In [1]:
from birdcall.data import *
from birdcall.metrics import *
from birdcall.ops import *

import torch
import torchvision
from torch import nn
import numpy as np
import pandas as pd

In [4]:
MelspecShortishValidatioDataset

birdcall.data.MelspecShortishValidatioDataset

In [None]:
classes = pd.read_pickle('data/classes.pkl')
train_ds = MelspecShortishDataset(pd.read_pickle('data/train_set.pkl'), classes)
train_dl = torch.utils.data.DataLoader(train_ds, batch_sampler=BatchSampler(len_mult=60), num_workers=NUM_WORKERS, pin_memory=True)

# valid_ds = MelspecShortishDataset(pd.read_pickle('data/val_set.pkl'), classes)
# valid_dl = torch.utils.data.DataLoader(train_ds, batch_sampler=BatchSampler(len_mult=60), num_workers=NUM_WORKERS, pin_memory=True)

valid_ds = MelspecPoolDataset(pd.read_pickle('data/val_set.pkl'), classes, len_mult=50, normalize=False)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=2*16, shuffle=False, num_workers=NUM_WORKERS)

In [2]:
class FrontEnd(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(80, affine=False)
        self.register_parameter('alpha', torch.nn.Parameter(torch.tensor(0.)))
        
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = x ** torch.sigmoid(self.alpha)
        x = x.view(-1, y_dim, x_dim)
        x = self.bn(x)
        return x.view(bs, im_num, ch, y_dim, x_dim)

In [3]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.frontend = FrontEnd()
        self.cnn = nn.Sequential(*list(torchvision.models.resnet34(True).children())[:-2])
        self.classifier = nn.Sequential(*[
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, len(classes))
        ])
    
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = self.frontend(x)
        x = self.cnn(x.view(-1, ch, y_dim, x_dim))
        x = x.mean((2,3))
        x = self.classifier(x)
        x = x.view(bs, im_num, -1)
        x = lme_pool(x)
        return x

In [4]:
model = Model().cuda()

In [5]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score
import time

In [6]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), 1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)

In [8]:
t0 = time.time()
for epoch in range(260):
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        model.train()
        inputs, labels = data[0].cuda(), data[1].cuda()
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if np.isnan(loss.item()): 
            print(f'!!! nan encountered in loss !!! alpha: epoch: {epoch}\n')
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()


    if epoch % 5 == 4:
        model.eval();
        preds = []
        targs = []

        with torch.no_grad():
            for data in valid_dl:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = model(inputs)
                preds.append(outputs.cpu().detach())
                targs.append(labels.cpu().detach())

            preds = torch.cat(preds)
            targs = torch.cat(targs)

        accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
        f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
        print(f'[{epoch + 1}, {(time.time() - t0)/60:.1f}] loss: {running_loss / len(train_dl):.3f}, acc: {accuracy:.3f}, f1: {f1:.3f}')
        running_loss = 0.0

        torch.save(model.state_dict(), f'models/{epoch+1}_lmepool_frontend_shortish_{round(f1, 2)}.pth')

[5, 12.3] loss: 0.025, acc: 0.000, f1: 0.000
[10, 24.3] loss: 0.024, acc: 0.000, f1: 0.000
[15, 36.3] loss: 0.023, acc: 0.000, f1: 0.000
[20, 48.4] loss: 0.021, acc: 0.002, f1: 0.006
[25, 60.4] loss: 0.020, acc: 0.003, f1: 0.006
[30, 72.6] loss: 0.020, acc: 0.002, f1: 0.008
[35, 84.7] loss: 0.019, acc: 0.004, f1: 0.021
[40, 97.0] loss: 0.019, acc: 0.011, f1: 0.029
[45, 109.4] loss: 0.018, acc: 0.017, f1: 0.044
[50, 121.6] loss: 0.018, acc: 0.025, f1: 0.055
[55, 133.8] loss: 0.018, acc: 0.029, f1: 0.066
[60, 145.9] loss: 0.017, acc: 0.037, f1: 0.071
[65, 158.2] loss: 0.017, acc: 0.040, f1: 0.069
[70, 170.5] loss: 0.016, acc: 0.055, f1: 0.121
[75, 182.7] loss: 0.016, acc: 0.051, f1: 0.110
[80, 195.0] loss: 0.016, acc: 0.060, f1: 0.141
[85, 207.6] loss: 0.015, acc: 0.091, f1: 0.160
[90, 220.0] loss: 0.015, acc: 0.081, f1: 0.123
[95, 232.4] loss: 0.016, acc: 0.073, f1: 0.110
[100, 245.1] loss: 0.014, acc: 0.103, f1: 0.166
[105, 257.4] loss: 0.014, acc: 0.124, f1: 0.208
[110, 270.0] loss: 0

In [9]:
f1s = []
ts = []
for t in np.linspace(0.4, 1, 61):
    f1s.append(f1_score(preds.sigmoid() > t, targs, average='micro'))
    ts.append(t)

