In [1]:
from src.dataset import AnimeAudioDataset
from src.net import Net
from src.utils import save_model, load_model, predict

import torchaudio
import torch
import os

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset

device = torch.device("cuda")

## Load Data

In [2]:
anime_audio_data = AnimeAudioDataset(device)

## Define Network

In [3]:
batch_size = 3
net = Net(batch_size, anime_audio_data.max_length, anime_audio_data.l_mean, anime_audio_data.l_std)
net.cuda()
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0)

In [4]:
load_model("test.model", net, optimizer)

9

In [41]:
dataloader = torch.utils.data.DataLoader(anime_audio_data, batch_size=batch_size, shuffle=True)

# Train

In [None]:
net.train()

for epoch in range(50):
    
    running_loss = 0
    
    for data, label in dataloader:
        label = label.flatten()      
        optimizer.zero_grad()
        data, label = data, label
        output = net(data)
        loss = criterion(output, label)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    print("loss:", running_loss)

loss: 0.0003207752415619325
loss: 0.00042806966303032823
loss: 0.0006737194562447257
loss: 0.0003324301242173533
loss: 0.0004525403928710148
loss: 0.00027878620494448114
loss: 0.0006532271145260893
loss: 0.0002477251437085215
loss: 0.0003225981017749291
loss: 0.0013520068459911272
loss: 0.00023205029174278025
loss: 0.0002960552828881191
loss: 0.00020813575974898413
loss: 0.00027318836146150716
loss: 0.0003177058752044104
loss: 0.0003497104462439893
loss: 0.00033027897006832063
loss: 0.0003761379011848476
loss: 0.00016255309128609952
loss: 0.00018672613441594876
loss: 6.26860091870185e-05
loss: 9.515970759821357e-05
loss: 0.00015142260235734284
loss: 0.00023108592904463876
loss: 0.00032268610812025145
loss: 0.0002900924991990905
loss: 0.000126515065858257
loss: 0.00015961114786477992
loss: 0.00018724164328887127
loss: 0.00022657685576632502
loss: 8.015564435481792e-05
loss: 0.00014768033543077763
loss: 0.00011310154695820529
loss: 7.234696568048093e-05
loss: 9.750164645083714e-05
loss: 

In [49]:
save_model("test.model", net, optimizer, epoch)
net.eval()
with torch.no_grad():
    rv = predict("data/Audio/Boku_no_Hero_Academia_10.wav", net, batch_size)
print(rv.mean(axis=0))

tensor([   8945.0752,   91498.9375, 1358045.5000, 1448425.3750])


# Compare actual labels

In [None]:
label_1 = "data/Labels/nick/One_Punch_Man_1.label"
label_2 = "data/Labels/nick/One_Punch_Man_5.label"
label_3 = "data/Labels/nick/One_Punch_Man_6.label"

labels = []
for filename in [label_1, label_2, label_3]:
    f = open(filename, "r")
    label = []
    for i in range(4):
        label.append(int(f.readline()))
    labels.append(label)

labels = torch.Tensor(labels)
labels = labels.to(device)

In [None]:
for i in range(len(anime_audio_data.audio_files)):
    if anime_audio_data.audio_files[i].startswith('One'):
        print(anime_audio_data.audio_files[i], i)

In [None]:
'''
for i, data in enumerate(anime_audio_data):
    data, _ = data
    #data = data.unsqueeze(0)
    out = net(data) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(labels[i])
'''
def temp(data):
    audio, label = data
    out = net(audio.unsqueeze(0)) * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    label = label * anime_audio_data.l_std.to(device) + anime_audio_data.l_mean.to(device)
    print(out)
    print(label)
temp(anime_audio_data[2])
temp(anime_audio_data[3])
temp(anime_audio_data[4])

In [None]:
# util to find magic numbers
x = data[0].unsqueeze(0)
x = net.pool(F.relu(net.conv1(x)))
print(x.shape)
x = x.view(-1, 5859 * 16)
# print(x.shape)
x = F.relu(net.fc1(x))
x = F.relu(net.fc2(x))
x = net.fc3(x)
print(x)
pass

In [None]:
# find magic num
x = data[0].unsqueeze(0)
x = net.pool(F.relu(net.conv1(x)))
# x = net.pool(F.relu(net.conv2(x)))
print(x.shape)
x = x.view(-1, net.magic_num)
x = net.fc1(x)
print(x)