In [None]:
import torch
import torch.utils.data as data
from torch.utils.data import Dataset
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import random
import torchaudio
import os
from torch import nn


class GuitarFxDataset(Dataset):

    def __init__(self, annotations_file, audio_dir, transformation, target_sample_rate, device):
        self.annotations = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.device = device
        self.transformation = transformation.to(self.device)
        self.target_sample_rate = target_sample_rate
        self.c = 0
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        label = self._get_audio_sample_label(index)
        self.c+=1
        percentage = round((self.c / len(self.annotations)) * 100, 2)
        #print(f"Loading {percentage}%\r", end=' ')
        signal, sr = torchaudio.load(audio_sample_path) # restituisce onda sonora e sample rate
        signal = signal.to(self.device)
        signal = self._resample_if_necessary(signal,sr) #normalizzo il sample rate
        signal = self._mix_down_if_necessary(signal) #normalizzo la traccia a mono se non lo è
        #print(signal.shape)
        signal = self.transformation(signal) #applico MelSpectrograms
        #print(f"Signal after mel: {signal.shape}")
        signal = torchaudio.transforms.AmplitudeToDB()(signal) #converto in decibel
        #print(f"Signal after decibel: {signal.shape}")
        return signal, label
    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate).to(self.device)
            signal = resampler(signal)
        return signal

    def _mix_down_if_necessary(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal


    def _get_audio_sample_path(self, index):
        path = os.path.join(self.audio_dir, self.annotations.iloc[index, 0]) #iloc è un metodo di pandas e serve per recuperare un dato dal file .csv dati gli indici
        return path

    def _get_audio_sample_label(self, index):
        return self.annotations.iloc[index, 3]
    
class CNNNetwork(nn.Module):

    def __init__(self, l1_reg=0.0001, l2_reg=0.001):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                padding = 2,
                stride=1
            ), 
            nn.ReLU(),     #Rectified Linear Unit
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.6)
        )
        self.layer4 = nn.Sequential(
            nn.Linear(70992, 432),#64*81     10368   6912 13056 25024 #?46464 13248 70992
            nn.ReLU()
        )
        self.layer5 = nn.Sequential(
            nn.Linear(432, 36),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
       
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.3)
        self.linear = nn.Linear(36, 12) #input size e output size
        self.softmax = nn.Softmax(dim =1)
        
        self.l1_reg = l1_reg
        self.l2_reg = l2_reg
        
    def forward(self, input_data):
        x = self.conv1(input_data)
        x = self.flatten(x)
        x = self.layer4(x)
        x = self.layer5(x)
        
         # Aggiungi la regolarizzazione L1 ai pesi
        l1_loss = 0
        for param in self.parameters():
            l1_loss += torch.norm(param, 1)

        # Aggiungi la regolarizzazione L2 ai pesi
        l2_loss = 0
        for param in self.parameters():
            l2_loss += torch.norm(param, 2)
        
        logits = self.linear(x)
        predictions = self.softmax(logits)
        return predictions, self.l1_reg * l1_loss, self.l2_reg * l2_loss

from torch.utils.data import WeightedRandomSampler, DataLoader
import torch.nn.functional as F
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import torch.onnx
BATCH_SIZE = 256
EPOCHS = 700
LEARNING_RATE = 0.0001
ANNOTATIONS_FILE = "../input/guitarfxdataset1/annotations1.csv"
AUDIO_DIR = "../input/guitarfxdataset/audio"
SAMPLE_RATE = 44100


def prepareDatasets(dataset, valid_size):
    train_dataset, validation_dataset = train_test_split(dataset, test_size=valid_size, random_state = 42)
    return train_dataset, validation_dataset

def create_data_loader(gfxd, batch_size, sampler):
    dataloader = DataLoader(gfxd, batch_size= batch_size, shuffle = False, sampler=sampler)
    return dataloader
def weighted_accuracy(logits, labels, class_weights):
    pred = torch.argmax(logits, dim=1)
    correct = (pred == labels).float()
    accuracy = (correct * class_weights[labels]).sum() / class_weights[labels].sum()
    return accuracy.item()

