In [1]:
# from numba import cuda
# cuda.select_device(2)
# # cuda.close()
# import torch 
# torch.cuda.is_available()
# torch.cuda.device_count()
# import torch
# torch.cuda.empty_cache()

In [2]:
%load_ext autoreload
%autoreload 2
%autosave 60

Autosaving every 60 seconds


In [3]:
device = "cuda:2"

In [4]:
splits = [10,20,30]
num_classes = len(splits)+1

In [5]:
from dataset import YouTubeDataset
dataset = YouTubeDataset(splits)

  from .autonotebook import tqdm as notebook_tqdm


0 - T001: 0.00012040138244628906 seconds.
1 - T003: 0.0007658004760742188 seconds.
2 - T004: 8.702278137207031e-05 seconds.
3 - T006: 8.320808410644531e-05 seconds.
4 - T007: 8.463859558105469e-05 seconds.
5 - T009: 8.296966552734375e-05 seconds.
6 - T011: 7.867813110351562e-05 seconds.
7 - T013: 0.00010275840759277344 seconds.
8 - T015: 6.580352783203125e-05 seconds.
9 - T016: 6.508827209472656e-05 seconds.
10 - T017: 7.700920104980469e-05 seconds.
11 - T018: 7.534027099609375e-05 seconds.
12 - T020: 6.29425048828125e-05 seconds.
13 - T021: 6.365776062011719e-05 seconds.
14 - T022: 6.699562072753906e-05 seconds.
15 - T023: 7.677078247070312e-05 seconds.
16 - T024: 6.985664367675781e-05 seconds.
17 - T027: 6.937980651855469e-05 seconds.
18 - T028: 6.198883056640625e-05 seconds.
19 - T031: 6.270408630371094e-05 seconds.
20 - T033: 6.175041198730469e-05 seconds.
21 - T034: 7.605552673339844e-05 seconds.
22 - T035: 0.00010013580322265625 seconds.
23 - T036: 9.1552734375e-05 seconds.
24 - 

In [6]:
print(dataset.audio.shape)

torch.Size([211, 131072])


In [7]:
from models.single_modality_classifiers import Wav2VecClassifier

model = Wav2VecClassifier(num_classes)

In [8]:
import torch
import math
import os
import time
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from inference import eval_audio, get_scores
from torch.nn.functional import cross_entropy

def train_model(model, dataset, learning_rate, lr_decay, weight_decay, batch_size, num_epochs, device, isCheckpoint=False, train_val_split = None, isVerbose=True):
    loss_history = []

    model.to(device)
    dataset.audio = dataset.audio.to(device)
    model.train()

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), learning_rate, weight_decay=weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: lr_decay ** epoch
    )

    # sample minibatch data
    if not train_val_split:
      train_ids = [i for i in range(len(dataset))]
      val_ids = None
    else:
      train_ids, val_ids = train_val_split

    iter_per_epoch = math.ceil(len(train_ids) // batch_size)
    class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.arange(model.num_classes), y=dataset.label[train_ids].numpy()), dtype=torch.float, device=device)
    loss_fn = torch.nn.NLLLoss(weight = class_weights)
    # loss_fn = cross_entropy
    
    for i in range(num_epochs):
        start_t = time.time()
        local_hist = []
        correct_cnt = 0
        y_preds = torch.empty((0,),device=device)
        y_trues = torch.empty((0,),device=device)
        for j in range(iter_per_epoch):
            _, waveforms, y_true = dataset[train_ids[j * batch_size: (j + 1) * batch_size]]

            waveforms = waveforms.to(device)
            y_true = y_true.to(device)

            optimizer.zero_grad()

            digits = model(waveforms)
            y_preds = torch.hstack([y_preds,digits.argmax(dim=1)])
            y_trues = torch.hstack([y_trues,y_true])

            probs = torch.nn.LogSoftmax(dim=1)(digits)
            loss = loss_fn(probs,y_true)
            loss.backward()

            local_hist.append(loss.item())
            optimizer.step()

        end_t = time.time()

        loss_mean = np.array(local_hist).mean()
        loss_history.append(loss_mean)
            
        print(
            f"(Epoch {i}), time: {end_t - start_t:.1f}s, loss: {loss_mean:.3f}"
        )
        if isVerbose:
            train_accuracy, train_precision, train_recall, train_f1 = get_scores(y_trues.to('cpu'), y_preds.to('cpu'), model.num_classes) # This is an aggregated result due to GPU size limit
            print(f"    Training Set - accuracy: {train_accuracy:.2f}, precision: {train_precision:.2f}, recall: {train_recall:.2f}, f1-score: {train_f1:.2f},")
            if val_ids is not None:
                val_accuracy, val_precision, val_recall, val_f1 = eval_audio(model, dataset, val_ids, num_classes, device, is_verbose = (loss_mean < 0.5))
                print(f"    Validation Set - accuracy: {val_accuracy:.2f}, precision: {val_precision:.2f}, recall: {val_recall:.2f}, f1-score: {val_f1:.2f},")
        if i%200 == 0 and isCheckpoint:
          dir = "checkpoints"
          if not os.path.exists(dir):
            os.mkdir(dir)
          file = f"epoch{i}.pt"
          path = dir+'/'+file
          torch.save({
                      'epoch': i,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': loss_mean,
                      }, path)

        lr_scheduler.step()

        if loss_mean < 0.5:
          break
    
    return loss_history

