In [1]:
%reset -f
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config Completer.use_jedi = False

In [2]:
from tqdm import tqdm
# from tqdm.notebook import tqdm as tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
from torchaudio.transforms import Spectrogram, AmplitudeToDB

# torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend('sox_io')
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import psutil
import requests

from utils import *
from wingbeat_datasets import *
from wingbeat_models import *

# print(f'Total RAM      : {bytes2GB(psutil.virtual_memory().total):5.2f} GB')
# print(f'Available RAM  : {bytes2GB(psutil.virtual_memory().available):5.2f} GB\n')

  '"sox" backend is being deprecated. '


Available workers: 16


In [3]:
num_epochs = 35
batch_size = 32
batch_size_val = batch_size * 2
validation_split = .2
shuffle_dataset = True
num_workers = psutil.cpu_count()
random_seed= 42


## Datasets and Dataloaders

In [4]:
setting = 'stftraw'

dmel1 = WingbeatsDataset(dsname="Melanogaster_RL/Y", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dmel2 = WingbeatsDataset(dsname="Melanogaster_RL/Z", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dsuz1 = WingbeatsDataset(dsname="Suzukii_RL/Y",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dsuz2 = WingbeatsDataset(dsname="Suzukii_RL/R",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()

transformed_dataset = ConcatDataset([dmel1, dsuz1])

train_size = int(0.8 * len(transformed_dataset))
valid_size = len(transformed_dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, valid_size])
test_dataset = ConcatDataset([dmel2, dsuz2])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size_val, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size_val, num_workers=num_workers)

Found 29002 in dataset: Melanogaster_RL/Y, and 1 label(s): ['D. melanogaster']
Label(s) changed to [0]
Nr. of valid wingbeats: 12819
Found 24763 in dataset: Melanogaster_RL/Z, and 1 label(s): ['D. melanogaster']
Label(s) changed to [0]
Nr. of valid wingbeats: 11778
Found 25732 in dataset: Suzukii_RL/Y, and 1 label(s): ['D. suzukii']
Label(s) changed to [1]
Nr. of valid wingbeats: 17088
Found 14348 in dataset: Suzukii_RL/R, and 1 label(s): ['D. suzukii']
Label(s) changed to [1]
Nr. of valid wingbeats: 10372


## Model definition

In [5]:
from torchvision.models import resnet34, densenet121, resnet152
import torch.optim as optim

if setting.startswith('psd'):
    model = DrosophilaNetPSD()
elif setting == 'raw':
    model = DrosophilaNetRAW()
elif setting == 'stft':
    model = resnet152(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs,2)
elif setting == 'stftraw':
    modelA = DrosophilaNetRAW() 
    modelB = resnet152(pretrained=False)
    num_ftrs = modelB.fc.in_features
    modelB.fc = nn.Linear(num_ftrs,2)

    model = ModelEnsemble(modelA, modelB)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
early_stopping = EarlyStopping(patience=7, verbose=2)
# print(model)

In [6]:
x1, x2 = torch.randn(64, 3, 129,120), torch.randn(64, 1, 5000)

output = model(x2, x1)

In [7]:
output.shape

torch.Size([64, 2])

## Training

In [8]:
# Choosing whether to train on a gpu
train_on_gpu = torch.cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')# Number of gpus
model = model.to('cuda', dtype=torch.float)

Train on gpu: True


In [9]:
import warnings
warnings.filterwarnings("ignore")

# Model training
for epoch in range(num_epochs):
    # Going through the training set
    correct_train = 0
    model.train()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(train_dataloader, desc='Training..\t'):        

        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch[0], x_batch[1] ,y_batch = x_batch[0].cuda(), x_batch[1].cuda(), y_batch.cuda()
        
        optimizer.zero_grad()
        pred = model(x_batch[0], x_batch[1])
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
        correct_train += (pred.argmax(axis=1) == y_batch).float().sum().item()

    train_accuracy = correct_train / (len(train_dataloader)*batch_size) * 100.
    # Going through the validation set
    correct_valid = 0
    model.eval()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(valid_dataloader, desc='Validating..\t'):
        
        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch[0], x_batch[1] ,y_batch = x_batch[0].cuda(), x_batch[1].cuda(), y_batch.cuda()

        pred = model(x_batch[0], x_batch[1])
        val_loss = criterion(pred, y_batch)
        correct_valid += (pred.argmax(axis=1) == y_batch).float().sum().item()
    valid_accuracy = correct_valid / (len(valid_dataloader)*batch_size_val) * 100.
    scheduler.step(val_loss)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

        # Printing results
    print(f"Epoch {epoch}: train_acc: {train_accuracy:.2f}% loss: {loss:.3f},  val_loss: {val_loss:.3f} val_acc: {valid_accuracy:.2f}%")


Training..	: 100%|██████████| 748/748 [03:30<00:00,  3.55it/s]
Validating..	: 100%|██████████| 94/94 [00:14<00:00,  6.68it/s]


Validation loss decreased (inf --> 0.585617).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 0: train_acc: 82.66% loss: 0.123,  val_loss: 0.586 val_acc: 70.05%


Training..	: 100%|██████████| 748/748 [03:28<00:00,  3.58it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.12it/s]


Validation loss decreased (0.585617 --> 0.374369).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 1: train_acc: 89.49% loss: 0.111,  val_loss: 0.374 val_acc: 81.22%


Training..	: 100%|██████████| 748/748 [03:29<00:00,  3.57it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.14it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7
Epoch 2: train_acc: 90.83% loss: 0.110,  val_loss: 0.450 val_acc: 77.94%


Training..	: 100%|██████████| 748/748 [03:29<00:00,  3.57it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.12it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7
Epoch 3: train_acc: 91.52% loss: 0.095,  val_loss: 0.390 val_acc: 83.21%


Training..	: 100%|██████████| 748/748 [03:29<00:00,  3.57it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.20it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 3 out of 7
Epoch 4: train_acc: 92.13% loss: 0.093,  val_loss: 0.378 val_acc: 84.44%


Training..	: 100%|██████████| 748/748 [03:29<00:00,  3.57it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.07it/s]


Validation loss decreased (0.374369 --> 0.321917).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 5: train_acc: 92.64% loss: 0.094,  val_loss: 0.322 val_acc: 88.15%


Training..	: 100%|██████████| 748/748 [03:38<00:00,  3.42it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.06it/s]


Validation loss decreased (0.321917 --> 0.254378).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 6: train_acc: 93.19% loss: 0.085,  val_loss: 0.254 val_acc: 91.69%


Training..	: 100%|██████████| 748/748 [03:35<00:00,  3.48it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.14it/s]


Validation loss decreased (0.254378 --> 0.249778).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 7: train_acc: 93.71% loss: 0.086,  val_loss: 0.250 val_acc: 92.55%


Training..	: 100%|██████████| 748/748 [03:30<00:00,  3.55it/s]
Validating..	: 100%|██████████| 94/94 [00:13<00:00,  7.07it/s]


Validation loss decreased (0.249778 --> 0.236054).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 8: train_acc: 94.12% loss: 0.086,  val_loss: 0.236 val_acc: 92.90%


Training..	: 100%|██████████| 748/748 [03:25<00:00,  3.63it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.25it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7
Epoch 9: train_acc: 94.53% loss: 0.080,  val_loss: 0.268 val_acc: 92.37%


Training..	: 100%|██████████| 748/748 [03:24<00:00,  3.65it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.39it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7
Epoch 10: train_acc: 94.87% loss: 0.049,  val_loss: 0.252 val_acc: 92.74%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.46it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 3 out of 7
Epoch 11: train_acc: 95.35% loss: 0.022,  val_loss: 0.252 val_acc: 91.64%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.43it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch    13: reducing learning rate of group 0 to 1.0000e-03.
EarlyStopping counter: 4 out of 7
Epoch 12: train_acc: 95.62% loss: 0.023,  val_loss: 0.518 val_acc: 88.25%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.40it/s]


Validation loss decreased (0.236054 --> 0.107374).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 13: train_acc: 96.77% loss: 0.089,  val_loss: 0.107 val_acc: 93.80%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.44it/s]


Validation loss decreased (0.107374 --> 0.099122).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 14: train_acc: 97.35% loss: 0.089,  val_loss: 0.099 val_acc: 93.80%


Training..	: 100%|██████████| 748/748 [03:21<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.41it/s]


Validation loss decreased (0.099122 --> 0.085694).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 15: train_acc: 97.61% loss: 0.080,  val_loss: 0.086 val_acc: 93.92%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.44it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7
Epoch 16: train_acc: 97.95% loss: 0.033,  val_loss: 0.089 val_acc: 93.82%


Training..	: 100%|██████████| 748/748 [03:21<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.46it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7
Epoch 17: train_acc: 98.09% loss: 0.030,  val_loss: 0.092 val_acc: 93.95%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.41it/s]


Validation loss decreased (0.085694 --> 0.085285).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 18: train_acc: 98.34% loss: 0.028,  val_loss: 0.085 val_acc: 94.03%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.38it/s]


Validation loss decreased (0.085285 --> 0.079363).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 19: train_acc: 98.58% loss: 0.032,  val_loss: 0.079 val_acc: 93.92%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.40it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7
Epoch 20: train_acc: 98.72% loss: 0.018,  val_loss: 0.086 val_acc: 93.97%


Training..	: 100%|██████████| 748/748 [03:21<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.43it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7
Epoch 21: train_acc: 98.81% loss: 0.010,  val_loss: 0.087 val_acc: 94.00%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.43it/s]


Validation loss decreased (0.079363 --> 0.071379).  Saving model ...


Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch 22: train_acc: 99.09% loss: 0.011,  val_loss: 0.071 val_acc: 93.83%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.44it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7
Epoch 23: train_acc: 99.10% loss: 0.008,  val_loss: 0.077 val_acc: 93.90%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.45it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7
Epoch 24: train_acc: 99.23% loss: 0.004,  val_loss: 0.082 val_acc: 93.88%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.43it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 3 out of 7
Epoch 25: train_acc: 99.36% loss: 0.012,  val_loss: 0.105 val_acc: 93.87%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.45it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

Epoch    27: reducing learning rate of group 0 to 1.0000e-04.
EarlyStopping counter: 4 out of 7
Epoch 26: train_acc: 99.42% loss: 0.045,  val_loss: 0.143 val_acc: 93.87%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.41it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 5 out of 7
Epoch 27: train_acc: 99.54% loss: 0.004,  val_loss: 0.095 val_acc: 93.98%


Training..	: 100%|██████████| 748/748 [03:21<00:00,  3.72it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.42it/s]
Training..	:   0%|          | 0/748 [00:00<?, ?it/s]

EarlyStopping counter: 6 out of 7
Epoch 28: train_acc: 99.58% loss: 0.002,  val_loss: 0.086 val_acc: 93.98%


Training..	: 100%|██████████| 748/748 [03:20<00:00,  3.73it/s]
Validating..	: 100%|██████████| 94/94 [00:12<00:00,  7.42it/s]

EarlyStopping counter: 7 out of 7
Early stopping





In [10]:
# print(x_batch, y_batch)

## Testing

In [11]:
correct_test = 0
model.eval()
for x_batch,y_batch,path_batch,idx_batch in tqdm(test_dataloader, desc="Testing..\t"):

    y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
    x_batch[0], x_batch[1] ,y_batch = x_batch[0].cuda(), x_batch[1].cuda(), y_batch.cuda()

    pred = model(x_batch[0], x_batch[1])
    val_loss = criterion(pred, y_batch)
    correct_test += (pred.argmax(axis=1) == y_batch).float().sum().item()
test_accuracy = correct_test / (len(test_dataloader)*batch_size_val) * 100.
print(test_accuracy)

Testing..	: 100%|██████████| 347/347 [00:44<00:00,  7.86it/s]

88.32853025936599





In [None]:
Try mel spectrograms


In [None]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for x_batch,y_batch,path_batch,idx_batch in loader:

        preds = model(x_batch.float())
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

In [None]:
get_all_preds(model, test_dataloader)