In [None]:
import torchaudio
import torch

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset

device = torch.device("cuda")

## Load Data

In [None]:
class AnimeAudioDataset(Dataset):
    
    def __init__(self):
                
        self.data = self._pad_audio(self._load_audio())
        self.labels, self.label_mean, self.label_std = \
                self._normalize_labels(self._load_labels())
        
        self.data = self.data.to(device)
        self.labels = self.labels.to(device)
        
    def _load_audio(self):
        aud_1 = "data/Audio/One_Punch_Man_1.wav"
        aud_2 = "data/Audio/One_Punch_Man_5.wav"
        aud_3 = "data/Audio/One_Punch_Man_6.wav"

        waveform_1, sample_freq_1 = torchaudio.load(aud_1)
        waveform_2, sample_freq_2 = torchaudio.load(aud_2)
        waveform_3, sample_freq_3 = torchaudio.load(aud_3)

        data = []
        data.append(waveform_1)
        data.append(waveform_2)
        data.append(waveform_3)
        return data
    
    def _load_labels(self):
        label_1 = "data/Labels/nick/One_Punch_Man_1.label"
        label_2 = "data/Labels/nick/One_Punch_Man_5.label"
        label_3 = "data/Labels/nick/One_Punch_Man_6.label"

        labels = []
        for filename in [label_1, label_2, label_3]:
            f = open(filename, "r")
            label = []
            for i in range(4):
                label.append(int(f.readline()))
            labels.append(label)
            
        return torch.Tensor(labels)
    
    def _normalize_labels(self, labels):
        l_mean = labels.mean(dim=0)
        l_std = labels.std(dim=0)
        labels = (labels - l_mean) / l_std
        return labels, l_mean, l_std
        
    def _pad_audio(self, data):
        longest = max(map(lambda x: x.shape[1], data))
        for i in range(len(data)):
            zeros = torch.zeros(2, (longest - data[i].shape[1]))
            data[i] = torch.cat((data[i], zeros), dim=1)
        return torch.stack(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # idx can be a tensor
        return self.data[idx], self.labels[idx]

anime_audio_data = AnimeAudioDataset() 

## Define Network

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.magic_num = 234395 * 4
        self.conv1 = nn.Conv1d(2, 4, 1600, stride=10)
        self.pool = nn.MaxPool1d(5)
        # self.conv2 = nn.Conv1d(4, 8, 400, stride=10)
        self.fc1 = nn.Linear(self.magic_num, 4)
        # self.fc2 = nn.Linear(10000, 100)
        # self.fc3 = nn.Linear(100, 4)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        # x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, self.magic_num)
        # x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # x = self.fc3(x)
        x = self.fc1(x)
        return x

In [None]:
net = Net()
net.cuda()
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0)

# Train

In [None]:
for epoch in range(10):
    
    running_loss = 0
    
    for data, label in anime_audio_data:
    
        optimizer.zero_grad()
        data, label = data.unsqueeze(0), label.unsqueeze(0)

        output = net(data)
        loss = criterion(output, label)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
    print("loss:", running_loss)

# Compare actual labels

In [None]:
label_1 = "data/Labels/nick/One_Punch_Man_1.label"
label_2 = "data/Labels/nick/One_Punch_Man_5.label"
label_3 = "data/Labels/nick/One_Punch_Man_6.label"

labels = []
for filename in [label_1, label_2, label_3]:
    f = open(filename, "r")
    label = []
    for i in range(4):
        label.append(int(f.readline()))
    labels.append(label)

labels = torch.Tensor(labels)
labels = labels.to(device)

In [None]:
for i, data in enumerate(anime_audio_data):
    data, _ = data
    data = data.unsqueeze(0)
    out = net(data) * anime_audio_data.label_std.to(device) + anime_audio_data.label_mean.to(device)
    print(out)
    print(labels[i])

In [None]:
# util to find magic numbers
x = data[0].unsqueeze(0)
x = net.pool(F.relu(net.conv1(x)))
print(x.shape)
x = x.view(-1, 5859 * 16)
# print(x.shape)
x = F.relu(net.fc1(x))
x = F.relu(net.fc2(x))
x = net.fc3(x)
print(x)
pass

In [None]:
# find magic num
x = data[0].unsqueeze(0)
x = net.pool(F.relu(net.conv1(x)))
# x = net.pool(F.relu(net.conv2(x)))
print(x.shape)
x = x.view(-1, net.magic_num)
x = net.fc1(x)
print(x)