In [1]:
import math

from loaders import *
from episode import *
from dataset import *

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [5]:
class CausalConv1d(nn.Module):
    # https://discuss.pytorch.org/t/causal-convolution/3456/4
    def __init__(self, in_channels, out_channels, kernel = 2,dilation=2):
        super(CausalConv1d, self).__init__()
        
        self.kernel = kernel
        self.padding = (kernel - 1)*dilation
        self.dilation = dilation
        self.causal_conv = nn.Conv1d(in_channels, out_channels, 2, padding = self.padding, dilation = dilation)

    def forward(self, input):
        return self.causal_conv(input)[:, :, :-self.padding]

In [6]:
class DenseBlock(nn.Module):
    
    def __init__(self,in_channels,num_filters,dilation_rate = 2):
        
        super(DenseBlock,self).__init__()
        
        self.causal_conv = CausalConv1d(in_channels,num_filters,dilation=dilation_rate)
        
    def forward(self,input):
        
        xf = F.tanh(self.causal_conv(input))
        xb = F.sigmoid(self.causal_conv(input))
        
        activations = xf*gf
        
        return torch.cat([input,activations],dim = 1)
        

In [7]:
class TCBlock(nn.Module):
    
    
    def __init__(self, in_channels, filters, seq_len):
        
        super(TCBlock, self).__init__()
        layer_count = math.ceil(math.log(seq_len)/math.log(2))
        blocks = []
        channel_count = in_channels
        for layer in range(layer_count):
            block = DenseBlock(channel_count, filters, dilation=2**layer)
            blocks.append(block)
            channel_count += filters
        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)
    

In [9]:
class AttetionBlock(nn.Module):
    
    def __init__(self, dims, k_size, v_size, seq_len):
        
        super(AttentionBlock, self).__init__()
        
        self.key_layer = nn.Linear(dims, k_size)
        self.query_layer = nn.Linear(dims, k_size)
        self.value_layer = nn.Linear(dims, v_size)
        self.sqrt_k = math.sqrt(k_size)
        
    def forward(self,input):
        
        keys = self.key_layer(input)
        queries = self.query_layer(input)
        
        logits = torch.bmm(queries, keys.transpose(2,1))
        
        mask = np.triu(np.ones(logits.size()), k=1).astype('uint8')
        mask = torch.from_numpy(mask)
        # do masked_fill_ on data rather than Variable because PyTorch doesn't
        # support masked_fill_ w/-inf directly on Variables for some reason.
        logits.data.masked_fill_(mask, float('-inf'))
        probs = F.softmax(logits, dim=1) / self.sqrt_k
        read = torch.bmm(probs, values)
        return torch.cat([minibatch, read], dim=2)
        

In [2]:
""" Some global variables """
_loader = Loader(502) # 500 + SOS + EOS
loader = MIDILoader(_loader)

use_cuda = torch.cuda.is_available()
# Is the tokenizer 1 indexed?
vocabulary_size = 16*128*2 + 32*16 + 100 + 1 # 4708 + 1
vocabulary_size = vocabulary_size + 2 # SOS (index 4709) and EOS (index 4710)
SOS_TOKEN = 4709
EOS_TOKEN = 4710

encoding_size = 500
one_hot_embeddings = np.eye(vocabulary_size)

In [10]:
eps = load_sampler_from_config("../src/test.yaml")

INFO:few-shot:Preprocessing data. 0.00%
INFO:few-shot:Preprocessing data. 3.12%
INFO:few-shot:Preprocessing data. 9.93%
INFO:few-shot:Preprocessing data. 10.11%
INFO:few-shot:Preprocessing data. 10.29%
INFO:few-shot:Preprocessing data. 10.48%
INFO:few-shot:Preprocessing data. 10.66%
INFO:few-shot:Preprocessing data. 11.03%
INFO:few-shot:Preprocessing data. 11.21%
INFO:few-shot:Preprocessing data. 11.58%
INFO:few-shot:Preprocessing data. 11.95%
INFO:few-shot:Preprocessing data. 12.32%
INFO:few-shot:Preprocessing data. 12.50%
INFO:few-shot:Preprocessing data. 12.68%
INFO:few-shot:Preprocessing data. 12.87%
INFO:few-shot:Preprocessing data. 13.05%
INFO:few-shot:Preprocessing data. 13.42%
INFO:few-shot:Preprocessing data. 13.60%
INFO:few-shot:Preprocessing data. 13.79%
INFO:few-shot:Preprocessing data. 13.97%
INFO:few-shot:Preprocessing data. 14.15%
INFO:few-shot:Preprocessing data. 14.52%
INFO:few-shot:Preprocessing data. 14.71%
INFO:few-shot:Preprocessing data. 14.89%
INFO:few-shot:Prepr

INFO:few-shot:Preprocessing data. 46.88%
INFO:few-shot:Preprocessing data. 47.06%
INFO:few-shot:Preprocessing data. 47.24%
INFO:few-shot:Preprocessing data. 47.43%
INFO:few-shot:Preprocessing data. 47.61%
INFO:few-shot:Preprocessing data. 47.79%
INFO:few-shot:Preprocessing data. 47.98%
INFO:few-shot:Preprocessing data. 48.16%
INFO:few-shot:Preprocessing data. 48.35%
INFO:few-shot:Preprocessing data. 48.53%
INFO:few-shot:Preprocessing data. 48.71%
INFO:few-shot:Preprocessing data. 48.90%
INFO:few-shot:Preprocessing data. 49.08%
INFO:few-shot:Preprocessing data. 49.26%
INFO:few-shot:Preprocessing data. 49.45%
INFO:few-shot:Preprocessing data. 49.63%
INFO:few-shot:Preprocessing data. 49.82%
INFO:few-shot:Preprocessing data. 50.18%
INFO:few-shot:Preprocessing data. 50.37%
INFO:few-shot:Preprocessing data. 50.55%
INFO:few-shot:Preprocessing data. 50.74%
INFO:few-shot:Preprocessing data. 50.92%
INFO:few-shot:Preprocessing data. 51.10%
INFO:few-shot:Preprocessing data. 51.29%
INFO:few-shot:Pr

INFO:few-shot:Preprocessing data. 89.34%
INFO:few-shot:Preprocessing data. 89.71%
INFO:few-shot:Preprocessing data. 89.89%
INFO:few-shot:Preprocessing data. 90.26%
INFO:few-shot:Preprocessing data. 90.44%
INFO:few-shot:Preprocessing data. 90.62%
INFO:few-shot:Preprocessing data. 90.81%
INFO:few-shot:Preprocessing data. 90.99%
INFO:few-shot:Preprocessing data. 91.18%
INFO:few-shot:Preprocessing data. 91.36%
INFO:few-shot:Preprocessing data. 91.54%
INFO:few-shot:Preprocessing data. 91.91%
INFO:few-shot:Preprocessing data. 92.28%
INFO:few-shot:Preprocessing data. 92.46%
INFO:few-shot:Preprocessing data. 92.65%
INFO:few-shot:Preprocessing data. 92.83%
INFO:few-shot:Preprocessing data. 93.01%
INFO:few-shot:Preprocessing data. 93.20%
INFO:few-shot:Preprocessing data. 93.38%
INFO:few-shot:Preprocessing data. 93.57%
INFO:few-shot:Preprocessing data. 93.93%
INFO:few-shot:Preprocessing data. 94.12%
INFO:few-shot:Preprocessing data. 94.30%
INFO:few-shot:Preprocessing data. 94.49%
INFO:few-shot:Pr

In [14]:
eps

AttributeError: 'EpisodeSampler' object has no attribute 'root'