def compute_batch_weights(y_batch, num_classes):
    # compute the class frequencies in the batch
    class_counts = np.bincount(y_batch, minlength=num_classes)
    # compute the class weights based on the class frequencies
    class_weights = 1.0 / class_counts.astype(np.float32)
    # normalize the class weights so that they sum to num_classes
    class_weights /= class_weights.sum() * num_classes
    return torch.tensor(class_weights, dtype=torch.float32).unsqueeze(0)

def train_one_epoch(model, train_data_loader, validation_loader, optimiser,class_weights,val_class_weights, device):
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    val_criterion = nn.CrossEntropyLoss(weight=val_class_weights)
    loss = 0.0
    correct = 0.0
    val_loss = 0.0
    val_correct = 0.0
    i=0
    model.train()
    for inputs, targets in train_data_loader:
        i+=1
        load_percentage =round((i / len(train_data_loader)) * 100,2)
        #print(f"Loading {load_percentage}%\r", end=' ')
        inputs,targets = inputs.to(device), targets.to(device)
        predictions, l1_loss, l2_loss = model(inputs)
        loss_weights = class_weights[targets]
        
        loss1 =  criterion(predictions, targets)* loss_weights
        #backpropagate loss and update weights
        total_loss = loss1.mean() + l1_loss + l2_loss
        optimiser.zero_grad()
        total_loss.backward()
        optimiser.step()

        _, preds = torch.max(predictions, 1)
        loss += (loss1.mean().item() + l1_loss.item() + l2_loss.item())

        correct += torch.sum(preds == targets.data).item()
    else: #validation
        model.eval()
        with torch.no_grad():
            total_samples = 0
            total_weighted_correct = 0
            for val_input, val_labels in validation_loader:
                val_input = val_input.to(device)
                val_labels = val_labels.to(device)
                val_outputs, _,_=model(val_input)
                val_loss_weights = val_class_weights[val_labels]
                
                val_loss1 = criterion(val_outputs, val_labels)* val_loss_weights
                _ , val_preds = torch.max(val_outputs, 1)
                val_loss += val_loss1.mean().item()
                val_correct += torch.sum(val_preds == val_labels.data).item()
                total_samples += len(val_labels)
                total_weighted_correct += (weighted_accuracy(val_outputs, val_labels, val_class_weights) * len(val_labels))
            epoch_loss = loss*100 / len(train_data_loader.dataset)
            epoch_f1 =f1_score(targets.data.cpu(), preds.cpu(), average='weighted')# correct *100/ len(train_data_loader.dataset)
#epoch_acc = torch.sum(correct * weights) / torch.sum(weights)
            epoch_acc = weighted_accuracy(predictions, targets, class_weights)
            loss_history.append(epoch_loss)
            f1_history.append(epoch_f1)
            correct_history.append(epoch_acc)

            val_epoch_loss = val_loss*100 / len(validation_loader.dataset)
            val_epoch_f1 = f1_score(val_labels.data.cpu(), val_preds.cpu(), average='weighted') #val_correct*100 / len(validation_loader.dataset)
            #val_epoch_acc = torch.sum(val_correct * weights) / torch.sum(weights)
            val_epoch_acc = total_weighted_correct / total_samples
            val_loss_history.append(val_epoch_loss)
            val_f1_history.append(val_epoch_f1)
            val_correct_history.append(val_epoch_acc)
            print('Training Loss:{:.3f}%, training f1 score:{:.3f}, training accuracy:{:.3f}%'.format(epoch_loss, epoch_f1, epoch_acc*100))
            print('validation_loss:{:.3f}%, validation f1 score{:.3f}, validation accuracy:{:.3f}%'.format(val_epoch_loss, val_epoch_f1, val_epoch_acc*100))
            
