In [1]:
DATA_PATH = 'E:/Mein/Учеба/Audio/'

# names of valuable files/folders
train_meta_fname = 'train.csv'
test_meta_fname = 'sample_submission.csv'
train_data_folder = 'train'
test_data_folder = 'test'

In [2]:
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import torchvision
from torchaudio import transforms
from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from tqdm import tqdm

In [3]:
import random
import numpy as np

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [4]:
df_train = pd.read_csv(os.path.join(DATA_PATH, train_meta_fname))
df_test = pd.read_csv(os.path.join(DATA_PATH, test_meta_fname))
df_train.head(2)

Unnamed: 0,fname,label
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping
1,00d77b917e241afa06f1.wav,Squeak


In [5]:
n_classes = df_train.label.nunique()
print(n_classes)
classes_dict = {cl:i for i,cl in enumerate(df_train.label.unique())}
df_train['label_encoded'] = df_train.label.map(classes_dict)
df_train.head()

41


Unnamed: 0,fname,label,label_encoded
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping,0
1,00d77b917e241afa06f1.wav,Squeak,1
2,17bb93b73b8e79234cb3.wav,Electric_piano,2
3,7d5c7a40a936136da55e.wav,Harmonica,3
4,17e0ee7565a33d6c2326.wav,Snare_drum,4


In [6]:
class BaseLineModel(nn.Module):
    
    def __init__(self, sample_rate=16000, n_classes=41):
        super().__init__()
        self.ms = torchaudio.transforms.MelSpectrogram(sample_rate)
#         self.bn1 = nn.BatchNorm2d(1)
        
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, padding=1)
        self.cnn3 = nn.Conv2d(in_channels=10, out_channels=3, kernel_size=3, padding=1)
        
        self.features = EfficientNet.from_pretrained('efficientnet-b0')
        # use it as features
#         for param in self.features.parameters():
#             param.requires_grad = False
            
        self.lin1 = nn.Linear(1000, 333)
        
        self.lin2 = nn.Linear(333, 111)
                
        self.lin3 = nn.Linear(111, n_classes)
        
    def forward(self, x):
        x = self.ms(x)
#         x = self.bn1(x)
                
        x = F.relu(self.cnn1(x))
        x = F.relu(self.cnn3(x))
        
        x = self.features(x)

        x = x.view(x.shape[0], -1)
        x = F.relu(x)
        
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x
    
    def inference(self, x):
        x = self.forward(x)
        x = F.softmax(x)
        return x

In [7]:
def sample_or_pad(waveform, wav_len=32000):
    m, n = waveform.shape
    if n < wav_len:
        padded_wav = torch.zeros(1, wav_len)
        padded_wav[:, :n] = waveform
        return padded_wav
    elif n > wav_len:
        offset = np.random.randint(0, n - wav_len)
        sampled_wav = waveform[:, offset:offset+wav_len]
        return sampled_wav
    else:
        return waveform
        
class EventDetectionDataset(Dataset):
    def __init__(self, data_path, x, y=None):
        self.x = x
        self.y = y
        self.data_path = data_path
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        path2wav = os.path.join(self.data_path, self.x[idx])
        waveform, sample_rate = torchaudio.load(path2wav, normalization=True)
        waveform = sample_or_pad(waveform)
        if self.y is not None:
            return waveform, self.y[idx]
        return waveform

In [8]:
X_train, X_val, y_train, y_val = train_test_split(df_train.fname.values, df_train.label_encoded.values, 
                                                  test_size=0.2, random_state=42)
train_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_train, y_train),
                        batch_size=41
                )
val_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_val, y_val),
                        batch_size=41
                )
test_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, test_data_folder), df_test.fname.values, None),
                        batch_size=41, shuffle=False
                )

