In [1]:
!pip install timm

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
     |████████████████████████████████| 431 kB 929 kB/s            
Installing collected packages: timm
Successfully installed timm-0.5.4


In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from matplotlib import pyplot as plt

import librosa

import torch
import torchaudio as ta
import timm

from tqdm import tqdm

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
data_root = '/kaggle/input/birdclef-2022'
train_meta = pd.read_csv(os.path.join(data_root, 'train_metadata.csv'))
ebird_taxonomy = pd.read_csv(os.path.join(data_root, 'eBird_Taxonomy_v2021.csv'))

In [4]:
train_meta.loc[:, 'secondary_labels'] = train_meta.secondary_labels.apply(eval)
train_meta['target_raw'] = train_meta.secondary_labels + train_meta.primary_label.apply(lambda x: [x])

In [5]:
all_species = sorted(set(train_meta.target_raw.sum()))
species2id = {s: i for i, s in enumerate(all_species)}
id2species = {i: s for i, s in enumerate(all_species)}

train_meta['target'] = train_meta.target_raw.apply(lambda species: [int(s in species) for s in all_species])

In [6]:
def load_wav(fname, offset, duration):
#     fname = 'afrsil1/XC125458.ogg'
    fpath = os.path.join(data_root, 'train_audio', fname)
    wav, sr = librosa.load(fpath, duration=duration)
    assert sr <= 32000, sr
    return wav, sr

In [7]:
# duration = 30
# sample_rate = 32000

# wav, sr = load_wav('afrsil1/XC125458.ogg', 0, duration)
# to_pad = duration * sample_rate - wav.shape[0]

# if to_pad > 0:
#     wav = np.pad(wav, (0, to_pad))



### Torch Dataset

In [8]:
from torch.utils.data import Dataset, DataLoader

class BirdDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df
        
    def __getitem__(self, idx):
        duration = 30
        sample_rate = 32000
        
        fname = self.df.iloc[idx]['filename']
        # TODO: add random offset
        wav, sr = load_wav(fname, 0, duration)
        to_pad = duration * sample_rate - wav.shape[0]
        if to_pad > 0:
            wav = np.pad(wav, (0, to_pad))
            
        target = self.df.iloc[idx]['target']
        
        # TODO: add weighting
            
        wav = torch.tensor(wav)
        target = torch.tensor(target, dtype=float)
        return {
            'wav': wav,
            'target': target,
        }

    def __len__(self):
        return len(self.df)

### Model

In [9]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.audio2image = self._init_audio2image()
        self.backbone = self._init_backbone()
        self.head = self._init_head(self.backbone.feature_info[-1]['num_chs'])      
        self.loss = torch.nn.BCEWithLogitsLoss()
        
    def forward(self, wav_tensor, y):
        spectrogram = self.audio2image(wav_tensor)
        spectrogram = spectrogram.permute(0, 2, 1)
        spectrogram = spectrogram[:, None, :, :]
        x = self.backbone(spectrogram)
        logits = self.head(x)
        loss = self.loss(logits, y)
        return {'loss': loss, 'logits': logits.sigmoid()}

    
    @staticmethod
    def _init_audio2image():
        mel = ta.transforms.MelSpectrogram(
            sample_rate=32000,
            n_fft=2048,
            win_length=2048,
            hop_length=512,
            f_min=16,
            f_max=16386,
            pad=0,
            n_mels=256,
            power=2,
            normalized=False,
        )
        db_scale = ta.transforms.AmplitudeToDB(top_db=80.0)
        audio2image = torch.nn.Sequential(mel, db_scale)
        return audio2image
    
    @staticmethod
    def _init_backbone():
        backbone = "resnet18"
        pretrained = True
        pretrained_weights = None
        train = True
        val = False
        in_chans = 1

        backbone = timm.create_model(
            backbone,
            pretrained=pretrained,
            num_classes=0,
            global_pool="",
            in_chans=in_chans,
        )
        return backbone
    
    @staticmethod
    def _init_head(num_chs):
        head = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(num_chs, len(all_species))
        )
        return head
        

### Train loop

In [10]:
model = Net()
train_dataset = BirdDataset(train_meta)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters())

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


In [11]:
def train_epoch(model, optimizer, dataloader, device):
    tqdm_dataloader = tqdm(dataloader)
    loss_list = []
    for batch in tqdm_dataloader:
        loss = model(batch['wav'].to(device), batch['target'].to(device))['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
    return loss_list
    

In [12]:
if False:
    epochs = 2
    model.to(device)
    for e in range(epochs):
        epoch_loss = train_epoch(model, optimizer, train_dataloader, device)
        print(f'{e} train loss:', f'{epoch_loss:.3f}', sep='\t')