In [1]:
import warnings
warnings.filterwarnings('ignore')

import src.dataset as dataset
from src.net import Net

import torchaudio
import torch
import os

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda")

## Load Data

In [2]:
training_data, validation_data = dataset.get_train_val_data(
                            audio_dir="data/TempAudio/",
                            label_dir="data/Labels",
                            device=device,
                            load_all_in_mem=False)

## Define Network

In [3]:
batch_size = 5
net = Net(training_data.max_length, batch_size)
net.cuda()
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0)

# Train

In [5]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(validation_data)

writer = SummaryWriter("runs/experiment_1")

for epoch in range(50):
    
    train_loss = 0
    val_loss = 0
    
    for data, label in train_dataloader:
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    for data, label in val_dataloader:
        output = net(data)
        loss = criterion(output, label)
        val_loss += loss.item()
    
    train_loss /= len(training_data)
    val_loss /= len(validation_data)

    writer.add_scalar('Loss/train', train_loss, epoch + 1)
    writer.add_scalar('Loss/val', val_loss, epoch + 1)
    
    print("train_loss:", train_loss)
    print("val_loss:", val_loss)
    print()

train_loss: 0.12434544414281845
val_loss: 0.058523308485746384

train_loss: 0.10766292735934258
val_loss: 0.04519055783748627

train_loss: 0.10284284502267838
val_loss: 0.05679596588015556

train_loss: 0.09527323395013809
val_loss: 0.04234014451503754

train_loss: 0.07954220660030842
val_loss: 0.044094305485486984

train_loss: 0.08257294073700905
val_loss: 0.0402560792863369

train_loss: 0.07717148959636688
val_loss: 0.046958159655332565

train_loss: 0.07507700473070145
val_loss: 0.037683967500925064

train_loss: 0.06725197657942772
val_loss: 0.03915981203317642

train_loss: 0.0602297019213438
val_loss: 0.03465183079242706

train_loss: 0.0705949030816555
val_loss: 0.026055876165628433

train_loss: 0.05217808671295643
val_loss: 0.03282998502254486

train_loss: 0.06392803974449635
val_loss: 0.02248215489089489

train_loss: 0.05440438725054264
val_loss: 0.020859133452177048

train_loss: 0.054827358573675156
val_loss: 0.01651868224143982

train_loss: 0.045367952436208725
val_loss: 0.017767

KeyboardInterrupt: 

# Compare actual labels

In [None]:
label_1 = "data/Labels/nick/One_Punch_Man_1.label"
label_2 = "data/Labels/nick/One_Punch_Man_5.label"
label_3 = "data/Labels/nick/One_Punch_Man_6.label"

labels = []
for filename in [label_1, label_2, label_3]:
    f = open(filename, "r")
    label = []
    for i in range(4):
        label.append(int(f.readline()))
    labels.append(label)

labels = torch.Tensor(labels)
labels = labels.to(device)

In [None]:
for i in range(len(anime_audio_data.audio_files)):
    if anime_audio_data.audio_files[i].startswith('One'):
        print(anime_audio_data.audio_files[i], i)

In [None]:
'''
for i, data in enumerate(anime_audio_data):
    data, _ = data
    #data = data.unsqueeze(0)
    out = net(data) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(labels[i])
'''
def temp(data):
    audio, label = data
    out = net(audio.unsqueeze(0)) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    label = label * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(label)
temp(anime_audio_data[2])
temp(anime_audio_data[3])
temp(anime_audio_data[4])