def train (model, train_data_loader, validation_loader, optimiser, class_weights,val_class_weights, device, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_one_epoch(model, train_data_loader, validation_loader, optimiser,class_weights,val_class_weights, device)
        print("--------------")
    print("Training is done.")

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    print(f"Using device {device}")
    
#instantiate Dataset and create dataLoader
    mel_spectograms = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=1024,
        hop_length=512,
        n_mels=100,
    )

    gfxd = GuitarFxDataset(ANNOTATIONS_FILE,
                            AUDIO_DIR,
                            mel_spectograms,
                            SAMPLE_RATE,
                            device)
    
    #Visualizzo uno o più spettogrammi
    counter =0
    while counter <10:
        random_index = random.randint(0, len(gfxd) - 1)
        sample_item = gfxd[random_index]

        # Estrai lo spettrogramma di Mel dall'elemento
        mel_spectrogram = sample_item[0].cpu().numpy().squeeze()
        print(mel_spectrogram.shape)
        lbl = sample_item[1]

        # Visualizza lo spettrogramma di Mel
        print(f"Il seguente grafico si riferisce a uno spettogramma di label {lbl}")
        plt.figure(figsize=(10, 4))
        plt.imshow(mel_spectrogram, cmap='viridis', origin='lower')
        plt.title('Spettrogramma di Mel')
        plt.xlabel('Tempo')
        plt.ylabel('Frequenza')
        plt.colorbar(format='%+2.0f dB')
        plt.show()
        counter += 1
    
    train_dataset, validation_dataset = prepareDatasets(gfxd, 0.33) # 20% of samples for validation
    
    
    label_list = []
    val_label_list = []
    for data in train_dataset:
        label_list.append(data[1])
    for data in validation_dataset:
        val_label_list.append(data[1])
    #class_weights=class_weight.compute_class_weight(class_weight = "balanced", classes= np.unique(label_list), y= label_list)
    #class_weights=torch.tensor(class_weights,dtype=torch.float)
    class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(label_list),y=label_list)
    class_weights =torch.tensor(class_weights, dtype=torch.float32).to(device)
    val_class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(val_label_list),y=val_label_list)
    val_class_weights = torch.tensor(val_class_weights, dtype=torch.float).to(device)
    
    weights = [class_weights[label] for label in label_list]
    sampler = WeightedRandomSampler(weights, len(weights))
    
    val_weights = [val_class_weights[label] for label in val_label_list]
    val_sampler = WeightedRandomSampler(val_weights, len(val_weights))
    
    train_dataloader = create_data_loader(train_dataset, BATCH_SIZE, sampler)
    valid_dataloader = create_data_loader(validation_dataset, BATCH_SIZE, val_sampler)
    
    #construct model and assign it to device
    cnn = CNNNetwork().to(device)
    print(cnn)

    #initialize loss function + optimiser
    #loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights,dtype=torch.float).to(device),reduction='mean')
    optimiser = torch.optim.Adam(cnn.parameters(),lr=LEARNING_RATE)

    loss_history = []
    f1_history =[]
    correct_history = []
    val_loss_history = []
    val_f1_history = []
    val_correct_history = []

    #train model
    train(cnn, train_dataloader, valid_dataloader, optimiser,class_weights,val_class_weights, device, EPOCHS)
    #plot the results
    plt.plot(loss_history, label='Training Loss')
    plt.plot(val_loss_history, label='Validation Loss')
    plt.legend()
    plt.show()
    plt.plot(correct_history, label='Training accuracy')
    plt.plot(val_correct_history, label='Validation accuracy')
    plt.legend()
    plt.show()
    plt.plot(f1_history, label='Training F1 score')
    plt.plot(val_f1_history, label='Validation F1 score')
    plt.legend()
    plt.show()

    torch.save(cnn.state_dict(), "./feedforwardnet5_3.pth")
    
    from torch.utils.mobile_optimizer import optimize_for_mobile
    cnn = cnn.to("cpu")
   
    cnn.eval() 

    # Generate some random noise
    #1, 100, 173
    dummy_input = torch.randn(1, 1 , 100, 173, requires_grad=True).to("cpu")

    # Generate the optimized model
    traced_script_module = torch.jit.trace(cnn, dummy_input)
    traced_script_module_optimized = optimize_for_mobile(traced_script_module)

    # Save the optimzied model
    traced_script_module_optimized._save_for_lite_interpreter("./model.pt")
    
    