Skip to content

Commit

Permalink
Docs update (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 10, 2019
1 parent 632c43e commit ec51720
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def log_prob(self, value):
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): sample_sample x batch_shape x event_shapesss
value (tensor): sample_shape x batch_shape x event_shape
"""

d = value.dim()
Expand All @@ -55,14 +55,14 @@ def entropy(self):
Compute entropy for distribution :math:`H[z]`.
Returns:
entropy - batch_shape
entropy (*batch_shape*)
"""
return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths)

@lazy_property
def argmax(self):
r"""
Compute an argmax for distribution :math:`\\arg\max p(z)`.
Compute an argmax for distribution :math:`\arg\max p(z)`.
Returns:
argmax (*batch_shape x event_shape*)
Expand Down Expand Up @@ -100,7 +100,7 @@ def sample(self, sample_shape=torch.Size()):
sample_shape (int): number of samples
Returns:
samples - sample_shape x batch_shape x event_shape
samples (*sample_shape x batch_shape x event_shape*)
"""
assert len(sample_shape) == 1
nsamples = sample_shape[0]
Expand Down Expand Up @@ -128,7 +128,7 @@ def enumerate_support(self, expand=True):
Compute the full exponential enumeration set.
Returns:
(enum, enum_lengths) - tuple cardinality x batch_shape x event_shape
(enum, enum_lengths) - *tuple cardinality x batch_shape x event_shape*
"""
_, _, edges, enum_lengths = self.struct().enumerate(
self.log_potentials, self.lengths
Expand All @@ -145,7 +145,7 @@ class LinearChainCRF(StructDistribution):
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape ((N-1) x C x C ) e.g.
log_potentials (tensor) : event shape (*(N-1) x C x C*) e.g.
:math:`\phi(n, z_{n+1}, z_{n})`
lengths (long tensor) : batch_shape integers for length masking.
Expand All @@ -163,7 +163,7 @@ class SemiMarkovCRF(StructDistribution):
Event shape is of the form:
Parameters:
log_potentials : event shape (N x K x C x C) e.g.
log_potentials : event shape (*N x K x C x C*) e.g.
:math:`\phi(n, k, z_{n+1}, z_{n})`
lengths (long tensor) : batch shape integers for length masking.
Expand All @@ -180,13 +180,13 @@ class DependencyCRF(StructDistribution):
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (N x N) head, child with
log_potentials (tensor) : event shape (*N x N*) head, child with
arc scores with root scores on diagonal e.g.
:math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i).
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N long tensor in [0, N] (indexing is +1)
Compact representation: N long tensor in [0, .. N] (indexing is +1)
"""

struct = DepTree
Expand All @@ -199,11 +199,11 @@ class TreeCRF(StructDistribution):
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape N x N x NT, e.g.
log_potentials (tensor) : event_shape (*N x N x NT*), e.g.
:math:`\phi(i, j, nt)`
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N x N x NT long tensor (Same)
Compact representation: *N x N x NT* long tensor (Same)
"""
struct = CKY_CRF

Expand All @@ -217,12 +217,12 @@ class SentCFG(StructDistribution):
Parameters:
log_potentials (tuple) : event tuple with event shapes
terms (N x T)
rules (NT x (NT+T) x (NT+T))
root (NT)
terms (*N x T*)
rules (*NT x (NT+T) x (NT+T)*)
root (*NT*)
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N x N x NT long tensor
Compact representation: *N x N x NT* long tensor
"""

struct = CKY
Expand Down

0 comments on commit ec51720

Please sign in to comment.