In [1]:
# from numba import cuda
# cuda.select_device(2)
# # cuda.close()
# import torch 
# torch.cuda.is_available()
# torch.cuda.device_count()
# import torch
# torch.cuda.empty_cache()

In [2]:
%load_ext autoreload
%autoreload 2
%autosave 60

Autosaving every 60 seconds


In [3]:
device = "cuda"

In [4]:
splits = [10,20,30]
num_classes = len(splits)+1

In [5]:
from dataset import YouTubeDataset
dataset = YouTubeDataset(splits)

  from .autonotebook import tqdm as notebook_tqdm


0 - T001: 0.0001842975616455078 seconds.
1 - T003: 0.0007531642913818359 seconds.
2 - T004: 7.081031799316406e-05 seconds.
3 - T006: 6.914138793945312e-05 seconds.
4 - T007: 7.05718994140625e-05 seconds.
5 - T009: 6.818771362304688e-05 seconds.
6 - T011: 6.365776062011719e-05 seconds.
7 - T013: 7.987022399902344e-05 seconds.
8 - T015: 6.842613220214844e-05 seconds.
9 - T016: 7.462501525878906e-05 seconds.
10 - T017: 8.368492126464844e-05 seconds.
11 - T018: 7.748603820800781e-05 seconds.
12 - T020: 6.794929504394531e-05 seconds.
13 - T021: 6.508827209472656e-05 seconds.
14 - T022: 6.937980651855469e-05 seconds.
15 - T023: 7.62939453125e-05 seconds.
16 - T024: 7.414817810058594e-05 seconds.
17 - T027: 6.937980651855469e-05 seconds.
18 - T028: 6.222724914550781e-05 seconds.
19 - T031: 6.318092346191406e-05 seconds.
20 - T033: 6.580352783203125e-05 seconds.
21 - T034: 7.62939453125e-05 seconds.
22 - T035: 6.365776062011719e-05 seconds.
23 - T036: 7.176399230957031e-05 seconds.
24 - T037: 

In [6]:
print(dataset.audio.shape)

torch.Size([211, 262144])


In [7]:
from models.single_modality_classifiers import Wav2VecClassifier

model = Wav2VecClassifier(num_classes)

In [8]:
dataset.audio = dataset.audio.to(device)

In [9]:
import torch
import math
import os
import time
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from inference import eval_audio, get_scores
from torch.nn.functional import cross_entropy

def train_model(model, dataset, learning_rate, lr_decay, weight_decay, batch_size, num_epochs, device, isCheckpoint=False, train_val_split = None, isVerbose=True):
    loss_history = []
    early_stop_loss = 0.4
    model.train()

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), learning_rate, weight_decay=weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: lr_decay ** epoch
    )

    # sample minibatch data
    if not train_val_split:
      train_ids = [i for i in range(len(dataset))]
      val_ids = None
    else:
      train_ids, val_ids = train_val_split

    iter_per_epoch = math.ceil(len(train_ids) // batch_size)
    class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.arange(model.module.num_classes), y=dataset.label[train_ids].numpy()), dtype=torch.float, device=device)
    loss_fn = torch.nn.NLLLoss(weight = class_weights)
    # loss_fn = cross_entropy
    
    for i in range(num_epochs):
        start_t = time.time()
        local_hist = []
        correct_cnt = 0
        y_preds = torch.empty((0,),device='cpu')
        y_trues = torch.empty((0,),device=device)
        for j in range(iter_per_epoch):
            _, waveforms, y_true = dataset[train_ids[j * batch_size: (j + 1) * batch_size]]
            optimizer.zero_grad()
            y_true = y_true.to(device)
            digits = model(waveforms)
            y_preds = torch.hstack([y_preds,digits.argmax(dim=1).to('cpu')])
            y_trues = torch.hstack([y_trues,y_true])

            probs = torch.nn.LogSoftmax(dim=1)(digits)
            loss = loss_fn(probs,y_true)
            loss.backward()

            local_hist.append(loss.item())
            optimizer.step()

        end_t = time.time()

        loss_mean = np.array(local_hist).mean()
        loss_history.append(loss_mean)
            
        print(
            f"(Epoch {i}), time: {end_t - start_t:.1f}s, loss: {loss_mean:.3f}"
        )
        if isVerbose:
            train_accuracy, train_precision, train_recall, train_f1 = get_scores(y_trues.to("cpu"), y_preds, model.module.num_classes) # This is an aggregated result due to GPU size limit
            print(f"    Training Set - accuracy: {train_accuracy:.2f}, precision: {train_precision:.2f}, recall: {train_recall:.2f}, f1-score: {train_f1:.2f},")
            if val_ids is not None:
                val_accuracy, val_precision, val_recall, val_f1 = eval_audio(model, dataset, val_ids, num_classes, device, is_verbose = (loss_mean < early_stop_loss))
                print(f"    Validation Set - accuracy: {val_accuracy:.2f}, precision: {val_precision:.2f}, recall: {val_recall:.2f}, f1-score: {val_f1:.2f},")
        if i%200 == 0 and isCheckpoint:
          dir = "checkpoints"
          if not os.path.exists(dir):
            os.mkdir(dir)
          file = f"epoch{i}.pt"
          path = dir+'/'+file
          torch.save({
                      'epoch': i,
                      'model_state_dict': model.module.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss_mean,
                      }, path)

        lr_scheduler.step()

        if loss_mean < early_stop_loss:
          break
    
    return loss_history