In [9]:
def eval_model(model, eval_dataset):
    model.eval()
    forecast, true_labs = [], []
    with torch.no_grad():
        for wavs, labs in tqdm(eval_dataset):
            wavs, labs = wavs.cuda(), labs.detach().numpy()
            true_labs.append(labs)
            outputs = model.inference(wavs)
            
            outputs = outputs.detach().cpu().numpy().argmax(axis=1)
            forecast.append(outputs)
    forecast = [x for sublist in forecast for x in sublist]
    true_labs = [x for sublist in true_labs for x in sublist]
    return f1_score(forecast, true_labs, average='macro')

In [10]:
criterion = nn.CrossEntropyLoss()
model = BaseLineModel()
model = model.cuda()
lr = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Loaded pretrained weights for efficientnet-b0


In [11]:
n_epoch = 45
best_f1 = 0
for epoch in range(n_epoch):
    model.train()
    for wavs, labs in tqdm(train_loader):
        optimizer.zero_grad()
        wavs, labs = wavs.cuda(), labs.cuda()
        outputs = model(wavs)
        loss = criterion(outputs, labs)
        loss.backward()
        optimizer.step()
#     if epoch % 10 == 0:
    f1 = eval_model(model, val_loader)
    f1_train = eval_model(model, train_loader)
    print(f'epoch: {epoch}, f1_test: {f1}, f1_train: {f1_train}')
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), 'baseline_fulldiv_2.pt')
        
    lr = lr * 0.95
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

