Skip to content

Commit

Permalink
minor comment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 committed Dec 16, 2020
1 parent 3626816 commit cd5aa5c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ tdnn = TDNNLayer(
512, # output dim
[-3,0,3], # context
)

y = tdnn(x)
```

Here, `x` should have the shape `(batch_size, sequence_length, input_dim)`.

**Note:** The `context` list should follow these constraints:
* The length of the list should be 2 or an odd number.
* If the length is 2, it should be of the form `[-1,1]` or `[-3,3]`, but not
Expand All @@ -44,14 +48,20 @@ tdnn = TDNNLayer(
```python
from pytorch_tdnn.tdnnf import TDNNF as TDNNFLayer

tdnn = TDNNFLayer(
tdnnf = TDNNFLayer(
512, # input dim
512, # output dim
256, # bottleneck dim
1, # time stride
)

y = tdnnf(x, training=True)
```

The argument `training` is used to perform the semi-orthogonality step only during
the model training. If this call is made from within a `forward()` function of an
`nn.Module` class, `training` can be set to `self.training`.

**Note:** Time stride should be greater than or equal to 0. For example, if
the time stride is 1, a context of `[-1,1]` is used for each stage of splicing.

Expand Down
4 changes: 2 additions & 2 deletions pytorch_tdnn/tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self,

def forward(self, x):
"""
:param x: is one batch of data, x.size(): [batch_size, input_dim, sequence_length]
:param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length]
sequence length is the dimension of the arbitrary length data
:return: [batch_size, output_dim, len(valid_steps)]
:return: [batch_size, output_dim, out_seq_length ]
"""
return self.temporal_conv(x)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_tdnn/tdnnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def __init__(self,

def forward(self, x, training=True):
"""
:param x: is one batch of data, x.size(): [batch_size, sequence_length, input_dim]
:param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length]
sequence length is the dimension of the arbitrary length data
:param training: True if model is in training phase
:return: [batch_size, output_dim, len(valid_steps)]
:return: [batch_size, output_dim, out_seq_length]
"""
semi_ortho_step = training and (random.uniform(0,1) < 0.25)
x = self.factor1(x, semi_ortho_step=semi_ortho_step)
Expand Down

0 comments on commit cd5aa5c

Please sign in to comment.