In [10]:
from sklearn.model_selection import KFold
import torch

def train_model_cv5(model, dataset):
    loss_hist = []
    kf = KFold(n_splits=5)
    cnt = 1
    for train_index, val_index in kf.split(dataset):
        model = Wav2VecClassifier(num_classes)
        model = torch.nn.DataParallel(model)
        model.to(device)
        model.module.base.requires_grad = False
        print("Fold "+str(cnt)+" (val", val_index[0],"-",str(val_index[-1])+")")
        loss_hist_fold = train_model(model, device = device, dataset=dataset, train_val_split=(train_index, val_index),learning_rate=1e-5, lr_decay=0.99, weight_decay=1e-4, batch_size=7, num_epochs=300, isCheckpoint = False, isVerbose = True)
        loss_hist.append(loss_hist_fold)
        cnt += 1
    return loss_hist

In [11]:
lost_hist_folds = train_model_cv5(model, dataset)

Fold 1 (val 0 - 42)
(Epoch 0), time: 19.1s, loss: 1.394
    Training Set - accuracy: 0.34, precision: 0.08, recall: 0.25, f1-score: 0.13,
    Validation Set - accuracy: 0.12, precision: 0.03, recall: 0.25, f1-score: 0.06,
(Epoch 1), time: 40.7s, loss: 1.390
    Training Set - accuracy: 0.34, precision: 0.08, recall: 0.25, f1-score: 0.13,
    Validation Set - accuracy: 0.12, precision: 0.03, recall: 0.25, f1-score: 0.06,
(Epoch 2), time: 9.7s, loss: 1.384
    Training Set - accuracy: 0.33, precision: 0.14, recall: 0.24, f1-score: 0.13,
    Validation Set - accuracy: 0.15, precision: 0.28, recall: 0.26, f1-score: 0.07,
(Epoch 3), time: 9.7s, loss: 1.374
    Training Set - accuracy: 0.43, precision: 0.25, recall: 0.29, f1-score: 0.23,
    Validation Set - accuracy: 0.28, precision: 0.19, recall: 0.17, f1-score: 0.13,
(Epoch 4), time: 9.6s, loss: 1.362
    Training Set - accuracy: 0.46, precision: 0.23, recall: 0.29, f1-score: 0.25,
    Validation Set - accuracy: 0.45, precision: 0.19, rec

In [12]:
splits = [10,20]
num_classes = len(splits)+1

dataset = YouTubeDataset(splits)
model = Wav2VecClassifier(num_classes)
train_model_cv5(model, dataset)

0 - T001: 0.00013518333435058594 seconds.
1 - T003: 0.0006113052368164062 seconds.
2 - T004: 8.702278137207031e-05 seconds.
3 - T006: 8.320808410644531e-05 seconds.
4 - T007: 8.416175842285156e-05 seconds.
5 - T009: 6.628036499023438e-05 seconds.
6 - T011: 6.198883056640625e-05 seconds.
7 - T013: 6.222724914550781e-05 seconds.
8 - T015: 6.651878356933594e-05 seconds.
9 - T016: 7.224082946777344e-05 seconds.
10 - T017: 7.128715515136719e-05 seconds.
11 - T018: 6.699562072753906e-05 seconds.
12 - T020: 5.7697296142578125e-05 seconds.
13 - T021: 5.745887756347656e-05 seconds.
14 - T022: 5.9604644775390625e-05 seconds.
15 - T023: 7.081031799316406e-05 seconds.
16 - T024: 6.508827209472656e-05 seconds.
17 - T027: 6.365776062011719e-05 seconds.
18 - T028: 5.698204040527344e-05 seconds.
19 - T031: 5.7220458984375e-05 seconds.
20 - T033: 5.6743621826171875e-05 seconds.
21 - T034: 6.985664367675781e-05 seconds.
22 - T035: 5.6743621826171875e-05 seconds.
23 - T036: 6.508827209472656e-05 seconds.

