Skip to content

Commit

Permalink
fix formatting clobber
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 5, 2020
1 parent 93ebb94 commit a105c09
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions torch_struct/distributions.py
Expand Up @@ -90,7 +90,9 @@ def cross_entropy(self, other):
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)

def kl(self, other):
"""
Expand All @@ -99,7 +101,9 @@ def kl(self, other):
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)

@lazy_property
def max(self):
Expand Down Expand Up @@ -127,7 +131,9 @@ def kmax(self, k):
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True)
return self._struct(KMaxSemiring(k)).sum(
self.log_potentials, self.lengths, _raw=True
)

def topk(self, k):
r"""
Expand All @@ -137,7 +143,9 @@ def topk(self, k):
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True)
return self._struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)

@lazy_property
def mode(self):
Expand Down Expand Up @@ -192,7 +200,9 @@ def sample(self, sample_shape=torch.Size()):
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths)
sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
samples.append(tmp_sample)
Expand All @@ -213,7 +223,9 @@ def enumerate_support(self, expand=True):
Returns:
(enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*)
"""
_, _, edges, enum_lengths = self._struct().enumerate(self.log_potentials, self.lengths)
_, _, edges, enum_lengths = self._struct().enumerate(
self.log_potentials, self.lengths
)
# if expand:
# edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:])
return edges, enum_lengths
Expand Down Expand Up @@ -284,7 +296,9 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
super().__init__(log_potentials, lengths)

def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap)
return self.struct(
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap
)


class HMM(StructDistribution):
Expand Down Expand Up @@ -437,7 +451,9 @@ def __init__(self, log_potentials, lengths=None):
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape)
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape
)


class NonProjectiveDependencyCRF(StructDistribution):
Expand Down

0 comments on commit a105c09

Please sign in to comment.