Skip to content

Commit

Permalink
add initial pass at arg_constraints and arg validation. (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
erip committed Jan 30, 2022
1 parent 29faea2 commit 7146de5
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions torch_struct/distributions.py
@@ -1,4 +1,5 @@
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
from .linearchain import LinearChain
Expand Down Expand Up @@ -36,15 +37,18 @@ class StructDistribution(Distribution):
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
lengths (long tensor, batch_shape) : integers for length masking
"""
validate_args = False
arg_constraints = {
"log_potentials": constraints.real,
"lengths": constraints.nonnegative_integer
}

def __init__(self, log_potentials, lengths=None, args={}):
def __init__(self, log_potentials, lengths=None, args={}, validate_args=False):
batch_shape = log_potentials.shape[:1]
event_shape = log_potentials.shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
self.args = args
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
Expand Down Expand Up @@ -295,11 +299,17 @@ class AlignmentCRF(StructDistribution):
"""
struct = Alignment

def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
arg_constraints = {
"log_potentials": constraints.real,
"local": constraints.boolean,
"max_gap": constraints.nonnegative_integer,
"lengths": constraints.nonnegative_integer
}

def __init__(self, log_potentials, local=False, lengths=None, max_gap=None, validate_args=False):
self.local = local
self.max_gap = max_gap
super().__init__(log_potentials, lengths)
super().__init__(log_potentials, lengths, validate_args=validate_args)

def _struct(self, sr=None):
return self.struct(
Expand All @@ -324,9 +334,9 @@ class HMM(StructDistribution):
Implemented as a special case of linear chain CRF.
"""

def __init__(self, transition, emission, init, observations, lengths=None):
def __init__(self, transition, emission, init, observations, lengths=None, validate_args=False):
log_potentials = HMM.struct.hmm(transition, emission, init, observations)
super().__init__(log_potentials, lengths)
super().__init__(log_potentials, lengths, validate_args=validate_args)

struct = LinearChain

Expand Down Expand Up @@ -380,8 +390,8 @@ class DependencyCRF(StructDistribution):
"""

def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
super(DependencyCRF, self).__init__(log_potentials, lengths, args)
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True, validate_args=False):
super(DependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
self.struct = DepTree
setattr(self.struct, "multiroot", multiroot)

Expand Down Expand Up @@ -436,13 +446,13 @@ class SentCFG(StructDistribution):

struct = CKY

def __init__(self, log_potentials, lengths=None):
def __init__(self, log_potentials, lengths=None, validate_args=False):
batch_shape = log_potentials[0].shape[:1]
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape
batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
)


Expand All @@ -468,8 +478,12 @@ class NonProjectiveDependencyCRF(StructDistribution):
"""

def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
arg_constraints = {
"log_potentials": constraints.real
}

def __init__(self, log_potentials, lengths=None, args={}, multiroot=False, validate_args=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
self.multiroot = multiroot

@lazy_property
Expand Down

0 comments on commit 7146de5

Please sign in to comment.