In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class MusicDataset(Dataset):
    def __init__(self, data, seq_length=50):
        self.data = data
        self.seq_length = seq_length

    def __getitem__(self, index):
        train_seq = self.data[:, index * self.seq_length: (index + 1) * self.seq_length, :]
        target_seq = self.data[:, (index + 1) * self.seq_length, :]
        return train_seq, target_seq

    def __len__(self):
        return int(self.data.size(1) / self.seq_length)

def grad_clipping(model, theta):
    params = [p for p in model.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

class MusicGenerationRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size=32, n_layers=1):
        super(MusicGenerationRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.batch_size = batch_size
        self.gru = nn.GRU(input_size, hidden_size, n_layers)
        self.linear = nn.Linear(hidden_size * n_layers, output_size)

    def forward(self, input, hidden):
        input = input.permute(0, 2, 1, 3)
        input = input.flatten(2, 3)
        input = input.permute(1, 0, 2)
        _, hidden = self.gru(input, hidden)
        h_n = hidden.permute(1, 0, 2)
        h_n = h_n.contiguous().flatten(1, 2)
        output = self.linear(h_n)
        return output, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.n_layers, batch_size, self.hidden_size).to(device)


def train_epoch(dataloader, model, optimizer, criterion):
    running_loss = 0
    n_obs = 0

    for train_seq, target_seq in dataloader:
        train_seq = train_seq.to(device)
        target_seq = target_seq.to(device)
        hidden = model.init_hidden(batch_size=64)
        optimizer.zero_grad()
        target_seq = target_seq.flatten(1, 2)
        output, hidden = model(train_seq, hidden)
        loss = criterion(output, target_seq)
        loss.backward()
        grad_clipping(model, 1)
        optimizer.step()
        running_loss += loss.item()
        n_obs += train_seq.size()[0]

    return running_loss / n_obs * 100


def test_epoch(dataloader, model, optimizer, criterion):
    running_loss = 0
    n_obs = 0

    for train_seq, target_seq in dataloader:
        train_seq = train_seq.to(device)
        target_seq = target_seq.to(device)
        hidden = model.init_hidden(batch_size=64)
        target_seq = target_seq.flatten(1, 2)
        optimizer.zero_grad()
        output, hidden = model(train_seq, hidden)
        loss = criterion(output, target_seq)
        running_loss += loss.item()
        n_obs += train_seq.size()[0]

    return running_loss / n_obs * 100


def fit(model, checkpoint_path, optimizer, scheduler, criterion, train_dataloader, test_dataloader):
    train_losses = []
    test_losses = []

    for epoch in range(1, n_epochs + 1):
        model.train()
        train_epoch_loss = train_epoch(train_dataloader, model, optimizer, criterion)
        train_losses.append(train_epoch_loss)
        scheduler.step()
        model.eval()
        test_epoch_loss = test_epoch(test_dataloader, model, optimizer, criterion)
        test_losses.append(test_epoch_loss)

        print('Epoch {}, Train Loss: {}, Test Loss: {}, Time: {}'.format(epoch, train_epoch_loss, test_epoch_loss,
                                                                         datetime.now()))

    torch.save(model, os.path.join(checkpoint_path, 'model_torch.pt'))

    return train_losses, test_losses

In [None]:
input_path = os.path.join('data')
preparation_path = os.path.join(input_path, '02_preparation')
model_path = os.path.join(input_path, '03_model')
checkpoint_path = os.path.join(model_path, 'gan', 'checkpoints')

torch_tensor = torch.load(os.path.join(preparation_path, 'tensor.pt'))
torch_tensor = torch_tensor.type(torch.float32)

train_dataset = MusicDataset(torch_tensor[:, 0:int(torch_tensor.shape[1] * 0.8), :], seq_length=64)
train_loader = DataLoader(train_dataset, batch_size=64,
                          drop_last=True)

test_dataset = MusicDataset(torch_tensor[:, int(torch_tensor.shape[1] * 0.8):, :], seq_length=64)
test_loader = DataLoader(test_dataset, batch_size=64,
                         drop_last=True)

hidden_size = 512
n_layers = 3
n_epochs = 1
lr = 0.001
lr_lambda = 0.98

model = MusicGenerationRNN(input_size=504, hidden_size=hidden_size, output_size=504, n_layers=n_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: lr_lambda ** epoch)
criterion = nn.MSELoss()
train_losses, test_losses = fit(model, checkpoint_path, optimizer, scheduler, criterion, train_loader, test_loader)

pd.DataFrame(train_losses).to_csv(os.path.join(checkpoint_path, 'train_losses_rnn.csv'), index=False)
pd.DataFrame(test_losses).to_csv(os.path.join(checkpoint_path, 'test_losses_rnn.csv'), index=False)

plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
