Skip to content

Commit

Permalink
modfications to TDNNF semi-orthogonal error
Browse files Browse the repository at this point in the history
  • Loading branch information
desh2608 committed Dec 18, 2020
1 parent cd5aa5c commit a861e69
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip install pytorch-tdnn
```

To install for development, clone the repository, and then run the following from
within the roor directory.
within the root directory.

```bash
pip install -e .
Expand All @@ -34,7 +34,7 @@ tdnn = TDNNLayer(
y = tdnn(x)
```

Here, `x` should have the shape `(batch_size, sequence_length, input_dim)`.
Here, `x` should have the shape `(batch_size, input_dim, sequence_length)`.

**Note:** The `context` list should follow these constraints:
* The length of the list should be 2 or an odd number.
Expand All @@ -55,20 +55,29 @@ tdnnf = TDNNFLayer(
1, # time stride
)

y = tdnnf(x, training=True)
y = tdnnf(x, semi_ortho_step=True)
```

The argument `training` is used to perform the semi-orthogonality step only during
the model training. If this call is made from within a `forward()` function of an
`nn.Module` class, `training` can be set to `self.training`.
The argument `semi_ortho_step` determines whether to take the step towards semi-
orthogonality for the constrained convolutional layers in the 3-stage splicing.
If this call is made from within a `forward()` function of an
`nn.Module` class, it can be set as follows to approximate Kaldi-style training
where the step is taken once every 4 iterations:

```python
import random
semi_ortho_step = self.training and (random.uniform(0,1) < 0.25)
```

**Note:** Time stride should be greater than or equal to 0. For example, if
the time stride is 1, a context of `[-1,1]` is used for each stage of splicing.

### Credits

* The TDNN implementation is based on: https://github.com/jonasvdd/TDNN.
* The TDNN implementation is based on: https://github.com/jonasvdd/TDNN and https://github.com/m-wiesner/nnet_pytorch.
* Semi-orthogonal convolutions used in TDNN-F are based on: https://github.com/cvqluu/Factorized-TDNN.
* Thanks to [Matthew Wiesner](https://github.com/m-wiesner) for helpful discussions
about the implementations.

This repository aims to wrap up these implementations in easy-installable PyPi
packages, which can be used directly in PyTorch based neural network training.
Expand Down
6 changes: 4 additions & 2 deletions pytorch_tdnn/tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ class TDNN(torch.nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
context: list):
context: list,
bias: bool = True):
"""
Implementation of TDNN using the dilation argument of the PyTorch Conv1d class
Due to its fastness the context has gained two constraints:
Expand Down Expand Up @@ -44,7 +45,8 @@ def __init__(self,
output_dim,
kernel_size=kernel_size,
dilation=dilation,
padding=padding
padding=padding,
bias=bias # will be set to False for semi-orthogonal TDNNF convolutions
))

def forward(self, x):
Expand Down
25 changes: 7 additions & 18 deletions pytorch_tdnn/tdnnf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# This implementation is based on: https://github.com/cvqluu/Factorized-TDNN

import random

import torch
import torch.nn.functional as F

Expand All @@ -14,7 +12,7 @@ def __init__(self,
input_dim: int,
output_dim: int,
context: list,
init: str = 'kaldi'):
init: str = 'xavier'):
"""
Semi-orthogonal convolutions. The forward function takes an additional
parameter that specifies whether to take the semi-orthogonality step.
Expand All @@ -23,7 +21,7 @@ def __init__(self,
:param output_dim: The number of channels produced by the temporal convolution
:param init: Initialization method for weight matrix (default = Kaldi-style)
"""
super(SemiOrthogonalConv, self).__init__(input_dim, output_dim, context)
super(SemiOrthogonalConv, self).__init__(input_dim, output_dim, context, bias=False)
self.init_method = init
self.reset_parameters()

Expand All @@ -38,7 +36,7 @@ def reset_parameters(self):
elif self.init_method == 'xavier':
# Use Xavier initialization
torch.nn.init.xavier_normal_(
self.temporal_conv
self.temporal_conv.weight
)

def step_semi_orth(self):
Expand Down Expand Up @@ -72,12 +70,11 @@ def get_semi_orth_weight(M):
ratio = trace_PP * P.shape[0] / (trace_P * trace_P)

# the following is the tweak to avoid divergence (more info in Kaldi)
assert ratio > 0.99, "Ratio of traces is less than 0.99"
# assert ratio > 0.9, "Ratio of traces is less than 0.9"
if ratio > 1.02:
update_speed *= 0.5
if ratio > 1.1:
update_speed *= 0.5

scale2 = trace_PP/trace_P
update = P - (torch.matrix_power(P, 0) * scale2)
alpha = update_speed / scale2
Expand Down Expand Up @@ -106,12 +103,7 @@ def get_semi_orth_error(M):
if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols
M = M.T
P = torch.mm(M, M.T)
PP = torch.mm(P, P.T)
trace_P = torch.trace(P)
trace_PP = torch.trace(PP)
scale2 = torch.sqrt(trace_PP/trace_P) ** 2
update = P - (torch.matrix_power(P, 0) * scale2)
return torch.norm(update, p='fro')
return torch.norm(P, p='fro')

def forward(self, x, semi_ortho_step = False):
"""
Expand Down Expand Up @@ -149,8 +141,6 @@ def __init__(self,
self.bottleneck_dim = bottleneck_dim
self.output_dim = output_dim

random.seed(0)

if time_stride == 0:
context = [0]
else:
Expand All @@ -160,14 +150,13 @@ def __init__(self,
self.factor2 = SemiOrthogonalConv(bottleneck_dim, bottleneck_dim, context)
self.factor3 = TDNN(bottleneck_dim, output_dim, context)

def forward(self, x, training=True):
def forward(self, x, semi_ortho_step=True):
"""
:param x: is one batch of data, x.size(): [batch_size, input_dim, in_seq_length]
sequence length is the dimension of the arbitrary length data
:param training: True if model is in training phase
:param semi_ortho_step: if True, update parameter for semi-orthogonality
:return: [batch_size, output_dim, out_seq_length]
"""
semi_ortho_step = training and (random.uniform(0,1) < 0.25)
x = self.factor1(x, semi_ortho_step=semi_ortho_step)
x = self.factor2(x, semi_ortho_step=semi_ortho_step)
x = self.factor3(x)
Expand Down

0 comments on commit a861e69

Please sign in to comment.