In [1]:
from birdcall.data import *
from birdcall.metrics import *
from birdcall.ops import *

import torch
import torchvision
from torch import nn
import numpy as np
import pandas as pd
from pathlib import Path
import soundfile as sf

In [2]:
splits = pd.read_pickle('data/splits.pkl')
positive_class_items = pd.read_pickle('data/positive_class_items.pkl')
negative_class_items = pd.read_pickle('data/negative_class_items.pkl')
north_american_birds_common = pd.read_pickle('data/north_american_birds_common.pkl')
all_classes = pd.read_pickle('data/classes.pkl')

In [3]:
#export data
from collections import defaultdict

def translate_class(items, old_vocab, new_vocab):
    items_with_translated_class = []
    for cls_idx, path, duration in items:
        items_with_translated_class.append((new_vocab.index(old_vocab[cls_idx]), path, duration))
    return items_with_translated_class

class MelspecPoolDatasetNegativeClass(torch.utils.data.Dataset):
    def __init__(self, items, items_neg_class, north_american_birds_common, len_mult=20, specs_per_example=30, reshape_to_3ch=True, spec_dur=1.66):
        self.cls_idx_to_recs = defaultdict(list)
        for item in items:
            self.cls_idx_to_recs[item[0]].append(item)
        self.items = items
        self.items_neg_class = items_neg_class
        self.all_classes = classes
        self.vocab = north_american_birds_common
        self.specs_per_example = specs_per_example
        self.len_mult = len_mult
        self.reshape_to_3ch = reshape_to_3ch
        self.spec_dur = spec_dur
        
    def __getitem__(self, idx):
        if np.random.rand() > 0.54:
            cls_idx = idx % len(self.vocab)
            recs = self.cls_idx_to_recs[cls_idx]
            _, path, duration = recs[np.random.randint(0, len(recs))]
        else:
            cls_idx = -1
            _, path, duration = self.items_neg_class[np.random.randint(len(self.items_neg_class))]
            
        example = self.sample_specs(path, duration, self.specs_per_example)
        if self.reshape_to_3ch: example = example.reshape(-1, 3, 80, 212)
        return example.astype(np.float32), self.one_hot_encode(cls_idx)
    
    def sample_specs(self, path, duration, count):
        x, _ = sf.read(path)

        if x.shape[0] < self.spec_dur*SAMPLE_RATE:
            x =  np.tile(x, int(self.spec_dur*SAMPLE_RATE) // x.shape[0] + 1 ) # the shortest rec in the train set is 0.39 sec

        xs = []
        for _ in range(count):
            start_frame = int(np.random.rand() * (x.shape[0] - self.spec_dur * SAMPLE_RATE))
            xs.append(x[start_frame:start_frame+int(self.spec_dur*SAMPLE_RATE)])

        specs = []
        for x in xs:
            specs.append(audio_to_melspec(x))
        return np.stack(specs)
    
    def show(self, idx):
        x = self[idx][0][0]
        return plt.imshow(x.transpose(1,2,0)[:, :, 0])
        
    def one_hot_encode(self, y):
        one_hot = np.zeros((len(self.vocab)))
        if y != -1:
            one_hot[y] = 1
        return one_hot
    
    def __len__(self):
        return self.len_mult * len(self.vocab)

In [4]:
train_items = np.array(positive_class_items)[splits[0][0]].tolist()
val_items = np.array(positive_class_items)[splits[0][1]].tolist()
negative_class_items = [(-1, item[1], item[2]) for item in negative_class_items]

train_items = translate_class(train_items, all_classes, north_american_birds_common)
val_items = translate_class(val_items, all_classes, north_american_birds_common)

In [5]:
classes = pd.read_pickle('data/classes.pkl')
north_american_birds_common = pd.read_pickle('data/north_american_birds_common.pkl')

train_ds = MelspecPoolDatasetNegativeClass(train_items, negative_class_items, north_american_birds_common, len_mult=300)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, num_workers=NUM_WORKERS, pin_memory=True, shuffle=True)

