In [4]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import soundfile as sf
import torch.nn as nn
import torch
import math
import pickle

from scipy.io.wavfile import write as write_wav
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils import weight_norm

from tinydb import TinyDB, Query

sys.path.append('/export/home/jrichter/Projects/audio_visual/frameworks/uhh-sp')
sys.path.append('/export/home/jrichter/Projects/audio_visual/python')

from uhh_sp.dsp import stft, istft
from utils import count_parameters

### Stochastic Temporal Convolutional Network (STCN)

In [206]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2, 
                 activation=nn.ReLU()):
        super(ResidualBlock, self).__init__()
        
        self.resample = (weight_norm(nn.Conv1d(in_channels, out_channels, 1)) 
                         if in_channels != out_channels else None)
        self.padding = nn.ConstantPad1d(((kernel_size - 1) * dilation, 0), 0)
        self.convolution = weight_norm(nn.Conv1d(out_channels, out_channels, 
                                                 kernel_size, dilation=dilation))
        self.activation = activation
        self.dropout = nn.Dropout(dropout)     
        
        self.init_weights()
        
    def init_weights(self):
        self.convolution.weight.data.normal_(0, 0.01)
        if self.resample is not None:
            self.resample.weight.data.normal_(0, 0.01)
        
    def forward(self, x):
        x = x if self.resample is None else self.resample(x)
        y = self.dropout(self.activation(self.convolution(self.padding(x))))
        return self.activation(x + y)

In [239]:
class TCN(nn.Module):
    def __init__(self, channels, kernel_size=2, dropout=0.2, activation=nn.ReLU()):
        super(TCN, self).__init__()
        
        self.layers = nn.Sequential(*[ResidualBlock(channels[i], channels[i+1],
            kernel_size, 2**i, dropout, activation) for i in range(len(channels) - 1)]) 
                                    
    def forward(self, x):
        return self.layers(x)
    
    def representation(self, x, level):
        for i in range(level):
            x = self.layers[i](x)
        return x

In [221]:
class LatentLayer(nn.Module):
    def __init__(self, channels, kernel_size):
        super(LatentLayer, self).__init__()
        
        self.latent_variables = None

In [226]:
class STCN(nn.Module):
    def __init__(self, channels, kernel_size=2, dropout=0.2, activation=nn.ReLU()):
        super(STCN, self).__init__()
        
        self.network = TCN(channels, kernel_size, dropout, activation)
        self.latent_variables = LatentVariables(channels, kernel_size)
    
        self.num_hidden_layers = num_hidden_layers
        
        # encoder  
        self.enc_in = nn.Linear(in_out_dim, hidden_dim)
        self.enc_hidden = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers)])
        self.enc_out_1 = nn.Linear(hidden_dim, latent_dim)
        self.enc_out_2 = nn.Linear(hidden_dim, latent_dim)
        
        # decoder
        self.dec_in = nn.Linear(latent_dim, hidden_dim)
        self.dec_hidden = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers)])
        self.dec_out = nn.Linear(hidden_dim, in_out_dim)
              
    def encode(self, x):
        h = torch.tanh(self.enc_in(x))
        for i in range(self.num_hidden_layers):
            h = torch.tanh(self.enc_hidden[i](h))     
        return self.enc_out_1(h), self.enc_out_2(h)
    
    def decode(self, z):
        h = torch.tanh(self.dec_in(z))
        for i in range(self.num_hidden_layers):
            h = torch.tanh(self.dec_hidden[i](h))
        return self.dec_out(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 513))
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar
        

In [227]:
channels = [2, 5, 5, 5, 3]
kernel_size = 3
dropout = 0.0
activation = lambda x : x

model = STCN(channels, kernel_size, dropout, activation)
model

STCN(
  (deterministic_network): TCN(
    (layers): Sequential(
      (0): ResidualBlock(
        (resample): Conv1d(2, 5, kernel_size=(1,), stride=(1,))
        (padding): ConstantPad1d(padding=(2, 0), value=0)
        (convolution): Conv1d(5, 5, kernel_size=(3,), stride=(1,))
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): ResidualBlock(
        (padding): ConstantPad1d(padding=(4, 0), value=0)
        (convolution): Conv1d(5, 5, kernel_size=(3,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): ResidualBlock(
        (padding): ConstantPad1d(padding=(8, 0), value=0)
        (convolution): Conv1d(5, 5, kernel_size=(3,), stride=(1,), dilation=(4,))
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (3): ResidualBlock(
        (resample): Conv1d(5, 3, kernel_size=(1,), stride=(1,))
        (padding): ConstantPad1d(padding=(16, 0), value=0)
        (convolution): Conv1d(3, 3, kernel_size=(3,), stride=(1,), di

In [209]:
batch_size = 2
seq_length = 10

input = torch.randn(batch_size, in_channels, seq_length)
model(input)

tensor([[[-0.9051, -0.4287, -0.5468, -0.5347, -0.6750, -1.2396, -0.6867,
          -0.9820, -0.8129, -0.2056],
         [-0.3419, -0.6165, -0.3817, -0.3315, -0.6376, -0.0914, -0.8208,
          -0.0322, -0.0062, -0.0756],
         [ 0.0613, -0.2904,  0.0367,  0.1918, -0.2204,  0.8778, -0.2903,
           0.6046,  0.4820, -0.0654]],

        [[-0.4264, -0.4999, -0.8912, -0.7094, -0.8899, -1.4694, -0.9052,
          -0.5631, -0.8975, -0.6074],
         [-0.3022, -0.5099, -0.3062, -0.1465, -0.2287, -0.1667, -0.2507,
          -0.4289,  0.1311, -0.3389],
         [-0.0573, -0.1170,  0.3888,  0.4377,  0.3673,  0.8339,  0.2395,
          -0.1399,  0.4895, -0.2250]]], grad_fn=<AddBackward0>)