In [1]:
from birdcall.data import *
from birdcall.metrics import *
from birdcall.ops import *

import torch
import torchvision
from torch import nn
import numpy as np
import pandas as pd
from pathlib import Path
import soundfile as sf

In [2]:
BS = 100
MAX_LR = 1e-3
T_MAX = 10

In [3]:
classes = pd.read_pickle('data/classes.pkl')
splits = pd.read_pickle('data/all_splits.pkl')
all_train_items = pd.read_pickle('data/all_train_items_npy.pkl')

train_items = np.array(all_train_items)[splits[0][0]].tolist()
val_items = np.array(all_train_items)[splits[0][1]].tolist()

In [4]:
len(train_items), len(val_items)

(17099, 4275)

In [5]:
#export data

class MelspecDataset(torch.utils.data.Dataset):
    def __init__(self, items, classes):
        self.items = items
        self.vocab = classes
        
    def __getitem__(self, idx):
        cls, path, _ = self.items[idx]
        example = self.get_spec(path)
        return example, self.one_hot_encode(cls)
    
    def get_spec(self, path):
        frames_per_spec = 212  
        x = np.load(path)

        specs = []
        for _ in range(3):
            if x.shape[1] < frames_per_spec:
                spec = np.zeros((80, frames_per_spec))
                start_frame = np.random.randint(frames_per_spec-x.shape[1])
                spec[:, start_frame:start_frame+x.shape[1]] = x
            else:
                start_frame = int(np.random.rand() * (x.shape[1] - frames_per_spec))
                spec = x[:, start_frame:start_frame+frames_per_spec]
            specs.append(spec)

        return np.stack(specs).reshape(3, 80, frames_per_spec).astype(np.float32)
    
    def show(self, idx):
        x = self[idx][0]
        return plt.imshow(x.transpose(1,2,0)[:, :, 0])
        
    def one_hot_encode(self, cls):
        y = self.vocab.index(cls)
        one_hot = np.zeros((len(self.vocab)))
        one_hot[y] = 1
        return one_hot
    def __len__(self):
        return len(self.items)

In [6]:
train_ds = MelspecDataset(train_items, classes)
val_ds = MelspecDataset(val_items, classes)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BS, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)
valid_dl = torch.utils.data.DataLoader(val_ds, batch_size=BS, num_workers=NUM_WORKERS, pin_memory=True, shuffle=False)

In [7]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm2d(3)
        self.cnn = nn.Sequential(*list(torchvision.models.resnet50(True).children())[:-2], nn.AdaptiveMaxPool2d(1))
        self.classifier = nn.Sequential(*[
            nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(p=0.2), nn.BatchNorm1d(1024),
            nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(p=0.2), nn.BatchNorm1d(1024),
            nn.Linear(1024, len(classes))
        ])
    
    def forward(self, x):
#         max_per_example = x.view(x.shape[0], -1).max(1)[0]
#         x /= max_per_example[:, None, None, None]
        x = self.bn(x)
        x = self.cnn(x)[:, :, 0, 0]
        x = self.classifier(x)
        return x

In [8]:
model = Model().cuda()

In [9]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score
import time

In [10]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), MAX_LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_MAX)

In [11]:
sc_ds = SoundscapeMelspecPoolDataset(pd.read_pickle('data/soundscape_items_npy.pkl'), classes)
sc_dl = torch.utils.data.DataLoader(sc_ds, batch_size=2*BS, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
t0 = time.time()
for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        model.train()
        inputs, labels = data[0].cuda(), data[1].cuda()
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        if np.isnan(loss.item()): 
            raise Exception(f'!!! nan encountered in loss !!! epoch: {epoch}\n')
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()

    model.eval();
    preds = []
    targs = []

    for _ in range(5):
        with torch.no_grad():
            for data in valid_dl:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = model(inputs)
                preds.append(outputs.cpu().detach())
                targs.append(labels.cpu().detach())

    preds = torch.cat(preds)
    targs = torch.cat(targs)

    f1s = []
    ts = []
    for t in np.linspace(0.4, 1, 61):
        f1s.append(f1_score(preds.sigmoid() > t, targs, average='micro'))
        ts.append(t)

    sc_preds = []
    sc_targs = []
    with torch.no_grad():
        for data in sc_dl:
            inputs, labels = data[0].squeeze(1).cuda(), data[1].cuda()
            outputs = model(inputs)
            sc_preds.append(outputs.cpu().detach())
            sc_targs.append(labels.cpu().detach())

    sc_preds = torch.cat(sc_preds)
    sc_targs = torch.cat(sc_targs)
    sc_f1 = f1_score(sc_preds.sigmoid() > 0.5, sc_targs, average='micro')

    sc_f1s = []
    sc_ts = []
    for t in np.linspace(0.4, 1, 61):
        sc_f1s.append(f1_score(sc_preds.sigmoid() > t, sc_targs, average='micro'))
        sc_ts.append(t)

    f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
    print(f'[{epoch + 1}, {(time.time() - t0)/60:.1f}] loss: {running_loss / (len(train_dl)-1):.3f}, f1: {max(f1s):.3f}, sc_f1: {max(sc_f1s):.3f}')
    running_loss = 0.0

    torch.save(model.state_dict(), f'models/{epoch+1}_single_example_per_epoch_{round(f1, 2)}.pth')

[1, 1.7] loss: 0.367, f1: 0.000, sc_f1: 0.000
[2, 3.4] loss: 0.035, f1: 0.000, sc_f1: 0.000
[3, 5.1] loss: 0.027, f1: 0.000, sc_f1: 0.000
[4, 6.8] loss: 0.026, f1: 0.001, sc_f1: 0.000
[5, 8.6] loss: 0.026, f1: 0.000, sc_f1: 0.000
[6, 10.3] loss: 0.026, f1: 0.000, sc_f1: 0.000
[7, 12.0] loss: 0.026, f1: 0.000, sc_f1: 0.000
[8, 13.7] loss: 0.025, f1: 0.001, sc_f1: 0.000
[9, 15.4] loss: 0.025, f1: 0.000, sc_f1: 0.000
[10, 17.1] loss: 0.025, f1: 0.000, sc_f1: 0.000
[11, 18.8] loss: 0.024, f1: 0.000, sc_f1: 0.000
[12, 20.5] loss: 0.024, f1: 0.000, sc_f1: 0.000
[13, 22.2] loss: 0.024, f1: 0.000, sc_f1: 0.000
[14, 23.9] loss: 0.024, f1: 0.000, sc_f1: 0.000
[15, 25.6] loss: 0.024, f1: 0.002, sc_f1: 0.000
[16, 27.4] loss: 0.023, f1: 0.000, sc_f1: 0.000
[17, 29.1] loss: 0.023, f1: 0.000, sc_f1: 0.000
[18, 30.8] loss: 0.023, f1: 0.001, sc_f1: 0.000
[19, 32.5] loss: 0.023, f1: 0.000, sc_f1: 0.000
[20, 34.2] loss: 0.023, f1: 0.001, sc_f1: 0.000
[21, 35.9] loss: 0.022, f1: 0.002, sc_f1: 0.000
[22, 3