In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt

import IPython.display as ipd
import librosa

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from torch import nn
from torch.nn import functional as F
import torch
from birds_utils import get_fourier_weights, DataGeneratorV2, Dataset, get_pytorch_model, get_pytorch_model_all_conv

In [None]:
import torchvision

In [None]:
sr = 44100
# train_files = np.load(f'train_files_{sr}.npy')
# train_labels = np.load(f'train_labels_{sr}.npy')
# val_files = np.load(f'val_files_{sr}.npy')
# val_labels = np.load(f'val_labels_{sr}.npy')

train_files = np.load(f'train_files.npy')
train_labels = np.load(f'train_labels.npy')
val_files = np.load(f'val_files.npy')
val_labels = np.load(f'val_labels.npy')

train_files = [f.replace('npy_22050', 'npy_44100') for f in train_files]
val_files = [f.replace('npy_22050', 'npy_44100') for f in val_files]

N = 2
min_std = 0.5
duration = 5

params = {'batch_size': 32,
          'shuffle': True,
          'num_workers': 1}
classes = np.unique(train_labels)

training_set = Dataset(list(train_files)*N, classes, chunk_seconds=duration, sr=sr, min_std=min_std, multilabel=True)
training_generator = torch.utils.data.DataLoader(training_set, **params)

In [None]:
for X, y in training_generator:
    break

In [None]:
X.shape, y.shape

In [None]:
X = X.to(device)

In [None]:
X.type()

In [None]:
window_size = 1024

In [None]:
# nn.Sequential(*list(model_resnet.children())[:-1])

In [None]:
model = get_pytorch_model_all_conv(window_size, resnet='resnet18', pretrained=True, n_classes=10, init_fourier=True, train_fourier=False).to(device)

In [None]:
# model = torch.load('model_1_sec_18.pth')

In [None]:
# Fourier not trainable
list(model.cos.parameters())[0].requires_grad

In [None]:
# plt.plot(model.cos.weight.data[0, 0, :])
# plt.plot(model.cos.weight.data[1, 0, :])

In [None]:
%time
model.eval()
spec, y_res = model(X)

In [None]:
spec.shape, y_res.shape

In [None]:
N = 5
plt.imshow(np.flipud(spec.detach().cpu().numpy()[N, :,:]), cmap='gray')
ipd.Audio(X.detach().cpu().numpy()[N].reshape(-1), rate=sr)


In [None]:
spec[N].max(), spec[N].min()

In [None]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    correct_pred = (y_pred_tags == y_test).float()
    
    return correct_pred.sum(), len(correct_pred)

def multilabel_acc(y_pred, y_test):
    ok_count = ((torch.sigmoid(y_pred)>0.5) * y_test).sum()
    return ok_count, len(y_pred)

def validate(dgen_val, acc_func=multi_acc):
    model.eval()  
    with torch.no_grad():
        running_loss = 0.0
        total_ok = 0
        total_predictions = 0
        batches_per_epoch = len(dgen_val)
        for i, (X, y) in enumerate(dgen_val):
#             inputs, labels = torch.from_numpy(X).float().to(device), torch.from_numpy(y).long().to(device)
            inputs, labels = X.to(device), y.to(device)
            _, y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            ok, total = acc_func(y_pred, labels)
            total_ok = total_ok + ok
            running_loss = running_loss + loss
            total_predictions = total_predictions + total
            print(f'\r{i+1}/{batches_per_epoch} - val loss: {running_loss/(i+1)}, val acc: {total_ok/total_predictions}', end='')
    model.train()
    return (running_loss/(i+1)).detach().item(), (total_ok/total_predictions).detach().item()

In [None]:
def train_model(dataset, validation_generator, criterion, acc_func=multi_acc, epochs=1, best_val_acc = 0):
    model.train()
    batches_per_epoch = len(dataset)
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        total_ok = 0
        total_predictions = 0
        for i, (X, y) in enumerate(dataset):
            # Get the inputs; data is a list of [inputs, labels]
            inputs, labels = X.to(device), y.to(device)
            # (1) Initialise gradients
            optimizer.zero_grad()
            # (2) Forward pass
            _, y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            running_loss = running_loss + loss
            # (3) Backward
            loss.backward()
            # (4) Compute the loss and update the weights
            optimizer.step()
            ok, total = acc_func(y_pred, labels)
            total_ok = total_ok + ok
            total_predictions = total_predictions + total
            
            print(f'\r{epoch+1}/{epochs} - {i+1}/{batches_per_epoch} - loss: {running_loss/(i+1)}, acc: {total_ok/total_predictions}', end='')
        
        print()
        loss, acc = validate(validation_generator, acc_func=multilabel_acc)
        if acc>best_val_acc:
            best_val_acc = acc
            print()
            print('Best model saved')
            torch.save(model.state_dict(), f'model_{epoch+1}_.pth')
        else:
            print()
        print('--------------------------------------------------------------------------')
        

In [None]:
params = {'batch_size': 16,
          'shuffle': True,
          'num_workers': 1}
classes = np.unique(train_labels)

N = 7
training_set = Dataset(list(train_files)*N, classes, chunk_seconds=duration, sr=sr, min_std=min_std, multilabel=True)
validation_set = Dataset(list(val_files)*N, classes, chunk_seconds=duration, sr=sr, min_std=min_std, multilabel=True)
training_generator = torch.utils.data.DataLoader(training_set, **params)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [25]:
LEARNING_RATE = 0.001
# criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
train_model(training_generator, validation_generator, criterion, epochs=100, acc_func=multilabel_acc, best_val_acc=0.5)

1/100 - 346/346 - loss: 0.19698800146579742, acc: 0.39199709892272954
78/78 - val loss: 0.35373106598854065, val acc: 0.37288135290145874
--------------------------------------------------------------------------
2/100 - 346/346 - loss: 0.11426668614149094, acc: 0.6784356236457825
78/78 - val loss: 0.1679152548313141, val acc: 0.58111375570297241
Best model saved
--------------------------------------------------------------------------
3/100 - 346/346 - loss: 0.08152970671653748, acc: 0.7874343395233154
78/78 - val loss: 0.18607202172279358, val acc: 0.5891848206520081
Best model saved
--------------------------------------------------------------------------
4/100 - 346/346 - loss: 0.06542854011058807, acc: 0.8292594552040119
78/78 - val loss: 0.13372105360031128, val acc: 0.6997578740119934
Best model saved
--------------------------------------------------------------------------
5/100 - 346/346 - loss: 0.04875045642256737, acc: 0.88321560621261675
78/78 - val loss: 0.1296156793832

In [30]:
torch.save(model.state_dict(), 'model_44100_last.pth')

In [None]:
# print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
X_img = torch.from_numpy(np.random.rand(2, 1, 44100)).float()