In [1]:
import torch
import torchvision
import torch.nn as nn
import time
import json
import datetime
import matplotlib.pyplot as plt

In [2]:
%cd
from DeepLearning.Project2.data_loading_preparation import *

/home/kacper


In [3]:
def get_generic_classifier(input_size, hidden_size, output_size):
    classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(hidden_size, output_size),
    )
    return classifier
    
class LSTMSimple(nn.Module):
    def __init__(
        self,
        input_size, 
        hidden_size,
        num_layers,
        num_classes = 2,
        avgpool_dim = 32,
        classifier_size = 512,
        add_dropout=True
    ):       
        super().__init__()
        self.normalization = nn.BatchNorm1d(input_size)
        self.features = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.avgpool = nn.AdaptiveAvgPool1d(avgpool_dim)
        self.classifier = get_generic_classifier(
            hidden_size * avgpool_dim,
            classifier_size,
            num_classes
        )

    def forward(self, x):
        x = self.normalization(x)
        x = x.mT
        x, _ = self.features(x)
        x = x.mT
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
def eval_accuracy_detect(model, dataloader, reccurrent, training_device='cuda'):
    with torch.no_grad():
        model.to(training_device)
        correct = 0
        all_so_far = 0
        for inputs, labels in dataloader:
            inputs, labels = inputs.float().to(training_device), labels.float().to(training_device)
            labels =  (labels == 11).long()
            yhat = model(inputs)
            pred = torch.argmax(yhat, dim=1)

            
            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels)).item()
    return correct/all_so_far

In [5]:
def backup_to_ram(model):
    from copy import deepcopy
    return deepcopy(model).cpu()

class EarlyStopper:
    def __init__(self, patience = 3, backup_method=backup_to_ram):
        self.patience = patience
        self.current = 0
        
        self.backup_method = backup_method
        
        self.best_backup = None
        self.best_accuracy = 0.

    def should_continue(self, accuracy, model = None):
        if self.best_accuracy < accuracy:
            self.current = 0
            self.best_accuracy = accuracy
            if model is not None:
                self.best_backup = self.backup_method(model)
            return True
        
        self.current += 1
        
        if self.current >= self.patience:
            return False
        return True

In [6]:
def run_experiment(experiment_name, train_func, run, train_params=None):
    path = f"experiments_rnn/{experiment_name}_run_{run}_"
    print("Running experiment for ", path[:-1])
    
    import os
    try:
        if os.stat(path + "report.json").st_size != 0:
            print("Report exists already for " + path[:-1] + ". Skipping...")
            return
    except OSError:
        pass
    
    model, trajectory, validation_accuracy = train_func(train_params)
    
    with open(path + "report.json", "w") as f:
        json.dump(
            {
                "name": experiment_name,
                "train_params": train_params,
                "run": run,
                "best_accuracy_validation": validation_accuracy,
                "time_generated": datetime.datetime.now().isoformat(),
                "trajectory": trajectory
            },
            f
        )
    torch.save(model, path + "model.pt")

In [7]:
training_device = "cuda"
device = "cuda"
max_epochs = 250

def noaugLSTMdetector(params):  
    train, test, val = load_audio_dataloaders_validation(bs=128, limit_11=1.0)
    model = LSTMSimple(**params).to(training_device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=[0.9, 0.999], eps=10e-8)
    early_stopper = EarlyStopper(patience = 20)
    trajectory = []
    model.to(training_device)

    for epoch in range(1, max_epochs+1):
        model.train()
        for x, y in train:
            optimizer.zero_grad()
            x, y = x.float().to(device), y.to(device)
            y =  (y == 11).long()
            (yhat) = model(x)
            loss = criterion(yhat.softmax(1), y)
            loss.backward()
            optimizer.step()
        
        validation_accuracy = eval_accuracy_detect(model, val, training_device)
        
        print("Epoch: {}, Accuracy on validation set: {}".format(epoch, validation_accuracy))
        
        trajectory.append({
            "epoch": epoch,
            "validation": validation_accuracy,
        })
        
        if not early_stopper.should_continue(validation_accuracy, model):
            print("Early stop")
            model = early_stopper.best_backup
            model = model.to(device)
            break

    return model, trajectory, validation_accuracy