In [9]:
from sklearn.model_selection import KFold
import torch

def train_model_cv5(model, dataset):
    loss_hist = []
    kf = KFold(n_splits=5)
    cnt = 1
    for train_index, val_index in kf.split(dataset):
        model.reset()
        model.base.requires_grad = False
        print("Fold "+str(cnt)+" (val", val_index[0],"-",str(val_index[-1])+")")
        loss_hist_fold = train_model(model, device = device, dataset=dataset, train_val_split=(train_index, val_index),learning_rate=1e-5, lr_decay=0.99, weight_decay=1e-4, batch_size=10, num_epochs=300, isCheckpoint = False, isVerbose = True)
        loss_hist.append(loss_hist_fold)
        cnt += 1
    return loss_hist

In [10]:
lost_hist_folds = train_model_cv5(model, dataset)

Fold 1 (val 0 - 42)
(Epoch 0), time: 9.0s, loss: 1.406
    Training Set - accuracy: 0.47, precision: 0.12, recall: 0.25, f1-score: 0.16,
    Validation Set - accuracy: 0.82, precision: 0.21, recall: 0.25, f1-score: 0.23,
(Epoch 1), time: 8.8s, loss: 1.397
    Training Set - accuracy: 0.47, precision: 0.12, recall: 0.25, f1-score: 0.16,
    Validation Set - accuracy: 0.82, precision: 0.21, recall: 0.25, f1-score: 0.23,
(Epoch 2), time: 8.8s, loss: 1.391
    Training Set - accuracy: 0.47, precision: 0.12, recall: 0.25, f1-score: 0.16,
    Validation Set - accuracy: 0.82, precision: 0.21, recall: 0.25, f1-score: 0.23,
(Epoch 3), time: 8.9s, loss: 1.382
    Training Set - accuracy: 0.46, precision: 0.12, recall: 0.25, f1-score: 0.16,
    Validation Set - accuracy: 0.75, precision: 0.20, recall: 0.23, f1-score: 0.21,
(Epoch 4), time: 8.9s, loss: 1.372
    Training Set - accuracy: 0.46, precision: 0.12, recall: 0.25, f1-score: 0.16,
    Validation Set - accuracy: 0.72, precision: 0.20, recal

In [11]:
splits = [10,20]
num_classes = len(splits)+1

dataset = YouTubeDataset(splits)
model = Wav2VecClassifier(num_classes)
train_model_cv5(model, dataset)

0 - T001: 0.00019621849060058594 seconds.
1 - T003: 0.0009250640869140625 seconds.
2 - T004: 0.00012254714965820312 seconds.
3 - T006: 9.393692016601562e-05 seconds.
4 - T007: 9.489059448242188e-05 seconds.
5 - T009: 9.036064147949219e-05 seconds.
6 - T011: 8.821487426757812e-05 seconds.
7 - T013: 8.153915405273438e-05 seconds.
8 - T015: 9.34600830078125e-05 seconds.
9 - T016: 8.535385131835938e-05 seconds.
10 - T017: 0.00010347366333007812 seconds.
11 - T018: 9.918212890625e-05 seconds.
12 - T020: 8.678436279296875e-05 seconds.
13 - T021: 8.559226989746094e-05 seconds.
14 - T022: 9.322166442871094e-05 seconds.
15 - T023: 0.00010395050048828125 seconds.
16 - T024: 9.560585021972656e-05 seconds.
17 - T027: 9.608268737792969e-05 seconds.
18 - T028: 8.416175842285156e-05 seconds.
19 - T031: 8.630752563476562e-05 seconds.
20 - T033: 8.559226989746094e-05 seconds.
21 - T034: 0.00010538101196289062 seconds.
22 - T035: 8.749961853027344e-05 seconds.
23 - T036: 0.00010228157043457031 seconds.