In [10]:
max(f1s), accuracy_score(preds.sigmoid() > ts[np.argmax(f1s)], targs)

(0.5804704550965969, 0.43636363636363634)

In [11]:
ts[np.argmax(f1s)]

0.43000000000000005

In [12]:
from birdcall.metrics import *

preds_to_tp_fp_fn(preds, targs)

(tensor(6479), tensor(2963), tensor(6721))

In [15]:
t0 = time.time()
for epoch in range(260, 335):
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        model.train()
        inputs, labels = data[0].cuda(), data[1].cuda()
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if np.isnan(loss.item()): 
            print(f'!!! nan encountered in loss !!! alpha: epoch: {epoch}\n')
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()


    if epoch % 5 == 4:
        model.eval();
        preds = []
        targs = []

        with torch.no_grad():
            for data in valid_dl:
                inputs, labels = data[0].cuda(), data[1].cuda()
                outputs = model(inputs)
                preds.append(outputs.cpu().detach())
                targs.append(labels.cpu().detach())

            preds = torch.cat(preds)
            targs = torch.cat(targs)

        accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
        f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
        print(f'[{epoch + 1}, {(time.time() - t0)/60:.1f}] loss: {running_loss / len(train_dl):.3f}, acc: {accuracy:.3f}, f1: {f1:.3f}')
        running_loss = 0.0

        torch.save(model.state_dict(), f'models/{epoch+1}_lmepool_frontend_shortish_{round(f1, 2)}.pth')

[265, 12.3] loss: 0.005, acc: 0.447, f1: 0.584
[270, 24.8] loss: 0.005, acc: 0.407, f1: 0.313
[275, 37.3] loss: 0.005, acc: 0.448, f1: 0.573
[280, 49.8] loss: 0.005, acc: 0.437, f1: 0.448
[285, 62.0] loss: 0.005, acc: 0.398, f1: 0.553
[290, 74.5] loss: 0.005, acc: 0.459, f1: 0.484
[295, 86.8] loss: 0.005, acc: 0.467, f1: 0.599
[300, 99.1] loss: 0.004, acc: 0.458, f1: 0.566
[305, 111.3] loss: 0.004, acc: 0.432, f1: 0.194
[310, 123.6] loss: 0.004, acc: 0.433, f1: 0.235
[315, 136.2] loss: 0.004, acc: 0.460, f1: 0.434
[320, 148.5] loss: 0.003, acc: 0.467, f1: 0.611
[325, 161.1] loss: 0.004, acc: 0.468, f1: 0.387
[330, 173.6] loss: 0.004, acc: 0.477, f1: 0.502
[335, 186.3] loss: 0.004, acc: 0.469, f1: 0.568


In [7]:
model.load_state_dict(torch.load('models/335_lmepool_frontend_shortish_0.57.pth'))

<All keys matched successfully>

In [18]:
model.load_state_dict(torch.load('models/130_lmepool_frontend_0.72.pth'))

<All keys matched successfully>

In [22]:
from IPython.lib.display import FileLink

In [28]:
FileLink('models/130_lmepool_frontend_0.72.pth')

In [19]:
val_items = bin_items(pd.read_pickle('data/val_set.pkl'), pd.read_pickle('data/classes.pkl'))

In [20]:
%%time
model.eval();
preds = []
targs = []
fns = []

for num_specs in val_items.keys():
    valid_ds = MelspecShortishValidatioDataset(val_items[num_specs], pd.read_pickle('data/classes.pkl'))
    valid_dl = torch.utils.data.DataLoader(valid_ds, num_workers=NUM_WORKERS, pin_memory=True)
    
    fns += [item[1].name for item in valid_ds.items]
    
    with torch.no_grad():
        for data in valid_dl:
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = model(inputs)
            preds.append(outputs.cpu().detach())
            targs.append(labels.cpu().detach())

preds = torch.cat(preds)
targs = torch.cat(targs)

CPU times: user 31.5 s, sys: 7.55 s, total: 39.1 s
Wall time: 35.8 s


In [21]:
accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
accuracy, f1

(0.6055796055796056, 0.7320735179911985)

In [13]:
# accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
# f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
# accuracy, f1

(0.5161135161135161, 0.5741515574151557)

In [23]:
f1s = []
ts = []
for t in np.linspace(0.4, 1, 61):
    f1s.append(f1_score(preds.sigmoid() > t, targs, average='micro'))
    ts.append(t)

In [27]:
accuracy_score(preds.sigmoid() > ts[np.argmax(f1s)], targs), max(f1s)

(0.6089466089466089, 0.7345473198255963)

In [25]:
ts[np.argmax(f1s)]

0.46

In [26]:
from birdcall.metrics import *

preds_to_tp_fp_fn(preds, targs)

(tensor(1368), tensor(312), tensor(711))