In [22]:
%reset -f
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config Completer.use_jedi = False

In [23]:
from tqdm import tqdm
# from tqdm.notebook import tqdm as tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
from tsai.all import *

torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend('soundfile')
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import psutil
import requests

from utils import *
from wingbeat_datasets import *
from wingbeat_models import *

print(f'Total RAM      : {bytes2GB(psutil.virtual_memory().total):5.2f} GB')
print(f'Available RAM  : {bytes2GB(psutil.virtual_memory().available):5.2f} GB\n')

Total RAM      : 31.21 GB
Available RAM  : 20.83 GB



In [24]:
num_epochs = 35
batch_size = 32
batch_size_val = batch_size * 2
validation_split = .2
shuffle_dataset = True
num_workers = psutil.cpu_count()
random_seed= 42
setting = 'raw'

In [25]:
dmel1 = WingbeatsDataset(dsname="Melanogaster_RL/Y", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), TransformWingbeat(setting=setting)])).clean()
dmel2 = WingbeatsDataset(dsname="Melanogaster_RL/Z", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), TransformWingbeat(setting=setting)])).clean()
dsuz1 = WingbeatsDataset(dsname="Suzukii_RL/L",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), TransformWingbeat(setting=setting)])).clean()
dsuz2 = WingbeatsDataset(dsname="Suzukii_RL/R",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), TransformWingbeat(setting=setting)])).clean()

Found 29002 in dataset: Melanogaster_RL/Y, and 1 label(s): ['Y']
Label(s) changed to [0]
Nr. of valid wingbeats: 12819
Found 24763 in dataset: Melanogaster_RL/Z, and 1 label(s): ['Z']
Label(s) changed to [0]
Nr. of valid wingbeats: 11778
Found 21940 in dataset: Suzukii_RL/L, and 1 label(s): ['L']
Label(s) changed to [1]
Nr. of valid wingbeats: 14729
Found 14348 in dataset: Suzukii_RL/R, and 1 label(s): ['R']
Label(s) changed to [1]
Nr. of valid wingbeats: 10372


In [26]:
transformed_dataset = ConcatDataset([dmel1, dsuz1])

train_size = int(0.8 * len(transformed_dataset))
valid_size = len(transformed_dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, valid_size])
test_dataset = ConcatDataset([dmel2, dsuz2])

In [27]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size_val, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size_val, num_workers=num_workers)

In [28]:
if setting.startswith('psd'):
    model = Conv1dNetPSD()
else:
    model = Conv1dNetRAW()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
early_stopping = EarlyStopping(patience=7, verbose=False)
# print(model)

In [29]:
# from sklearn import preprocessing
# import itertools

# le = preprocessing.LabelEncoder()
# all_labels = [transformed_dataset.datasets[i].labels for i in range(len(transformed_dataset.datasets))]
# all_labels = list(itertools.chain.from_iterable(all_labels))
# le.fit(all_labels)

## Training

In [30]:
# Choosing whether to train on a gpu
train_on_gpu = torch.cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')# Number of gpus
model = model.to('cuda', dtype=torch.float)

Train on gpu: True


In [31]:
# Model training
for epoch in range(num_epochs):
    # Going through the training set
    correct_train = 0
    model.train()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(train_dataloader, desc='Training..\t'):        

        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch,y_batch = x_batch.cuda(), y_batch.cuda()
        
        optimizer.zero_grad()
        pred = model(x_batch)
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
        correct_train += (pred.argmax(axis=1) == y_batch).float().sum().item()

    train_accuracy = correct_train / (len(train_dataloader)*batch_size) * 100.
    # Going through the validation set
    correct_valid = 0
    model.eval()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(valid_dataloader, desc='Validating..\t'):
        
        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch,y_batch = x_batch.cuda(), y_batch.cuda()

        pred = model(x_batch)
        val_loss = criterion(pred, y_batch)
        correct_valid += (pred.argmax(axis=1) == y_batch).float().sum().item()
    valid_accuracy = correct_valid / (len(valid_dataloader)*batch_size_val) * 100.
    scheduler.step(val_loss)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

        # Printing results
    print(f"Epoch {epoch}: train_acc: {train_accuracy:.2f}% loss: {loss:.3f},  val_loss: {val_loss:.3f} val_acc: {valid_accuracy:.2f}%")