def freqmaskLSTMdetector(params):  
    train, test, val = load_audio_dataloaders_freqmask(bs=128, limit_11=1.0)
    model = LSTMSimple(**params).to(training_device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=[0.9, 0.999], eps=10e-8)
    early_stopper = EarlyStopper(patience = 20)
    trajectory = []
    model.to(training_device)

    for epoch in range(1, max_epochs+1):
        model.train()
        for x, y in train:
            optimizer.zero_grad()
            x, y = x.float().to(device), y.to(device)
            y =  (y == 11).long()
            (yhat) = model(x)
            loss = criterion(yhat.softmax(1), y)
            loss.backward()
            optimizer.step()
        
        validation_accuracy = eval_accuracy_detect(model, val, training_device)
        
        print("Epoch: {}, Accuracy on validation set: {}".format(epoch, validation_accuracy))
        
        trajectory.append({
            "epoch": epoch,
            "validation": validation_accuracy,
        })
        
        if not early_stopper.should_continue(validation_accuracy, model):
            print("Early stop")
            model = early_stopper.best_backup
            model = model.to(device)
            break

    return model, trajectory, validation_accuracy



In [8]:
param_grid = {
    "smallsingle": {
        "hidden_size": 64,
        "input_size": 20,
        "num_layers": 1,
    },
    "mediumsingle": {
        "hidden_size": 128,
        "input_size": 20,
        "num_layers": 1,
    },
}


experiment_list = [
    (
        f"trainer_{trainer.__name__}_params_{param_name}", 
        trainer,
        str(run),
        param
    )
    for run in range(1, 11) 
    for trainer in [noaugLSTMdetector, freqmaskLSTMdetector]
    for param_name, param in param_grid.items()
]

In [9]:
len(experiment_list)

40

In [10]:
for experiment in experiment_list:
    print(
        "Time:", datetime.datetime.now().isoformat(),
        "Experiment:", experiment[0]
    )
    try:
        run_experiment(*experiment)
    except Exception as e:
        print("Error occured, skipping...\n", repr(e))

Time: 2023-04-22T10:27:01.379644 Experiment: trainer_noaugLSTMdetector_params_smallsingle
Running experiment for  experiments_rnn/trainer_noaugLSTMdetector_params_smallsingle_run_1
Report exists already for experiments_rnn/trainer_noaugLSTMdetector_params_smallsingle_run_1. Skipping...
Time: 2023-04-22T10:27:01.379956 Experiment: trainer_noaugLSTMdetector_params_mediumsingle
Running experiment for  experiments_rnn/trainer_noaugLSTMdetector_params_mediumsingle_run_1
Report exists already for experiments_rnn/trainer_noaugLSTMdetector_params_mediumsingle_run_1. Skipping...
Time: 2023-04-22T10:27:01.379999 Experiment: trainer_freqmaskLSTMdetector_params_smallsingle
Running experiment for  experiments_rnn/trainer_freqmaskLSTMdetector_params_smallsingle_run_1
Report exists already for experiments_rnn/trainer_freqmaskLSTMdetector_params_smallsingle_run_1. Skipping...
Time: 2023-04-22T10:27:01.380036 Experiment: trainer_freqmaskLSTMdetector_params_mediumsingle
Running experiment for  experimen

Epoch: 8, Accuracy on validation set: 0.794736068225261
Epoch: 9, Accuracy on validation set: 0.7995882958388473
Epoch: 10, Accuracy on validation set: 0.7822379061902661
Epoch: 11, Accuracy on validation set: 0.802676077047493
Epoch: 12, Accuracy on validation set: 0.8034112630495516
Epoch: 13, Accuracy on validation set: 0.8076753418614909
Epoch: 14, Accuracy on validation set: 0.812086457873842
Epoch: 15, Accuracy on validation set: 0.8034112630495516
Epoch: 16, Accuracy on validation set: 0.8063520070577856
Epoch: 17, Accuracy on validation set: 0.8134097926775474
Epoch: 18, Accuracy on validation set: 0.8228201735038965
Epoch: 19, Accuracy on validation set: 0.8178209086898985
Epoch: 20, Accuracy on validation set: 0.8154683134833113
Epoch: 21, Accuracy on validation set: 0.8214968387001912
Epoch: 22, Accuracy on validation set: 0.8228201735038965
Epoch: 23, Accuracy on validation set: 0.8162034994853697
Epoch: 24, Accuracy on validation set: 0.8248786943096603
Epoch: 25, Accuracy

