In [1]:
from src.dataset import AnimeAudioDataset
from src.net import Net

import torchaudio
import torch
import os

import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

%matplotlib notebook
import matplotlib.pyplot as plt

device = torch.device("cuda")

## Load Data

In [2]:
dataset = AnimeAudioDataset(audio_dir="data/Audio",
                         label_dir="data/Labels",
                         device=device,
                         load_all_in_mem=False)

In [3]:
validation_split = 0.2

val_size = int(len(dataset) * validation_split)
train_size = len(dataset) - val_size
train_data, val_data = torch.utils.data.random_split(dataset, [train_size, val_size])

## Define Network

In [4]:
batch_size = 5
net = Net(dataset.max_length, dataset.l_mean, dataset.l_std)
net.cuda()
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0)

# Train

In [5]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)

train_loss_arr = []
val_loss_arr = []

fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()

fig.show()
fig.canvas.draw()

for epoch in range(10):

    train_loss = 0
    val_loss = 0

    for data, label in train_dataloader:

        optimizer.zero_grad()

        output = net(data)
        loss = criterion(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    for data, label in val_dataloader:
        output = net(data)
        loss = criterion(output, label)
        val_loss += loss.item()
        
    train_loss /= len(train_data)
    val_loss /= len(val_data)
    
    train_loss_arr.append(train_loss)
    val_loss_arr.append(val_loss)
    
    ax.clear()
    ax.plot(train_loss_arr, label="train_loss")
    ax.plot(val_loss_arr, label="val_loss")
    ax.legend()
    fig.canvas.draw()
    
    print("train_loss:", train_loss)
    print("val_loss:", val_loss)

<IPython.core.display.Javascript object>

train_loss: 0.18685141033851183
val_loss: 0.2952554523944855
train_loss: 0.17697688879875037
val_loss: 0.2811432909220457
train_loss: 0.1470703216126332
val_loss: 0.32964543004830676
train_loss: 0.15037102682086137
val_loss: 0.28832151864965755
train_loss: 0.1430724675838764
val_loss: 0.3040878288447857
train_loss: 0.3198485953303484
val_loss: 0.29616475601991016
train_loss: 0.10567359468684746
val_loss: 0.2757300063967705
train_loss: 0.06534843146800995
val_loss: 0.29399314398566884
train_loss: 0.05490700327433073
val_loss: 0.2836545618871848
train_loss: 0.03740486155192439
val_loss: 0.26782448093096417


# Compare actual labels

In [None]:
label_1 = "data/Labels/nick/One_Punch_Man_1.label"
label_2 = "data/Labels/nick/One_Punch_Man_5.label"
label_3 = "data/Labels/nick/One_Punch_Man_6.label"

labels = []
for filename in [label_1, label_2, label_3]:
    f = open(filename, "r")
    label = []
    for i in range(4):
        label.append(int(f.readline()))
    labels.append(label)

labels = torch.Tensor(labels)
labels = labels.to(device)

In [None]:
for i in range(len(anime_audio_data.audio_files)):
    if anime_audio_data.audio_files[i].startswith('One'):
        print(anime_audio_data.audio_files[i], i)

In [None]:
'''
for i, data in enumerate(anime_audio_data):
    data, _ = data
    #data = data.unsqueeze(0)
    out = net(data) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(labels[i])
'''
def temp(data):
    audio, label = data
    out = net(audio.unsqueeze(0)) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    label = label * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(label)
temp(anime_audio_data[2])
temp(anime_audio_data[3])
temp(anime_audio_data[4])