In [6]:
#export data
class MelspecShortishValidatioDatasetNegativeClass(torch.utils.data.Dataset):
    def __init__(self, items, vocab, negative_class_items=[], reshape_to_3ch=True):
        self.vocab = vocab
        self.items = items + negative_class_items
        self.reshape_to_3ch = reshape_to_3ch
        
    def __len__(self): return len(self.items)
    
    def __getitem__(self, idx):
        item = self.items[idx]
        
        return self.create_example(self.items[idx])
        
    def create_example(self, item):
        cls_idx, path, num_specs = item
        
        x, _ = sf.read(path)

        example_duration = num_specs * 5 * SAMPLE_RATE
        if x.shape[0] < example_duration:
            x = np.tile(x, example_duration // x.shape[0] + 1)
            
        start_frame = 0
        x = x[start_frame:example_duration]

        xs = []
        for i in range(num_specs):
            for j in range(3):
                start_frame = int((i * 3 + j) * 1.66 * SAMPLE_RATE)
                xs.append(x[start_frame:start_frame+int(1.66*SAMPLE_RATE)])

        specs = []
        for x in xs:
            specs.append(audio_to_melspec(x))
        specs = np.stack(specs)
        if self.reshape_to_3ch: specs = specs.reshape(-1, 3, 80, 212)

        one_hot = np.zeros((len(self.vocab)))
        if cls_idx != -1: one_hot[cls_idx] = 1

        return specs.astype(np.float32), one_hot

In [7]:
#export data
def bin_items_negative_class(items):        
    binned_items = defaultdict(list)
    for cls_idx, path, duration in items:
        if duration < 7.5: binned_items[1].append((cls_idx, path, 1))
        elif duration < 12.5: binned_items[2].append((cls_idx, path, 2))
        elif duration < 25: binned_items[4].append((cls_idx, path, 4))
        elif duration < 45: binned_items[6].append((cls_idx, path, 6))
        else: binned_items[10].append((cls_idx, path, 10))
    return binned_items

In [8]:
val_items_binned = bin_items_negative_class(val_items)

np.random.shuffle(negative_class_items)
negative_class_items = negative_class_items[:2500]
negative_class_items_binned = bin_items_negative_class(negative_class_items)

In [9]:
class FrontEnd(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(80, affine=False)
        self.register_parameter('alpha', torch.nn.Parameter(torch.tensor(0.)))
        
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = x ** torch.sigmoid(self.alpha)
        x = x.view(-1, y_dim, x_dim)
        x = self.bn(x)
        return x.view(bs, im_num, ch, y_dim, x_dim)

In [10]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.frontend = FrontEnd()
        self.cnn = nn.Sequential(*list(torchvision.models.resnet34(True).children())[:-2])
        self.classifier = nn.Sequential(*[
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.BatchNorm1d(512),
            nn.Linear(512, len(north_american_birds_common))
        ])
    
    def forward(self, x):
        bs, im_num, ch, y_dim, x_dim = x.shape
        x = self.frontend(x)
        x = self.cnn(x.view(-1, ch, y_dim, x_dim))
        x = x.mean((2,3))
        x = self.classifier(x)
        x = x.view(bs, im_num, -1)
        x = lme_pool(x)
        return x

In [11]:
model = Model().cuda()

In [12]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score
import time

In [13]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), 1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)

In [None]:
t0 = time.time()
for epoch in range(130):
    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        model.train()
        inputs, labels = data[0].cuda(), data[1].cuda()
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if np.isnan(loss.item()): 
            print(f'!!! nan encountered in loss !!! alpha: epoch: {epoch}\n')
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()


    if epoch % 5 == 4:
        model.eval();
        preds = []
        targs = []

        for num_specs in val_items_binned.keys():
            valid_ds = MelspecShortishValidatioDataset(val_items_binned[num_specs], north_american_birds_common, negative_class_items_binned[num_specs])
            valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=2*16, num_workers=NUM_WORKERS, pin_memory=True)

            with torch.no_grad():
                for data in valid_dl:
                    inputs, labels = data[0].cuda(), data[1].cuda()
                    outputs = model(inputs)
                    preds.append(outputs.cpu().detach())
                    targs.append(labels.cpu().detach())

        preds = torch.cat(preds)
        targs = torch.cat(targs)

        accuracy = accuracy_score(preds.sigmoid() > 0.5, targs)
        f1 = f1_score(preds.sigmoid() > 0.5, targs, average='micro')
        print(f'[{epoch + 1}, {(time.time() - t0)/60:.1f}] loss: {running_loss / (len(train_dl)-1):.3f}, acc: {accuracy:.3f}, f1: {f1:.3f}')
        running_loss = 0.0

        torch.save(model.state_dict(), f'models/{epoch+1}_lmepool_frontend_neg_class_refactored_{round(f1, 2)}.pth')

In [None]:
f1s = []
ts = []
for t in np.linspace(0.4, 1, 61):
    f1s.append(f1_score(preds.sigmoid() > t, targs, average='micro'))
    ts.append(t)

In [None]:
max(f1s), accuracy_score(preds.sigmoid() > ts[np.argmax(f1s)], targs)

In [None]:
ts[np.argmax(f1s)]

In [None]:
from birdcall.metrics import *

preds_to_tp_fp_fn(preds, targs)