Epoch: 31, Accuracy on validation set: 0.8436994559623585
Epoch: 32, Accuracy on validation set: 0.8372298191442435
Epoch: 33, Accuracy on validation set: 0.8411998235553595
Epoch: 34, Accuracy on validation set: 0.8420820467578297
Epoch: 35, Accuracy on validation set: 0.837523893545067
Epoch: 36, Accuracy on validation set: 0.8350242611380679
Epoch: 37, Accuracy on validation set: 0.8453168651668872
Epoch: 38, Accuracy on validation set: 0.8476694603734745
Epoch: 39, Accuracy on validation set: 0.8312012939273636
Epoch: 40, Accuracy on validation set: 0.8397294515512425
Epoch: 41, Accuracy on validation set: 0.8411998235553595
Epoch: 42, Accuracy on validation set: 0.8382590795471254
Early stop
Time: 2023-04-22T11:07:11.256528 Experiment: trainer_noaugLSTMdetector_params_mediumsingle
Running experiment for  experiments_rnn/trainer_noaugLSTMdetector_params_mediumsingle_run_6
Epoch: 1, Accuracy on validation set: 0.7326863696515218
Epoch: 2, Accuracy on validation set: 0.77723864137626

Epoch: 9, Accuracy on validation set: 0.8216438759006028
Epoch: 10, Accuracy on validation set: 0.8156153506837229
Epoch: 11, Accuracy on validation set: 0.8276724011174827
Epoch: 12, Accuracy on validation set: 0.8112042346713718
Epoch: 13, Accuracy on validation set: 0.8238494339067784
Epoch: 14, Accuracy on validation set: 0.8348772239376562
Epoch: 15, Accuracy on validation set: 0.827525363917071
Epoch: 16, Accuracy on validation set: 0.8170857226878401
Epoch: 17, Accuracy on validation set: 0.8373768563446552
Epoch: 18, Accuracy on validation set: 0.8426701955594765
Epoch: 19, Accuracy on validation set: 0.8442876047640053
Epoch: 20, Accuracy on validation set: 0.845905013968534
Epoch: 21, Accuracy on validation set: 0.8411998235553595
Epoch: 22, Accuracy on validation set: 0.8516394647845905
Epoch: 23, Accuracy on validation set: 0.8472283487722394
Epoch: 24, Accuracy on validation set: 0.837965005146302
Epoch: 25, Accuracy on validation set: 0.8451698279664756
Epoch: 26, Accurac

Epoch: 1, Accuracy on validation set: 0.7023967063667108
Epoch: 2, Accuracy on validation set: 0.7428319364799294
Epoch: 3, Accuracy on validation set: 0.7578297309219233
Epoch: 4, Accuracy on validation set: 0.7678282605499192
Epoch: 5, Accuracy on validation set: 0.7767975297750331
Epoch: 6, Accuracy on validation set: 0.7809145713865608
Epoch: 7, Accuracy on validation set: 0.7920893986178503
Epoch: 8, Accuracy on validation set: 0.794736068225261
Epoch: 9, Accuracy on validation set: 0.794736068225261
Epoch: 10, Accuracy on validation set: 0.8017938538450228
Epoch: 11, Accuracy on validation set: 0.7984119982355536
Epoch: 12, Accuracy on validation set: 0.8020879282458462
Epoch: 13, Accuracy on validation set: 0.8107631230701368
Epoch: 14, Accuracy on validation set: 0.8069401558594325
Epoch: 15, Accuracy on validation set: 0.8173797970886635
Epoch: 16, Accuracy on validation set: 0.8142920158800176
Epoch: 17, Accuracy on validation set: 0.8172327598882517
Epoch: 18, Accuracy on va

Epoch: 25, Accuracy on validation set: 0.8481105719747096
Epoch: 26, Accuracy on validation set: 0.8419350095574181
Epoch: 27, Accuracy on validation set: 0.8435524187619468
Epoch: 28, Accuracy on validation set: 0.8431113071607117
Epoch: 29, Accuracy on validation set: 0.8457579767681224
Epoch: 30, Accuracy on validation set: 0.8428172327598883
Epoch: 31, Accuracy on validation set: 0.8309072195265402
Epoch: 32, Accuracy on validation set: 0.8411998235553595
Epoch: 33, Accuracy on validation set: 0.8406116747537127
Epoch: 34, Accuracy on validation set: 0.839435377150419
Epoch: 35, Accuracy on validation set: 0.8411998235553595
Epoch: 36, Accuracy on validation set: 0.8401705631524776
Epoch: 37, Accuracy on validation set: 0.8417879723570063
Epoch: 38, Accuracy on validation set: 0.8435524187619468
Epoch: 39, Accuracy on validation set: 0.8488457579767681
Epoch: 40, Accuracy on validation set: 0.8409057491545361
Epoch: 41, Accuracy on validation set: 0.8472283487722394
Epoch: 42, Accu