100%|██████████| 111/111 [05:00<00:00,  2.71s/it]
100%|██████████| 28/28 [01:10<00:00,  2.52s/it]
100%|██████████| 111/111 [03:46<00:00,  2.04s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 0, f1_test: 0.08672982012717403, f1_train: 0.08584106618754195


100%|██████████| 111/111 [05:11<00:00,  2.81s/it]
100%|██████████| 28/28 [01:13<00:00,  2.62s/it]
100%|██████████| 111/111 [03:49<00:00,  2.07s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 1, f1_test: 0.2243916748047431, f1_train: 0.23550236737321567


100%|██████████| 111/111 [04:26<00:00,  2.40s/it]
100%|██████████| 28/28 [01:01<00:00,  2.20s/it]
100%|██████████| 111/111 [03:46<00:00,  2.04s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 2, f1_test: 0.20631553412600312, f1_train: 0.23052652910105234


100%|██████████| 111/111 [03:50<00:00,  2.08s/it]
100%|██████████| 28/28 [00:54<00:00,  1.93s/it]
100%|██████████| 111/111 [03:04<00:00,  1.66s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 3, f1_test: 0.22903201266913606, f1_train: 0.26031545277395346


100%|██████████| 111/111 [03:37<00:00,  1.96s/it]
100%|██████████| 28/28 [00:49<00:00,  1.76s/it]
100%|██████████| 111/111 [03:26<00:00,  1.86s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 4, f1_test: 0.35486816338961247, f1_train: 0.3938568166972967


100%|██████████| 111/111 [03:35<00:00,  1.94s/it]
100%|██████████| 28/28 [00:49<00:00,  1.77s/it]
100%|██████████| 111/111 [02:56<00:00,  1.59s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 5, f1_test: 0.33311821481150206, f1_train: 0.4154129571616797


100%|██████████| 111/111 [03:45<00:00,  2.04s/it]
100%|██████████| 28/28 [00:49<00:00,  1.78s/it]
100%|██████████| 111/111 [03:10<00:00,  1.72s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 6, f1_test: 0.42413863605275265, f1_train: 0.5111541997333628


100%|██████████| 111/111 [04:03<00:00,  2.19s/it]
100%|██████████| 28/28 [00:51<00:00,  1.85s/it]
100%|██████████| 111/111 [03:03<00:00,  1.66s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 7, f1_test: 0.4529135271078379, f1_train: 0.5730540646852728


100%|██████████| 111/111 [04:08<00:00,  2.24s/it]
100%|██████████| 28/28 [00:50<00:00,  1.79s/it]
100%|██████████| 111/111 [03:11<00:00,  1.72s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 8, f1_test: 0.4936825629925663, f1_train: 0.6431187529753074


100%|██████████| 111/111 [03:36<00:00,  1.95s/it]
100%|██████████| 28/28 [01:04<00:00,  2.29s/it]
100%|██████████| 111/111 [02:58<00:00,  1.61s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 9, f1_test: 0.4038542388591014, f1_train: 0.496867134295633


100%|██████████| 111/111 [03:38<00:00,  1.97s/it]
100%|██████████| 28/28 [00:47<00:00,  1.71s/it]
100%|██████████| 111/111 [03:09<00:00,  1.71s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 10, f1_test: 0.4675401638936091, f1_train: 0.6132172885757705


100%|██████████| 111/111 [03:33<00:00,  1.93s/it]
100%|██████████| 28/28 [00:49<00:00,  1.75s/it]
100%|██████████| 111/111 [02:58<00:00,  1.61s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 11, f1_test: 0.520426769799189, f1_train: 0.6815268907463442


100%|██████████| 111/111 [03:40<00:00,  1.98s/it]
100%|██████████| 28/28 [00:49<00:00,  1.78s/it]
100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 12, f1_test: 0.5439964549583134, f1_train: 0.7021997656523984


100%|██████████| 111/111 [04:21<00:00,  2.35s/it]
100%|██████████| 28/28 [00:59<00:00,  2.13s/it]
100%|██████████| 111/111 [03:51<00:00,  2.09s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 13, f1_test: 0.471546250828828, f1_train: 0.6013093081377464


100%|██████████| 111/111 [05:07<00:00,  2.77s/it]
100%|██████████| 28/28 [00:54<00:00,  1.93s/it]
100%|██████████| 111/111 [03:27<00:00,  1.87s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 14, f1_test: 0.5732330392920386, f1_train: 0.7372126475993634


100%|██████████| 111/111 [04:34<00:00,  2.48s/it]
100%|██████████| 28/28 [00:54<00:00,  1.95s/it]
100%|██████████| 111/111 [03:32<00:00,  1.91s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 15, f1_test: 0.5208280803383415, f1_train: 0.6738171315770931


100%|██████████| 111/111 [04:09<00:00,  2.24s/it]
100%|██████████| 28/28 [00:57<00:00,  2.05s/it]
100%|██████████| 111/111 [03:12<00:00,  1.73s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 16, f1_test: 0.4238591726575478, f1_train: 0.5695944320583652


100%|██████████| 111/111 [04:05<00:00,  2.22s/it]
100%|██████████| 28/28 [00:48<00:00,  1.72s/it]
100%|██████████| 111/111 [03:12<00:00,  1.73s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 17, f1_test: 0.4579840801880822, f1_train: 0.6349720147426362


100%|██████████| 111/111 [03:44<00:00,  2.03s/it]
100%|██████████| 28/28 [00:49<00:00,  1.75s/it]
100%|██████████| 111/111 [02:58<00:00,  1.61s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 18, f1_test: 0.4928989961800084, f1_train: 0.654218562743386


100%|██████████| 111/111 [03:36<00:00,  1.95s/it]
100%|██████████| 28/28 [00:47<00:00,  1.71s/it]
100%|██████████| 111/111 [03:08<00:00,  1.70s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 19, f1_test: 0.4077470647403362, f1_train: 0.5419211915858181


100%|██████████| 111/111 [03:30<00:00,  1.90s/it]
100%|██████████| 28/28 [00:49<00:00,  1.78s/it]
100%|██████████| 111/111 [02:54<00:00,  1.58s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 20, f1_test: 0.3134654096476439, f1_train: 0.43109494058531456


100%|██████████| 111/111 [03:38<00:00,  1.97s/it]
100%|██████████| 28/28 [00:49<00:00,  1.76s/it]
100%|██████████| 111/111 [03:20<00:00,  1.81s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 21, f1_test: 0.4279699700519587, f1_train: 0.5808301914649276


100%|██████████| 111/111 [03:49<00:00,  2.07s/it]
100%|██████████| 28/28 [00:49<00:00,  1.76s/it]
100%|██████████| 111/111 [02:55<00:00,  1.58s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 22, f1_test: 0.3908169853288906, f1_train: 0.5207418520727665


100%|██████████| 111/111 [03:43<00:00,  2.01s/it]
100%|██████████| 28/28 [00:50<00:00,  1.79s/it]
100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 23, f1_test: 0.42784314534615514, f1_train: 0.5915147084641941


100%|██████████| 111/111 [04:20<00:00,  2.35s/it]
100%|██████████| 28/28 [00:51<00:00,  1.83s/it]
100%|██████████| 111/111 [03:09<00:00,  1.71s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 24, f1_test: 0.473787787425132, f1_train: 0.6263514806434017


100%|██████████| 111/111 [03:42<00:00,  2.00s/it]
100%|██████████| 28/28 [00:56<00:00,  2.01s/it]
100%|██████████| 111/111 [03:11<00:00,  1.73s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 25, f1_test: 0.3735565457891073, f1_train: 0.48274043844304143


100%|██████████| 111/111 [04:37<00:00,  2.50s/it]
100%|██████████| 28/28 [00:49<00:00,  1.76s/it]
100%|██████████| 111/111 [03:18<00:00,  1.79s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 26, f1_test: 0.48225267059365656, f1_train: 0.6630963157234333


100%|██████████| 111/111 [04:15<00:00,  2.30s/it]
100%|██████████| 28/28 [00:49<00:00,  1.76s/it]
100%|██████████| 111/111 [03:16<00:00,  1.77s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 27, f1_test: 0.41643731923793054, f1_train: 0.59381409564686


100%|██████████| 111/111 [04:09<00:00,  2.25s/it]
100%|██████████| 28/28 [00:55<00:00,  1.99s/it]
100%|██████████| 111/111 [03:25<00:00,  1.85s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 28, f1_test: 0.4235059434086725, f1_train: 0.5766101771707053


100%|██████████| 111/111 [04:59<00:00,  2.70s/it]
100%|██████████| 28/28 [00:54<00:00,  1.95s/it]
100%|██████████| 111/111 [03:48<00:00,  2.06s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 29, f1_test: 0.3717440256254967, f1_train: 0.510286946105459


100%|██████████| 111/111 [04:04<00:00,  2.20s/it]
100%|██████████| 28/28 [00:58<00:00,  2.11s/it]
100%|██████████| 111/111 [03:54<00:00,  2.11s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 30, f1_test: 0.5226673124114498, f1_train: 0.7043775433928576


100%|██████████| 111/111 [04:12<00:00,  2.28s/it]
100%|██████████| 28/28 [00:49<00:00,  1.78s/it]
100%|██████████| 111/111 [03:44<00:00,  2.02s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 31, f1_test: 0.3808671769851983, f1_train: 0.5489014199396945


100%|██████████| 111/111 [04:18<00:00,  2.33s/it]
100%|██████████| 28/28 [00:51<00:00,  1.82s/it]
100%|██████████| 111/111 [03:22<00:00,  1.83s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 32, f1_test: 0.4408889591375032, f1_train: 0.638177443703883


100%|██████████| 111/111 [04:24<00:00,  2.38s/it]
100%|██████████| 28/28 [00:55<00:00,  1.97s/it]
100%|██████████| 111/111 [03:25<00:00,  1.85s/it]


epoch: 33, f1_test: 0.5816243358271652, f1_train: 0.7669283478705207


100%|██████████| 111/111 [03:52<00:00,  2.10s/it]
100%|██████████| 28/28 [00:52<00:00,  1.89s/it]
100%|██████████| 111/111 [03:17<00:00,  1.77s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 34, f1_test: 0.4524324645136685, f1_train: 0.6437849761681761


100%|██████████| 111/111 [04:22<00:00,  2.36s/it]
100%|██████████| 28/28 [00:51<00:00,  1.84s/it]
100%|██████████| 111/111 [03:54<00:00,  2.11s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 35, f1_test: 0.48910224056249074, f1_train: 0.6806841068996061


100%|██████████| 111/111 [04:00<00:00,  2.17s/it]
100%|██████████| 28/28 [00:52<00:00,  1.88s/it]
100%|██████████| 111/111 [03:27<00:00,  1.87s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 36, f1_test: 0.4990988967149006, f1_train: 0.6639625127098123


100%|██████████| 111/111 [04:01<00:00,  2.18s/it]
100%|██████████| 28/28 [00:50<00:00,  1.81s/it]
100%|██████████| 111/111 [03:43<00:00,  2.01s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 37, f1_test: 0.514350258476882, f1_train: 0.7069969584300682


100%|██████████| 111/111 [03:47<00:00,  2.05s/it]
100%|██████████| 28/28 [01:10<00:00,  2.50s/it]
100%|██████████| 111/111 [03:03<00:00,  1.65s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 38, f1_test: 0.5245882795806621, f1_train: 0.699778822833345


100%|██████████| 111/111 [03:55<00:00,  2.12s/it]
100%|██████████| 28/28 [00:55<00:00,  1.98s/it]
100%|██████████| 111/111 [03:17<00:00,  1.78s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 39, f1_test: 0.4660056199636203, f1_train: 0.6400046940047314


100%|██████████| 111/111 [03:57<00:00,  2.14s/it]
100%|██████████| 28/28 [00:57<00:00,  2.06s/it]
100%|██████████| 111/111 [03:12<00:00,  1.73s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 40, f1_test: 0.4572894070740288, f1_train: 0.6560077485601234


100%|██████████| 111/111 [04:06<00:00,  2.22s/it]
100%|██████████| 28/28 [00:51<00:00,  1.85s/it]
100%|██████████| 111/111 [04:01<00:00,  2.18s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 41, f1_test: 0.5235067252838232, f1_train: 0.6909909090720606


100%|██████████| 111/111 [03:53<00:00,  2.11s/it]
100%|██████████| 28/28 [00:58<00:00,  2.08s/it]
100%|██████████| 111/111 [03:11<00:00,  1.73s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 42, f1_test: 0.5047501062627344, f1_train: 0.6794654618420454


100%|██████████| 111/111 [04:24<00:00,  2.38s/it]
100%|██████████| 28/28 [01:10<00:00,  2.51s/it]
100%|██████████| 111/111 [03:49<00:00,  2.07s/it]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 43, f1_test: 0.4821024434279351, f1_train: 0.6975675782535434


100%|██████████| 111/111 [04:14<00:00,  2.30s/it]
100%|██████████| 28/28 [00:53<00:00,  1.90s/it]
100%|██████████| 111/111 [03:17<00:00,  1.78s/it]

epoch: 44, f1_test: 0.5015752864318939, f1_train: 0.7229030860335409





In [12]:
model_name = 'baseline_fulldiv_2.pt'
model = BaseLineModel().cuda()
model.load_state_dict(torch.load(os.path.join(model_name)))
model.eval()
forecast = []
with torch.no_grad():
    for wavs in tqdm(test_loader):
        wavs = wavs.cuda()
        outputs = model.inference(wavs)
        outputs = outputs.detach().cpu().numpy().argmax(axis=1)
        forecast.append(outputs)
forecast = [x for sublist in forecast for x in sublist]
decoder = {classes_dict[cl]:cl for cl in classes_dict}
forecast = pd.Series(forecast).map(decoder)
df_test['label'] = forecast
df_test.to_csv(f'{model_name}.csv', index=None)

Loaded pretrained weights for efficientnet-b0


100%|██████████| 93/93 [02:22<00:00,  1.53s/it]
