In [None]:
import csv
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from datasets import datasetMusic

In [None]:
temp    = 0
key     = 1
mode    = 2
loud    = 3
time_s  = 4
dura    = 5
dance   = 6
acoust  = 7
speech  = 8
live    = 9
energy  = 10
instru  = 11


sliced = [loud, dance, energy, speech, acoust]

In [None]:
train_ds = datasetMusic(path = 'train.csv', sliced=sliced)
train_dl = DataLoader(dataset=train_ds, shuffle=True, batch_size=256)

In [None]:
valid_ds = datasetMusic(path = 'valid.csv', sliced=sliced)
valid_dl = DataLoader(dataset=valid_ds, shuffle=False, batch_size=1)

In [None]:
class LinearReg(nn.Module):
    def __init__(self, in_dim = 5) -> None:
        super().__init__()
        self.linear = nn.Linear(in_dim, 1).double()

    def forward(self, x):
        return self.linear(x)

In [None]:
# basic setup
device      = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
lr          = 1e-3
model       = LinearReg(in_dim=len(sliced)).to(device)
criterion   = nn.MSELoss()
optimizer   = optim.SGD(model.parameters(), lr=lr) 
epochs      = 3000

In [None]:
# training (regression for valence)
losses = []
model.train()
for epoch in range(epochs):
    for idx, data in enumerate(train_dl):
        inData  = data[0].to(device) 
        outData = data[1].to(device)

        preds   = model(inData)       
        loss    = criterion(outData, preds)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # if idx % 10 == 0:
        losses.append(loss)
        print(f'Epoch: {str(epoch):5s}, idx: {str(idx):3s}, loss: {loss:.4f}')

plt.plot(losses)
plt.show()

In [None]:
# validating
model.eval()
correct = 0
total = 0
threshold = 0.3
for idx, data in enumerate(valid_dl):
    inData  = data[0].to(device)
    outData = data[1].to(device)

    preds   = model(inData)
    if (outData > 0.75 or outData < 0.25):
        total += 1
    if (preds > 0.5 and outData > 0.75) or (preds < 0.5 and outData < 0.25):
        correct += 1
print(f'Accuracy in valid datasets: {(100 * correct / total):.2f} %')

correct = 0
total = 0
train_dl = DataLoader(dataset=train_ds, shuffle=False, batch_size=1)
for idx, data in enumerate(train_dl):
    inData  = data[0].to(device)
    outData = data[1].to(device)

    preds   = model(inData)

    if (outData > 0.75 or outData < 0.25):
        total += 1
    if (preds > 0.5 and outData > 0.75) or (preds < 0.5 and outData < 0.25):
        correct += 1

print(f'Accuracy in train datasets: {(100 * correct / total):.2f} %')