Epoch: 1, Accuracy on validation set: 0.717247463608293
Epoch: 2, Accuracy on validation set: 0.7240111748272313
Epoch: 3, Accuracy on validation set: 0.7579767681223349
Epoch: 4, Accuracy on validation set: 0.7854727245993236
Epoch: 5, Accuracy on validation set: 0.7948831054256726
Epoch: 6, Accuracy on validation set: 0.8128216438759006
Epoch: 7, Accuracy on validation set: 0.8167916482870167
Epoch: 8, Accuracy on validation set: 0.8078223790619027
Epoch: 9, Accuracy on validation set: 0.8179679458903103
Epoch: 10, Accuracy on validation set: 0.8319364799294221
Epoch: 11, Accuracy on validation set: 0.8295838847228348
Epoch: 12, Accuracy on validation set: 0.8303190707248934
Epoch: 13, Accuracy on validation set: 0.845463902367299
Epoch: 14, Accuracy on validation set: 0.848404646375533
Epoch: 15, Accuracy on validation set: 0.8434053815615351
Epoch: 16, Accuracy on validation set: 0.838994265549184
Epoch: 17, Accuracy on validation set: 0.853845022790766
Epoch: 18, Accuracy on valid

Epoch: 10, Accuracy on validation set: 0.8081164534627261
Epoch: 11, Accuracy on validation set: 0.810616085869725
Epoch: 12, Accuracy on validation set: 0.8129686810763123
Epoch: 13, Accuracy on validation set: 0.8123805322746654
Epoch: 14, Accuracy on validation set: 0.8166446110866049
Epoch: 15, Accuracy on validation set: 0.8231142479047199
Epoch: 16, Accuracy on validation set: 0.8247316571092487
Epoch: 17, Accuracy on validation set: 0.8229672107043082
Epoch: 18, Accuracy on validation set: 0.8216438759006028
Epoch: 19, Accuracy on validation set: 0.8164975738861933
Epoch: 20, Accuracy on validation set: 0.8188501690927805
Epoch: 21, Accuracy on validation set: 0.8197323922952507
Epoch: 22, Accuracy on validation set: 0.8225260991030731
Epoch: 23, Accuracy on validation set: 0.8288486987207764
Epoch: 24, Accuracy on validation set: 0.8259079547125423
Epoch: 25, Accuracy on validation set: 0.8294368475224232
Epoch: 26, Accuracy on validation set: 0.8216438759006028
Epoch: 27, Accu

Epoch: 5, Accuracy on validation set: 0.810616085869725
Epoch: 6, Accuracy on validation set: 0.8156153506837229
Epoch: 7, Accuracy on validation set: 0.8266431407146008
Epoch: 8, Accuracy on validation set: 0.8164975738861933
Epoch: 9, Accuracy on validation set: 0.8232612851051316
Epoch: 10, Accuracy on validation set: 0.8248786943096603
Epoch: 11, Accuracy on validation set: 0.8417879723570063
Epoch: 12, Accuracy on validation set: 0.835465372739303
Epoch: 13, Accuracy on validation set: 0.8419350095574181
Epoch: 14, Accuracy on validation set: 0.8447287163652404
Epoch: 15, Accuracy on validation set: 0.8470813115718276
Epoch: 16, Accuracy on validation set: 0.8520805763858256
Epoch: 17, Accuracy on validation set: 0.8557565063961182
Epoch: 18, Accuracy on validation set: 0.8438464931627702
Epoch: 19, Accuracy on validation set: 0.8489927951771798
Epoch: 20, Accuracy on validation set: 0.8495809439788267
Epoch: 21, Accuracy on validation set: 0.8451698279664756
Epoch: 22, Accuracy o

Epoch: 18, Accuracy on validation set: 0.8334068519335391
Epoch: 19, Accuracy on validation set: 0.8335538891339509
Epoch: 20, Accuracy on validation set: 0.8441405675635936
Epoch: 21, Accuracy on validation set: 0.8362005587413616
Epoch: 22, Accuracy on validation set: 0.8450227907660638
Epoch: 23, Accuracy on validation set: 0.8445816791648287
Epoch: 24, Accuracy on validation set: 0.836494633142185
Epoch: 25, Accuracy on validation set: 0.8250257315100721
Epoch: 26, Accuracy on validation set: 0.8367887075430084
Epoch: 27, Accuracy on validation set: 0.8467872371710042
Epoch: 28, Accuracy on validation set: 0.8411998235553595
Epoch: 29, Accuracy on validation set: 0.8334068519335391
Epoch: 30, Accuracy on validation set: 0.8351712983384796
Epoch: 31, Accuracy on validation set: 0.835465372739303
Epoch: 32, Accuracy on validation set: 0.8347301867372445
Epoch: 33, Accuracy on validation set: 0.8188501690927805
Epoch: 34, Accuracy on validation set: 0.8278194383178944
Epoch: 35, Accur