In [1]:
#default_exp ops

In [2]:
from birdcall.data import *
from birdcall.metrics import *

import torch
import torchvision
from torch import nn
import numpy as np
import pandas as pd

In [3]:
classes = pd.read_pickle('data/classes.pkl')
train_ds = MelspecPoolDataset(pd.read_pickle('data/train_set.pkl'), classes, len_mult=60, normalize=False)
valid_ds = MelspecPoolDataset(pd.read_pickle('data/val_set.pkl'), classes, len_mult=50, normalize=False)

In [4]:
len(train_ds), len(valid_ds)

(15840, 13200)

In [5]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=NUM_WORKERS)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=2*16, shuffle=False, num_workers=NUM_WORKERS)

In [6]:
#export
import torch

def lme_pool(x, alpha=1.0): # log-mean-exp pool
    '''alpha -> approximates maxpool, alpha -> 0 approximates mean pool'''
    T = x.shape[1]
    mult_log = torch.log(torch.tensor(1/T))
    return 1/alpha * (mult_log + torch.logsumexp((alpha * x), dim=1))

In [7]:
class FrontEnd(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(80, affine=False)
        self.register_parameter('alpha', torch.nn.Parameter(torch.tensor(0.)))
        
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = x ** torch.sigmoid(self.alpha)
        x = x.view(-1, y_dim, x_dim)
        x = self.bn(x)
        return x.view(bs, im_num, ch, y_dim, x_dim)

In [8]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.frontend = FrontEnd()
        self.cnn = nn.Sequential(*list(torchvision.models.resnet34(True).children())[:-2])
        self.classifier = nn.Sequential(*[
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, len(classes))
        ])
    
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = self.frontend(x)
        x = self.cnn(x.view(-1, ch, y_dim, x_dim))
        x = x.mean((2,3))
        x = self.classifier(x)
        x = x.view(bs, im_num, -1)
        x = lme_pool(x)
        return x

In [9]:
# model = Model().cuda()
# x = model(b[0].cuda())

# x.shape

In [10]:
model = Model().cuda()

In [11]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score
import time

In [12]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), 1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)

In [13]:
t0 = time.time()
for epoch in range(130):
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        model.train()
        inputs, labels = data[0].cuda(), data[1].cuda()
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if np.isnan(loss.item()): 
            print(f'!!! nan encountered in loss !!! alpha: epoch: {epoch}\n')
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()


    if epoch % 5 == 4:
        model.eval();
        preds = []
        targs = []

        with torch.no_grad():
            for data in valid_dl:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = model(inputs)
                preds.append(outputs.cpu().detach())
                targs.append(labels.cpu().detach())

            preds = torch.cat(preds)
            targs = torch.cat(targs)

        accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
        f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
        print(f'[{epoch + 1}, {(time.time() - t0)/60:.1f}] loss: {running_loss / (len(train_dl)-1):.3f}, acc: {accuracy:.3f}, f1: {f1:.3f}')
        running_loss = 0.0

        torch.save(model.state_dict(), f'models/{epoch+1}_lmepool_frontend_{round(f1, 2)}.pth')

[5, 22.3] loss: 0.023, acc: 0.000, f1: 0.000
[10, 44.6] loss: 0.020, acc: 0.001, f1: 0.003
[15, 66.9] loss: 0.019, acc: 0.015, f1: 0.039
[20, 89.3] loss: 0.017, acc: 0.029, f1: 0.060
[25, 111.6] loss: 0.015, acc: 0.097, f1: 0.178
[30, 133.9] loss: 0.013, acc: 0.197, f1: 0.333
[35, 156.4] loss: 0.012, acc: 0.314, f1: 0.471
[40, 178.8] loss: 0.010, acc: 0.377, f1: 0.540
[45, 201.3] loss: 0.009, acc: 0.413, f1: 0.576
[50, 223.7] loss: 0.008, acc: 0.459, f1: 0.614
[55, 246.2] loss: 0.008, acc: 0.462, f1: 0.621
[60, 268.6] loss: 0.007, acc: 0.506, f1: 0.652
[65, 291.1] loss: 0.006, acc: 0.534, f1: 0.674
[70, 313.6] loss: 0.006, acc: 0.537, f1: 0.681
[75, 336.1] loss: 0.005, acc: 0.557, f1: 0.696
[80, 358.5] loss: 0.005, acc: 0.555, f1: 0.695
[85, 381.0] loss: 0.005, acc: 0.578, f1: 0.710
[90, 403.4] loss: 0.004, acc: 0.562, f1: 0.702
[95, 425.7] loss: 0.004, acc: 0.578, f1: 0.708
[100, 448.1] loss: 0.004, acc: 0.577, f1: 0.707
[105, 470.4] loss: 0.004, acc: 0.572, f1: 0.702
[110, 492.8] los

In [15]:
f1s = []
ts = []
for t in np.linspace(0.4, 1, 61):
    f1s.append(f1_score(preds.sigmoid() > t, targs, average='micro'))
    ts.append(t)

In [16]:
max(f1s), accuracy_score(preds.sigmoid() > ts[np.argmax(f1s)], targs)

(0.7215644209744626, 0.5946969696969697)

In [17]:
ts[np.argmax(f1s)]

0.5700000000000001

In [21]:
from birdcall.metrics import *

preds_to_tp_fp_fn(preds, targs)

(tensor(8610), tensor(2106), tensor(4590))