In [None]:
import torch
import torchaudio

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split


from pitch_tracker.utils import dataset
from pitch_tracker.utils.dataset import AudioDataset
from pitch_tracker.utils.constants import (F_MIN, HOP_LENGTH, N_FFT, N_MELS,
                                           PICKING_FRAME_SIZE,
                                           PICKING_FRAME_STEP,
                                           PICKING_FRAME_TIME, SAMPLE_RATE,
                                           STEP_FRAME, STEP_TIME, WIN_LENGTH,
                                           N_CLASS, )
from pitch_tracker.utils import files

In [None]:
DATASET_DIR = '../content/pickled_database/'


dataset_paths = list(files.list_folder_paths_in_dir(DATASET_DIR))
train_set, validation_set = train_test_split(dataset_paths, test_size=0.40, random_state=1, shuffle=True)
validation_set, test_set = train_test_split(validation_set, test_size=0.50, random_state=1, shuffle=True)
print(f'train_song_set: {len(train_set)}')
print(f'validation_song_set: {len(validation_set)}')
print(f'test_song_set: {len(test_set)}')

In [None]:
train_dataset = AudioDataset(train_set)
validation_set = AudioDataset(validation_set)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
validation_dataloader = DataLoader(validation_set, batch_size=8, shuffle=True)

In [None]:
device = "cuda" if torch.cuda.is_available() \
    else "mps" if torch.backends.mps.is_available() \
    else "cpu"
print(f"Using {device} device")

In [None]:
from collections import OrderedDict
from typing import Tuple, Union
from functools import partial


def create_conv2d_block(
        conv2d_input: Tuple[int,int,Union[Tuple[int,int], int]],
        maxpool_kernel_size: Union[Tuple[int,int], int, None],):
    in_channels, out_channels, (kernel_size) = conv2d_input
    
    conv2d = nn.Conv2d(in_channels, out_channels, kernel_size)
    relu = nn.ReLU()
    batch_norm = nn.BatchNorm2d(out_channels)
    maxpool_2d = nn.MaxPool2d(maxpool_kernel_size) if maxpool_kernel_size else None
    
    conv2d_block = nn.Sequential(
        OrderedDict([
            ('conv2d', conv2d),
            ('relu', relu),
            ('batch_norm', batch_norm),  
        ])
    )

    if maxpool_2d:
        conv2d_block.add_module('maxpool2d', maxpool_2d)
    
    return conv2d_block

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.conv2d_block1 = create_conv2d_block(
            conv2d_input=(1,256,3),
            maxpool_kernel_size=3,
        )
        
        self.conv2d_block2 = create_conv2d_block(
            conv2d_input=(256,256,3),
            maxpool_kernel_size=3,
        )

        self.conv2d_block3 = create_conv2d_block(
            conv2d_input=(256,210,3),
            maxpool_kernel_size=3,
        )
        # self.unflatten_layer = nn.Unflatten(1, (210,-1))
        # self.reshape_layer = partial(torch.reshape, shape=(8,210,-1))
        self.flatten_layer = torch.nn.Flatten(2)
        self.dense_layer = nn.LazyLinear(512)
        self.output_layer = nn.Linear(512, 88)
        self.softmax_layer = nn.Softmax(0)

    def forward(self, x):
        x = self.conv2d_block1(x)
        x = self.conv2d_block2(x)
        x = self.conv2d_block3(x)
        # x = self.unflatten_layer(x)
        # x = self.reshape_layer(x)
        x = self.flatten_layer(x)
        x = self.dense_layer(x)
        x = self.output_layer(x)
        x = self.softmax_layer(x)

        return x

model = NeuralNetwork().to(device)
print(model)

In [None]:
loss_fn = nn.BCELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
dense_layer = nn.LazyLinear(256).to(device)
sample_feature, sample_label = next(iter(train_dataloader))
pred = model(sample_feature.to(device))
print(pred.shape)
print(nn.Flatten(2)(pred).shape)
# loss_fn(pred, sample_label[2])
print(sample_feature.shape)
print(sample_label[2].shape)
# pred = dense_layer(pred)
# pred.reshape((8,210,-1)).shape

In [None]:
torch.binary_cross_entropy_with_logits(pred.to(device), sample_label[2].to(device)).shape

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, (y1, y2, y3)) in enumerate(dataloader):
        X, y3 = X.to(device), y3.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y3)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if batch % 100 == 0:
        loss, current = loss.item(), (batch + 1) * len(X)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, (y1,y2,y3) in dataloader:
            X, y3 = X.to(device), y3.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y3).item()
            correct += (pred.argmax(2) == y3.argmax(2)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= (size*210)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
# print(pred.argmax(2))
# print(sample_label[2].argmax(2))

In [None]:
torch.argmax(torch.Tensor(([0,0,0,1],[0,1,0,0])), dim=1)

In [None]:
epochs = 20
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(validation_dataloader, model, loss_fn)
print("Done!")

In [None]:
for batch, (X, (y1, y2, y3)) in enumerate(train_dataloader):
    print(y3.shape)
    break