[[1.1295877173542976,
  1.1273719817399979,
  1.1268511712551117,
  1.124887079000473,
  1.120848223567009,
  1.1153303906321526,
  1.1076990067958832,
  1.1024466305971146,
  1.089106909930706,
  1.0906677693128586,
  1.0913375169038773,
  1.0770012326538563,
  1.0709569789469242,
  1.0649440288543701,
  1.0545285604894161,
  1.034836869686842,
  1.022515494376421,
  0.9983116127550602,
  0.9832006767392159,
  1.0010471381247044,
  1.0138780772686005,
  0.9809206649661064,
  0.9553652666509151,
  0.9312872663140297,
  0.9159387163817883,
  0.8806792236864567,
  0.8545030504465103,
  0.8353528715670109,
  0.8315897472202778,
  0.8039390929043293,
  0.7733923085033894,
  0.7478423789143562,
  0.689947135746479,
  0.6941433772444725,
  0.7089653164148331,
  0.653754087164998,
  0.589550107717514,
  0.5779519621282816,
  0.5418468713760376,
  0.5064200181514025,
  0.49614852480590343],
 [1.1255304217338562,
  1.1167023032903671,
  1.10970588773489,
  1.1021500378847122,
  1.09011860936880

In [12]:
splits = [10]
num_classes = len(splits)+1

dataset = YouTubeDataset(splits)
model = Wav2VecClassifier(num_classes)
train_model_cv5(model, dataset)

0 - T001: 0.00040268898010253906 seconds.
1 - T003: 0.0006885528564453125 seconds.
2 - T004: 0.0002930164337158203 seconds.
3 - T006: 0.00023746490478515625 seconds.
4 - T007: 0.0002677440643310547 seconds.
5 - T009: 0.00022935867309570312 seconds.
6 - T011: 0.0002467632293701172 seconds.
7 - T013: 0.0002090930938720703 seconds.
8 - T015: 0.0002536773681640625 seconds.
9 - T016: 0.00021839141845703125 seconds.
10 - T017: 0.0002696514129638672 seconds.
11 - T018: 0.0002193450927734375 seconds.
12 - T020: 0.0002453327178955078 seconds.
13 - T021: 0.00021576881408691406 seconds.
14 - T022: 0.00025010108947753906 seconds.
15 - T023: 0.00022482872009277344 seconds.
16 - T024: 0.0002605915069580078 seconds.
17 - T027: 0.00022673606872558594 seconds.
18 - T028: 0.00015997886657714844 seconds.
19 - T031: 0.00017881393432617188 seconds.
20 - T033: 0.0001633167266845703 seconds.
21 - T034: 0.0001575946807861328 seconds.
22 - T035: 0.0001678466796875 seconds.
23 - T036: 0.00014328956604003906 sec

[[0.7464047539979219,
  0.7282361555844545,
  0.7211398221552372,
  0.7127474062144756,
  0.6998687721788883,
  0.6914235651493073,
  0.6826597899198532,
  0.6733386367559433,
  0.6579379066824913,
  0.6358466930687428,
  0.6398527063429356,
  0.6865830272436142,
  0.6680218912661076,
  0.6974000334739685,
  0.6706861667335033,
  0.6388220228254795,
  0.6111139599233866,
  0.5651876330375671,
  0.5597941409796476,
  0.5628745798021555,
  0.5896639954298735,
  0.5712015479803085,
  0.5613272916525602,
  0.5308751985430717,
  0.5039190147072077,
  0.48604968190193176],
 [0.698051143437624,
  0.6921644918620586,
  0.6900754831731319,
  0.6851878836750984,
  0.6797115318477154,
  0.6697666682302952,
  0.6665422096848488,
  0.6361907552927732,
  0.6220542211085558,
  0.6034978069365025,
  0.5748561136424541,
  0.5601145159453154,
  0.573477054014802,
  0.5323265343904495,
  0.5793350785970688,
  0.5909880064427853,
  0.5283417478203773,
  0.5145398955792189,
  0.5055317282676697,
  0.500572