In [None]:
import sys, os
if os.path.abspath(os.pardir) not in sys.path:
    sys.path.insert(1, os.path.abspath(os.pardir))
import CONFIG
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.filterwarnings("ignore")

from tqdm.auto import tqdm

In [None]:
DATA_DIR = CONFIG.config.DATA.BASE
SR = 32000
N_MELS = 128
BATCH_SIZE = 8
N_FFT = 2048

In [None]:
train_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))[['ebird_code', 'filename', 'duration']]
test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
submission = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv"))

In [None]:
train_df['filepath'] = ""
train_df['filepath'] = train_df.apply(lambda row: os.path.join(DATA_DIR, "train_audio", row["ebird_code"], row["filename"]), axis=1)
train_df.head()

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
def get_melspectrogram(audio_path):
    y, sr = librosa.load(audio_path, sr=SR)
    mel_spec = librosa.feature.melspectrogram(y=y, sr=SR, n_fft=N_FFT, hop_length=1024)
    mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
    return mel_spec

In [None]:
def plot_random_melmel_spec():
    sample = train_df["filepath"].sample().values[0]
    mel_spec = get_melspectrogram(sample)
    print(mel_spec.shape)
    librosa.display.specshow(mel_spec, y_axis='mel', fmax=8000, x_axis='time')

In [None]:
plot_random_melmel_spec()

In [None]:
class BirdSoundDataset(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        idx_row = self.df.iloc[idx]
        file_path = idx_row['filepath']
        bird_code = idx_row['ebird_code']
        spec = get_melspectrogram(file_path)
        spec = F.adaptive_avg_pool2d(torch.tensor(spec).unsqueeze(dim=0), (128, 512))
        return {
            'spec': torch.tensor(spec, dtype=torch.float32),
            'label': torch.tensor(BIRD_CODE[bird_code], dtype=torch.long)
        }

In [None]:
train_dataset = BirdSoundDataset(train_df)
train_dataset[10]

In [None]:
class SimpleCNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*120*504, out_features=400)
        self.out = nn.Linear(in_features=400, out_features=264)

    def forward(self, t):
        out = self.conv1(t)
        out = F.relu(out)

        out = self.conv2(out)
        out = F.relu(out)

        out = out.reshape(-1, 12* 120* 504)
        out = self.fc1(out)
        return self.out(out)
        # out = self.fc1(out)
        # return self.out(out)

In [None]:
train_dataset = BirdSoundDataset(train_df)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=50,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

In [None]:
model = SimpleCNNModel()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=4e-5)

In [None]:
for i, j in enumerate(train_dataloader):
    print(j)
    break

In [None]:
model.train()
for epoch in range(10):
    epoch_loss = 0
    iter = 0
    for i, j in enumerate(train_dataloader):
        target = j["label"]
        out = model(j['spec'])
        loss = nn.CrossEntropyLoss()(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter += 1
        epoch_loss += loss.item()

        if i % 2 == 0:
            print(f"Current Loss avg: {epoch_loss/iter}")

    print(f"Epoch {epoch} Loss {epoch_loss/iter}")       