# Definition of the model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from constants import *
import time
import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class BeatTracker(nn.Module):
    
    def __init__(self, hidden_size=128, num_layers=2):
        super(BeatTracker, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(
                        nb, 
                        hidden_size, 
                        num_layers, 
                        bidirectional=True, 
                        #dropout=0.5,
                        batch_first=True)
        self.hid_to_beat = nn.Linear(2 * hidden_size, 2)
        self.hidden = None #self.init_hidden()
        
        self.loss_function = nn.NLLLoss()
        
        self.lr = 0.001
        self.optimizer = optim.Adam(self.parameters(), lr=self.lr)
        
    def init_hidden(self):
        h0 = torch.zeros(2 * self.num_layers, 1, self.hidden_size, device=device)
        c0 = torch.zeros(2 * self.num_layers, 1, self.hidden_size, device=device)
        return h0, c0
    
    def forward(self, spec):
        x = self.lstm(spec)[0]
        x = self.hid_to_beat(x)
        x = F.log_softmax(x, dim=-1)
        return x
    
    def set_lr(self, lr):
        self.lr = lr
        for p in self.optimizer.param_groups:
            p['lr'] = lr
            
    def learn(self, spec, onsets, isbeat):
        self.optimizer.zero_grad()
        output = self(spec)
        output = output[onsets == 1]
        target = isbeat[onsets == 1]
        loss = self.loss_function(output, target)
        loss.backward()
        self.optimizer.step()
        
        predic = torch.argmax(output, dim=1)
        accuracy = torch.sum(predic == target).item() / predic.shape.numel()
        
        return loss.item(), accuracy
    
    def fit(self, dataset, batch_size=1, epochs=1):
        loss_hist = np.zeros((epochs, -(-len(dataset) // batch_size)))
        accu_hist = np.zeros((epochs, -(-len(dataset) // batch_size)))
        for e in range(epochs):
            start = time.time()
            
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
            for i, (spec, onsets, isbeat) in enumerate(dataloader):
                loss, accuracy = self.learn(spec, onsets, isbeat)
                loss_hist[e, i] = loss
                accu_hist[e, i] = accuracy
            
            end = time.time()
            t = end - start
            eta = str(datetime.timedelta(seconds=int(t * (epochs - e - 1))))
            print(f'| Epoch: {e + 1:{len(str(epochs))}} | ', end='')
            print(f'Loss: {np.mean(loss_hist[e]):7.4f} | ', end='')
            print(f'Accuracy: {np.mean(accu_hist[e]):5.4f} | ', end='')
            print(f'{t / len(dataloader):.2f} s/b | eta: {eta} |')
        return loss_hist, accu_hist
    
    def predict(self, specs, onsets):
        """So far only works if batch_size = 1"""
        with torch.no_grad():
            output = model(specs)
            output = output[onsets == 1]
            pred_t = torch.argmax(output, dim=1)
            onsets_frames = np.argwhere(onsets.squeeze(0) == 1).squeeze(0)
            beats_frames = onsets_frames[pred_t == 1]
            pred = torch.zeros_like(onsets)
            pred[:, beats_frames] = 1
        return pred
    
    def evaluate(self, specs, onsets, isbeat):
        with torch.no_grad():
            output = model(specs)
            output = output[onsets == 1]
            target = isbeat[onsets == 1]
            predic = torch.argmax(output, dim=1)
            
            tn = torch.sum((predic == 0) & (target == 0)).item()
            fp = torch.sum((predic == 1) & (target == 0)).item()
            fn = torch.sum((predic == 0) & (target == 1)).item()
            tp = torch.sum((predic == 1) & (target == 1)).item()
        return tn, fp, fn, tp
    
    def evaluate_from_dataset(self, dataset):
        dataloader = DataLoader(trainset, batch_size=len(trainset))
        it = iter(dataloader)
        specs, onsets, isbeat = it.next()
        return self.evaluate(specs, onsets, isbeat)
    
    def freeze(self):
        for p in self.parameters():
            p.requires_grad = False
            
    def unfreeze(self):
        for p in self.parameters():
            p.requires_grad = True

class ToTensor(object):
    
    def __call__(self, sample):
        spec_np, onsets_np, isbeat_np = sample
        
        spec = torch.tensor(spec_np.T)
        
        onsets = torch.zeros(spec.shape[0], dtype=torch.long)
        isbeat = torch.zeros(spec.shape[0], dtype=torch.long)
        
        onsets[onsets_np] = 1
        isbeat[onsets_np[isbeat_np == 1]] = 1
        
        return spec, onsets, isbeat
    
def beat_track(isbeat):
    onset_envelope = isbeat.squeeze(0).numpy()
    tempo, bt = librosa.beat.beat_track(
                            sr=sr, 
                            onset_envelope=onset_envelope, 
                            hop_length=hl, 
                            tightness=800)
    return bt

# Visualization of the dataset

In [None]:
from GTZAN import GTZAN
from visualization import *

Take a subset of the GTZAN dataset preprocessed using `preprocess-GTZAN` and split it into a train set and a validation set.

In [None]:
dataset = GTZAN(937, 'country', 20, getbeats=True)

Visualize an example.

In [None]:
spec, onsets, isbeat, beats = dataset[np.random.randint(len(dataset))]

In [None]:
showspec(spec)

In [None]:
showdata(spec, onsets, isbeat, beats, duration=10)

# Training of the model

In [None]:
from torch.utils.data import random_split, DataLoader
import matplotlib.pyplot as plt

In [None]:
model = BeatTracker(hidden_size=128, num_layers=2)
print_params(model)

In [None]:
dataset = GTZAN(932, 'country', 20, ToTensor())

train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
trainset, validset = random_split(dataset, [train_size, valid_size])

print(f'Train size: {train_size}')
print(f'Valid size: {valid_size}')

In [None]:
confusion(*model.evaluate_from_dataset(validset))

In [None]:
lost_hist, accu_hist = model.fit(trainset, batch_size=3, epochs=3)

In [None]:
confusion(*model.evaluate_from_dataset(validset))

In [None]:
torch.save(model.state_dict(), './data/model_02.pt')

# Test