Training..	: 100%|██████████| 689/689 [00:07<00:00, 87.02it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 81.13it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 0: train_acc: 84.06% loss: 0.292,  val_loss: 2.679 val_acc: 46.14%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 88.31it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 81.86it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 1: train_acc: 87.40% loss: 0.368,  val_loss: 3.340 val_acc: 46.46%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.77it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.93it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 2: train_acc: 89.60% loss: 0.307,  val_loss: 6.167 val_acc: 46.35%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 87.08it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 77.41it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 3: train_acc: 91.58% loss: 0.226,  val_loss: 4.117 val_acc: 49.03%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 87.46it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.82it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch     5: reducing learning rate of group 0 to 1.0000e-03.
Epoch 4: train_acc: 92.65% loss: 0.152,  val_loss: 28.550 val_acc: 46.03%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 87.16it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.35it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 5: train_acc: 95.02% loss: 0.161,  val_loss: 0.218 val_acc: 62.52%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 87.13it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 77.83it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 6: train_acc: 95.51% loss: 0.150,  val_loss: 0.016 val_acc: 66.40%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.33it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.08it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 7: train_acc: 95.75% loss: 0.163,  val_loss: 1.601 val_acc: 55.87%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.26it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 75.94it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 8: train_acc: 96.08% loss: 0.149,  val_loss: 0.107 val_acc: 81.45%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 86.10it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 75.35it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 9: train_acc: 96.21% loss: 0.144,  val_loss: 12.653 val_acc: 50.59%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.53it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.92it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch    11: reducing learning rate of group 0 to 1.0000e-04.
Epoch 10: train_acc: 96.40% loss: 0.121,  val_loss: 5.303 val_acc: 54.18%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 86.11it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.02it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 11: train_acc: 96.84% loss: 0.126,  val_loss: 0.007 val_acc: 94.25%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.44it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 77.19it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 12: train_acc: 96.80% loss: 0.128,  val_loss: 0.015 val_acc: 93.48%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.34it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 73.48it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 13: train_acc: 96.89% loss: 0.118,  val_loss: 0.005 val_acc: 94.54%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 82.79it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 73.14it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 14: train_acc: 96.96% loss: 0.138,  val_loss: 0.001 val_acc: 93.88%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 84.90it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.11it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 15: train_acc: 96.89% loss: 0.127,  val_loss: 0.001 val_acc: 91.72%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 85.54it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 74.93it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 16: train_acc: 96.99% loss: 0.118,  val_loss: 0.004 val_acc: 94.83%


Training..	: 100%|██████████| 689/689 [00:07<00:00, 86.15it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 75.88it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 17: train_acc: 97.04% loss: 0.148,  val_loss: 0.023 val_acc: 93.88%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 85.41it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 74.80it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 18: train_acc: 96.99% loss: 0.104,  val_loss: 0.031 val_acc: 91.63%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 85.34it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 76.29it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch    20: reducing learning rate of group 0 to 1.0000e-05.
Epoch 19: train_acc: 96.97% loss: 0.158,  val_loss: 0.017 val_acc: 92.87%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 82.11it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 74.77it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 20: train_acc: 97.14% loss: 0.121,  val_loss: 0.004 val_acc: 94.77%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 85.87it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 70.94it/s] 
Training..	:   0%|          | 0/689 [00:00<?, ?it/s]

Epoch 21: train_acc: 97.14% loss: 0.125,  val_loss: 0.003 val_acc: 94.77%


Training..	: 100%|██████████| 689/689 [00:08<00:00, 85.56it/s]
Validating..	: 100%|██████████| 87/87 [00:01<00:00, 75.71it/s] 

Early stopping





In [32]:
x_batch

tensor([[[ 1.3861e-04,  5.4624e-04,  4.9648e-04,  ..., -3.7630e-03,
          -5.6457e-03, -7.6273e-03]],

        [[-5.9532e-05, -2.8347e-04, -5.0161e-04,  ...,  4.4629e-04,
          -3.1843e-04,  8.8621e-04]],

        [[-1.1551e-04, -8.1876e-04, -2.4229e-03,  ..., -9.6568e-04,
           7.1459e-04,  2.1520e-03]],

        [[-8.3522e-05, -4.5068e-04, -1.1141e-03,  ...,  2.3201e-03,
           2.1110e-03, -1.6186e-04]],

        [[-3.0210e-05, -7.5936e-05,  1.2054e-05,  ..., -1.2536e-03,
           1.0209e-04, -7.5016e-05]],

        [[-1.3239e-04, -5.9753e-04, -1.0327e-03,  ..., -5.0662e-03,
          -6.0150e-03, -6.2820e-03]]], device='cuda:0')

## Testing

In [33]:
correct_test = 0
model.eval()
for x_batch,y_batch,path_batch,idx_batch in test_dataloader:

    y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
    x_batch,y_batch = x_batch.cuda(), y_batch.cuda()

    pred = model(x_batch)
    val_loss = criterion(pred, y_batch)
    correct_test += (pred.argmax(axis=1) == y_batch).float().sum().item()
test_accuracy = correct_test / (len(test_dataloader)*batch_size_val) * 100.
print(test_accuracy)

90.63400576368876


In [38]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for x_batch,y_batch,path_batch,idx_batch in loader:

        preds = model(x_batch.float())
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

In [39]:
get_all_preds(model, test_dataloader)

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same