In [1]:
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [5]:
class CausalConv1d(nn.Module):
    # https://discuss.pytorch.org/t/causal-convolution/3456/4
    def __init__(self, in_channels, out_channels, kernel = 2,dilation=2):
        super(CausalConv1d, self).__init__()
        
        self.kernel = kernel
        self.padding = (kernel - 1)*dilation
        self.dilation = dilation
        self.causal_conv = nn.Conv1d(in_channels, out_channels, 2, padding = self.padding, dilation = dilation)

    def forward(self, input):
        return self.causal_conv(input)[:, :, :-self.padding]

In [6]:
class DenseBlock(nn.Module):
    
    def __init__(self,in_channels,num_filters,dilation_rate = 2):
        
        super(DenseBlock,self).__init__()
        
        self.causal_conv = CausalConv1d(in_channels,num_filters,dilation=dilation_rate)
        
    def forward(self,input):
        
        xf = F.tanh(self.causal_conv(input))
        xb = F.sigmoid(self.causal_conv(input))
        
        activations = xf*gf
        
        return torch.cat([input,activations],dim = 1)
        

In [7]:
class TCBlock(nn.Module):
    
    
    def __init__(self, in_channels, filters, seq_len):
        
        super(TCBlock, self).__init__()
        layer_count = math.ceil(math.log(seq_len)/math.log(2))
        blocks = []
        channel_count = in_channels
        for layer in range(layer_count):
            block = DenseBlock(channel_count, filters, dilation=2**layer)
            blocks.append(block)
            channel_count += filters
        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)
    

In [9]:
class AttetionBlock(nn.Module):
    
    def __init__(self, dims, k_size, v_size, seq_len):
        
        super(AttentionBlock, self).__init__()
        
        self.key_layer = nn.Linear(dims, k_size)
        self.query_layer = nn.Linear(dims, k_size)
        self.value_layer = nn.Linear(dims, v_size)
        self.sqrt_k = math.sqrt(k_size)
        
    def forward(self,input):
        
        keys = self.key_layer(input)
        queries = self.query_layer(input)
        
        logits = torch.bmm(queries, keys.transpose(2,1))
        
        mask = np.triu(np.ones(logits.size()), k=1).astype('uint8')
        mask = torch.from_numpy(mask)
        # do masked_fill_ on data rather than Variable because PyTorch doesn't
        # support masked_fill_ w/-inf directly on Variables for some reason.
        logits.data.masked_fill_(mask, float('-inf'))
        probs = F.softmax(logits, dim=1) / self.sqrt_k
        read = torch.bmm(probs, values)
        return torch.cat([minibatch, read], dim=2)
        