-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcausalconv.py
More file actions
40 lines (30 loc) · 1.81 KB
/
causalconv.py
File metadata and controls
40 lines (30 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn.functional as F
# This wraps a traditional Conv1d layer so that it can't "peek" into the future at all
# This is accomplished by padding the input with zeros on the left side
# The input shape is (N, C_in, L) and the output shape is (N, C_out, L) where L is the input length
#
# You can also allow it to "peek" into the future by setting additional_context to a positive integer
class CausalConv1d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, additional_context: int = 0):
super(CausalConv1d, self).__init__()
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
self.padding = (kernel_size - 1) * dilation - stride + 1
if additional_context < 0:
raise ValueError("additional_context must be non-negative")
if additional_context > self.padding:
raise ValueError("additional_context can't be greater than the padding")
self.additional_context = additional_context
self.left_padding = self.padding - additional_context
# Input shape is (N, C_in, L_in)
def forward(self, x: torch.Tensor):
# Right padding is always zero, because think about it: during training, you don't know what happens AFTER the training sample
# Padding with zeros is not a valid assumption, so you just would need to shorten the output length by that amount
x = torch.nn.functional.pad(x, (self.left_padding, 0))
return self.conv(x)
def streaming_forward(self, x: torch.Tensor, state: torch.Tensor):
input = torch.cat((state, x), dim=2)
result = self.conv(input)
# Update the state
state = input[:, :, result.shape[2] * self.conv.stride[0]:]
return result, state