From ec5172027b0d5040cab392cb36301338459c54b1 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 10 Oct 2019 10:52:02 -0700 Subject: [PATCH] Docs update (#18) --- torch_struct/distributions.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index e5b06e36..7fa731f0 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -37,7 +37,7 @@ def log_prob(self, value): Compute log probability over values :math:`p(z)`. Parameters: - value (tensor): sample_sample x batch_shape x event_shapesss + value (tensor): sample_shape x batch_shape x event_shape """ d = value.dim() @@ -55,14 +55,14 @@ def entropy(self): Compute entropy for distribution :math:`H[z]`. Returns: - entropy - batch_shape + entropy (*batch_shape*) """ return self.struct(EntropySemiring).sum(self.log_potentials, self.lengths) @lazy_property def argmax(self): r""" - Compute an argmax for distribution :math:`\\arg\max p(z)`. + Compute an argmax for distribution :math:`\arg\max p(z)`. Returns: argmax (*batch_shape x event_shape*) @@ -100,7 +100,7 @@ def sample(self, sample_shape=torch.Size()): sample_shape (int): number of samples Returns: - samples - sample_shape x batch_shape x event_shape + samples (*sample_shape x batch_shape x event_shape*) """ assert len(sample_shape) == 1 nsamples = sample_shape[0] @@ -128,7 +128,7 @@ def enumerate_support(self, expand=True): Compute the full exponential enumeration set. Returns: - (enum, enum_lengths) - tuple cardinality x batch_shape x event_shape + (enum, enum_lengths) - *tuple cardinality x batch_shape x event_shape* """ _, _, edges, enum_lengths = self.struct().enumerate( self.log_potentials, self.lengths @@ -145,7 +145,7 @@ class LinearChainCRF(StructDistribution): Event shape is of the form: Parameters: - log_potentials (tensor) : event shape ((N-1) x C x C ) e.g. + log_potentials (tensor) : event shape (*(N-1) x C x C*) e.g. :math:`\phi(n, z_{n+1}, z_{n})` lengths (long tensor) : batch_shape integers for length masking. @@ -163,7 +163,7 @@ class SemiMarkovCRF(StructDistribution): Event shape is of the form: Parameters: - log_potentials : event shape (N x K x C x C) e.g. + log_potentials : event shape (*N x K x C x C*) e.g. :math:`\phi(n, k, z_{n+1}, z_{n})` lengths (long tensor) : batch shape integers for length masking. @@ -180,13 +180,13 @@ class DependencyCRF(StructDistribution): Event shape is of the form: Parameters: - log_potentials (tensor) : event shape (N x N) head, child with + log_potentials (tensor) : event shape (*N x N*) head, child with arc scores with root scores on diagonal e.g. :math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i). lengths (long tensor) : batch shape integers for length masking. - Compact representation: N long tensor in [0, N] (indexing is +1) + Compact representation: N long tensor in [0, .. N] (indexing is +1) """ struct = DepTree @@ -199,11 +199,11 @@ class TreeCRF(StructDistribution): Event shape is of the form: Parameters: - log_potentials (tensor) : event_shape N x N x NT, e.g. + log_potentials (tensor) : event_shape (*N x N x NT*), e.g. :math:`\phi(i, j, nt)` lengths (long tensor) : batch shape integers for length masking. - Compact representation: N x N x NT long tensor (Same) + Compact representation: *N x N x NT* long tensor (Same) """ struct = CKY_CRF @@ -217,12 +217,12 @@ class SentCFG(StructDistribution): Parameters: log_potentials (tuple) : event tuple with event shapes - terms (N x T) - rules (NT x (NT+T) x (NT+T)) - root (NT) + terms (*N x T*) + rules (*NT x (NT+T) x (NT+T)*) + root (*NT*) lengths (long tensor) : batch shape integers for length masking. - Compact representation: N x N x NT long tensor + Compact representation: *N x N x NT* long tensor """ struct = CKY