[[1.1121671100457509,
  1.1068716545899708,
  1.1036824310819309,
  1.1013781254490216,
  1.09637950360775,
  1.090869478881359,
  1.0792051230867703,
  1.065562590956688,
  1.0514291500051816,
  1.0387335643172264,
  1.0386073291301727,
  1.0039257208506267,
  1.0100058342019718,
  0.9688134888807932,
  0.9283306375145912,
  0.9033289973934492,
  0.8919535999496778,
  0.8557506228486697,
  0.8294085363547007,
  0.7852052648862203,
  0.7916455840071043,
  0.7985212231675783,
  0.7076503510276476,
  0.6777491619189581,
  0.6443247894446055,
  0.6133542507886887,
  0.5889756840964159,
  0.5599232316017151,
  0.543206458290418,
  0.5622219517827034,
  0.5332504796485106,
  0.510855051378409,
  0.4845707733184099,
  0.46152288777132827,
  0.45881975380082923,
  0.45565666755040485,
  0.4410321644196908,
  0.4440429024398327,
  0.3999539551635583],
 [1.1034414966901143,
  1.0980794404943783,
  1.096241702636083,
  1.0935727084676425,
  1.0901789988080661,
  1.0841078758239746,
  1.077645346

In [13]:
splits = [10]
num_classes = len(splits)+1

dataset = YouTubeDataset(splits)
model = Wav2VecClassifier(num_classes)
train_model_cv5(model, dataset)

0 - T001: 0.00016355514526367188 seconds.
1 - T003: 0.0007355213165283203 seconds.
2 - T004: 0.00010752677917480469 seconds.
3 - T006: 6.961822509765625e-05 seconds.
4 - T007: 7.05718994140625e-05 seconds.
5 - T009: 6.985664367675781e-05 seconds.
6 - T011: 6.413459777832031e-05 seconds.
7 - T013: 6.389617919921875e-05 seconds.
8 - T015: 6.914138793945312e-05 seconds.
9 - T016: 6.580352783203125e-05 seconds.
10 - T017: 7.319450378417969e-05 seconds.
11 - T018: 6.628036499023438e-05 seconds.
12 - T020: 6.389617919921875e-05 seconds.
13 - T021: 6.461143493652344e-05 seconds.
14 - T022: 6.937980651855469e-05 seconds.
15 - T023: 6.604194641113281e-05 seconds.
16 - T024: 7.152557373046875e-05 seconds.
17 - T027: 9.107589721679688e-05 seconds.
18 - T028: 7.772445678710938e-05 seconds.
19 - T031: 7.82012939453125e-05 seconds.
20 - T033: 7.677078247070312e-05 seconds.
21 - T034: 8.034706115722656e-05 seconds.
22 - T035: 7.963180541992188e-05 seconds.
23 - T036: 7.62939453125e-05 seconds.
24 - T

[[0.6994764357805252,
  0.6904510433475176,
  0.6898252392808596,
  0.6849656080206236,
  0.6791842579841614,
  0.6685983339945475,
  0.6542384674151739,
  0.6329509342710177,
  0.6054744000236193,
  0.5779221889873346,
  0.5620468606551489,
  0.5109893977642059,
  0.49451879784464836,
  0.4607798469563325,
  0.39171617353955906],
 [0.7053423399726549,
  0.7024213100473086,
  0.7006821483373642,
  0.6980388263861338,
  0.6949471111098925,
  0.69208163022995,
  0.686525970697403,
  0.6730467031399409,
  0.672952781120936,
  0.656952107946078,
  0.6412579466899236,
  0.6403519734740257,
  0.5955919673045477,
  0.6145541605850061,
  0.5954622179269791,
  0.6248841397464275,
  0.6526696135600408,
  0.5877433208127817,
  0.5946950229505698,
  0.5373520590364933,
  0.5204312403996786,
  0.48310720920562744,
  0.4795517648259799,
  0.4969275730351607,
  0.408834982663393,
  0.4054417299727599,
  0.3580077402293682],
 [0.7104900976022085,
  0.7006963963309923,
  0.6980793550610542,
  0.6973808