From cd5aa5c0d1b6a00fd4e17937a27cc00f8395fe1d Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 15 Dec 2020 19:45:17 -0500 Subject: [PATCH] minor comment fixes --- README.md | 12 +++++++++++- pytorch_tdnn/tdnn.py | 4 ++-- pytorch_tdnn/tdnnf.py | 4 ++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 7dafee6..ae8a460 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,12 @@ tdnn = TDNNLayer( 512, # output dim [-3,0,3], # context ) + +y = tdnn(x) ``` +Here, `x` should have the shape `(batch_size, sequence_length, input_dim)`. + **Note:** The `context` list should follow these constraints: * The length of the list should be 2 or an odd number. * If the length is 2, it should be of the form `[-1,1]` or `[-3,3]`, but not @@ -44,14 +48,20 @@ tdnn = TDNNLayer( ```python from pytorch_tdnn.tdnnf import TDNNF as TDNNFLayer -tdnn = TDNNFLayer( +tdnnf = TDNNFLayer( 512, # input dim 512, # output dim 256, # bottleneck dim 1, # time stride ) + +y = tdnnf(x, training=True) ``` +The argument `training` is used to perform the semi-orthogonality step only during +the model training. If this call is made from within a `forward()` function of an +`nn.Module` class, `training` can be set to `self.training`. + **Note:** Time stride should be greater than or equal to 0. For example, if the time stride is 1, a context of `[-1,1]` is used for each stage of splicing. diff --git a/pytorch_tdnn/tdnn.py b/pytorch_tdnn/tdnn.py index 6733bca..cf1b7ad 100644 --- a/pytorch_tdnn/tdnn.py +++ b/pytorch_tdnn/tdnn.py @@ -49,9 +49,9 @@ def __init__(self, def forward(self, x): """ - :param x: is one batch of data, x.size(): [batch_size, input_dim, sequence_length] + :param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length] sequence length is the dimension of the arbitrary length data - :return: [batch_size, output_dim, len(valid_steps)] + :return: [batch_size, output_dim, out_seq_length ] """ return self.temporal_conv(x) diff --git a/pytorch_tdnn/tdnnf.py b/pytorch_tdnn/tdnnf.py index 875c6ac..27e3be7 100644 --- a/pytorch_tdnn/tdnnf.py +++ b/pytorch_tdnn/tdnnf.py @@ -162,10 +162,10 @@ def __init__(self, def forward(self, x, training=True): """ - :param x: is one batch of data, x.size(): [batch_size, sequence_length, input_dim] + :param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length] sequence length is the dimension of the arbitrary length data :param training: True if model is in training phase - :return: [batch_size, output_dim, len(valid_steps)] + :return: [batch_size, output_dim, out_seq_length] """ semi_ortho_step = training and (random.uniform(0,1) < 0.25) x = self.factor1(x, semi_ortho_step=semi_ortho_step)