In [None]:
%reset -f
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config Completer.use_jedi = False

In [None]:
from tqdm import tqdm
# from tqdm.notebook import tqdm as tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
from torchaudio.transforms import Spectrogram, AmplitudeToDB

# torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend('sox_io')
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import psutil
import requests

from utils import *
from wingbeat_datasets import *
from wingbeat_models import *

# print(f'Total RAM      : {bytes2GB(psutil.virtual_memory().total):5.2f} GB')
# print(f'Available RAM  : {bytes2GB(psutil.virtual_memory().available):5.2f} GB\n')

In [None]:
num_epochs = 35
batch_size = 32
batch_size_val = batch_size * 2
validation_split = .2
shuffle_dataset = True
num_workers = psutil.cpu_count()
random_seed= 42
setting = 'melstft'# 'stft', 'reassigned_stft', 'melstft'
modelname = 'densenet121' # 'resnet34','resnet152'

## Datasets and Dataloaders

In [None]:
dmel1 = WingbeatsDataset(dsname="Melanogaster_RL/Y", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dmel2 = WingbeatsDataset(dsname="Melanogaster_RL/Z", custom_label=[0], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dsuz1 = WingbeatsDataset(dsname="Suzukii_RL/Y",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()
dsuz2 = WingbeatsDataset(dsname="Suzukii_RL/R",      custom_label=[1], transform=transforms.Compose([FilterWingbeat(setting='bandpass'), NormalizeWingbeat(), TransformWingbeat(setting=setting)])).clean()

In [None]:
transformed_dataset = ConcatDataset([dmel1, dsuz1])

train_size = int(0.8 * len(transformed_dataset))
valid_size = len(transformed_dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(transformed_dataset, [train_size, valid_size])
test_dataset = ConcatDataset([dmel2, dsuz2])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size_val, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size_val, num_workers=num_workers)

## Model definition

In [None]:
from torchvision.models import resnet34, densenet121, resnet152
import torch.optim as optim

if setting in ['stft', 'reassigned_stft', 'melstft']:
    if modelname == 'densenet121':
        model = densenet121(pretrained=False)
        num_ftrs = model.classifier.in_features
        model.fc = nn.Linear(num_ftrs,2)
    elif modelname in ['resnet34', 'resnet152']:
        model = densenet121(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs,2)        

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
early_stopping = EarlyStopping(patience=7, verbose=2)
# print(model)

## Training

In [None]:
# Choosing whether to train on a gpu
train_on_gpu = torch.cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')# Number of gpus
model = model.to('cuda', dtype=torch.float)

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Model training
for epoch in range(num_epochs):
    # Going through the training set
    correct_train = 0
    model.train()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(train_dataloader, desc='Training..\t'):        

        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch,y_batch = x_batch.cuda(), y_batch.cuda()
        
        optimizer.zero_grad()
        pred = model(x_batch)
        loss = criterion(pred, y_batch)
        loss.backward()
        optimizer.step()
        correct_train += (pred.argmax(axis=1) == y_batch).float().sum().item()

    train_accuracy = correct_train / (len(train_dataloader)*batch_size) * 100.
    # Going through the validation set
    correct_valid = 0
    model.eval()
    for x_batch,y_batch,path_batch,idx_batch in tqdm(valid_dataloader, desc='Validating..\t'):
        
        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch,y_batch = x_batch.cuda(), y_batch.cuda()

        pred = model(x_batch)
        val_loss = criterion(pred, y_batch)
        correct_valid += (pred.argmax(axis=1) == y_batch).float().sum().item()
    valid_accuracy = correct_valid / (len(valid_dataloader)*batch_size_val) * 100.
    scheduler.step(val_loss)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

        # Printing results
    print(f"Epoch {epoch}: train_acc: {train_accuracy:.2f}% loss: {loss:.3f},  val_loss: {val_loss:.3f} val_acc: {valid_accuracy:.2f}%")


In [None]:
# print(x_batch, y_batch)

## Testing

In [None]:
correct_test = 0
model.eval()
for x_batch,y_batch,path_batch,idx_batch in tqdm(test_dataloader, desc="Testing..\t"):

    y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
    x_batch,y_batch = x_batch.cuda(), y_batch.cuda()

    pred = model(x_batch)
    val_loss = criterion(pred, y_batch)
    correct_test += (pred.argmax(axis=1) == y_batch).float().sum().item()
test_accuracy = correct_test / (len(test_dataloader)*batch_size_val) * 100.
print(test_accuracy)

In [None]:
model = model.to('cuda', dtype=torch.float)
model.eval()

@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for x_batch,y_batch,path_batch,idx_batch in loader:

        y_batch = torch.as_tensor(y_batch).type(torch.LongTensor)
        x_batch,y_batch = x_batch.cuda(), y_batch.cuda()


        preds = model(x_batch.float())
        all_preds = torch.cat(
            (all_preds, preds)
            ,dim=0
        )
    return all_preds

In [None]:
model = model.to('cpu', dtype=torch.float)
z = get_all_preds(model, test_dataloader)