In [60]:
# path to your train/test/meta folders
DATA_PATH = '/content/drive/My drive/Colab Notebooks'

# names of valuable files/folders
train_meta_fname = 'train.csv'
test_meta_fname = 'sample_submission.csv'
train_data_folder = 'train'
test_data_folder = 'test'

In [72]:
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import torchvision
from torchaudio import transforms
from efficientnet_pytorch import EfficientNet
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from tqdm import tqdm

In [73]:
# set seeds
import random
import numpy as np

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [74]:
df_train = pd.read_csv(os.path.join(DATA_PATH, train_meta_fname))
df_test = pd.read_csv(os.path.join(DATA_PATH, test_meta_fname))
df_train.head(2)

Unnamed: 0,fname,label
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping
1,00d77b917e241afa06f1.wav,Squeak


In [75]:
n_classes = df_train.label.nunique()
print(n_classes)
classes_dict = {cl:i for i,cl in enumerate(df_train.label.unique())}
df_train['label_encoded'] = df_train.label.map(classes_dict)
df_train.head()

41


Unnamed: 0,fname,label,label_encoded
0,8bcbcc394ba64fe85ed4.wav,Finger_snapping,0
1,00d77b917e241afa06f1.wav,Squeak,1
2,17bb93b73b8e79234cb3.wav,Electric_piano,2
3,7d5c7a40a936136da55e.wav,Harmonica,3
4,17e0ee7565a33d6c2326.wav,Snare_drum,4


In [76]:
# https://github.com/lukemelas/EfficientNet-PyTorch
class BaseLineModel(nn.Module):
    
    def __init__(self, sample_rate=16000, n_classes=41):
        super().__init__()
        self.ms = torchaudio.transforms.MelSpectrogram(sample_rate)
#         self.bn1 = nn.BatchNorm2d(1)
        
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, padding=1)
        self.cnn3 = nn.Conv2d(in_channels=10, out_channels=3, kernel_size=3, padding=1)
        
        self.features = EfficientNet.from_pretrained('efficientnet-b0')
        # use it as features
#         for param in self.features.parameters():
#             param.requires_grad = False
            
        self.lin1 = nn.Linear(1000, 333)
        
        self.lin2 = nn.Linear(333, 111)
                
        self.lin3 = nn.Linear(111, n_classes)
        
    def forward(self, x):
        x = self.ms(x)
#         x = self.bn1(x)
                
        x = F.relu(self.cnn1(x))
        x = F.relu(self.cnn3(x))
        
        x = self.features(x)

        x = x.view(x.shape[0], -1)
        x = F.relu(x)
        
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x
    
    def inference(self, x):
        x = self.forward(x)
        x = F.softmax(x)
        return x

In [77]:
def sample_or_pad(waveform, wav_len=32000):
    m, n = waveform.shape
    if n < wav_len:
        padded_wav = torch.zeros(1, wav_len)
        padded_wav[:, :n] = waveform
        return padded_wav
    elif n > wav_len:
        offset = np.random.randint(0, n - wav_len)
        sampled_wav = waveform[:, offset:offset+wav_len]
        return sampled_wav
    else:
        return waveform
        
class EventDetectionDataset(Dataset):
    def __init__(self, data_path, x, y=None):
        self.x = x
        self.y = y
        self.data_path = data_path
    
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        path2wav = os.path.join(self.data_path, self.x[idx])
        waveform, sample_rate = torchaudio.load(path2wav, normalization=True)
        waveform = sample_or_pad(waveform)
        if self.y is not None:
            return waveform, self.y[idx]
        return waveform

In [78]:
X_train, X_val, y_train, y_val = train_test_split(df_train.fname.values, df_train.label_encoded.values, 
                                                  test_size=0.2, random_state=42)
train_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_train, y_train),
                        batch_size=41
                )
val_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, train_data_folder), X_val, y_val),
                        batch_size=41
                )
test_loader = DataLoader(
                        EventDetectionDataset(os.path.join(DATA_PATH, test_data_folder), df_test.fname.values, None),
                        batch_size=41, shuffle=False
                )

In [79]:
def eval_model(model, eval_dataset):
    model.eval()
    forecast, true_labs = [], []
    with torch.no_grad():
        for wavs, labs in tqdm(eval_dataset):
            wavs, labs = wavs.cuda(), labs.detach().numpy()
            true_labs.append(labs)
            outputs = model.inference(wavs)
            
            outputs = outputs.detach().cpu().numpy().argmax(axis=1)
            forecast.append(outputs)
    forecast = [x for sublist in forecast for x in sublist]
    true_labs = [x for sublist in true_labs for x in sublist]
    return f1_score(forecast, true_labs, average='macro')

In [80]:
criterion = nn.CrossEntropyLoss()
model = BaseLineModel()
model = model.cuda()
lr = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Loaded pretrained weights for efficientnet-b0


In [81]:
n_epoch = 100
best_f1 = 0
for epoch in range(n_epoch):
    model.train()
    for wavs, labs in tqdm(train_loader):
        optimizer.zero_grad()
        wavs, labs = wavs.cuda(), labs.cuda()
        outputs = model(wavs)
        loss = criterion(outputs, labs)
        loss.backward()
        optimizer.step()
#     if epoch % 10 == 0:
    f1 = eval_model(model, val_loader)
    f1_train = eval_model(model, train_loader)
    print(f'epoch: {epoch}, f1_test: {f1}, f1_train: {f1_train}')
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), '../baseline_fulldiv.pt')
        
    lr = lr * 0.95
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

100%|██████████| 111/111 [00:13<00:00,  8.28it/s]
100%|██████████| 28/28 [00:01<00:00, 16.32it/s]
100%|██████████| 111/111 [00:07<00:00, 15.62it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 0, f1_test: 0.10629352612610805, f1_train: 0.11949593809290951


100%|██████████| 111/111 [00:13<00:00,  8.29it/s]
100%|██████████| 28/28 [00:01<00:00, 15.62it/s]
100%|██████████| 111/111 [00:07<00:00, 15.62it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 1, f1_test: 0.2285954430090747, f1_train: 0.2344379529881226


100%|██████████| 111/111 [00:13<00:00,  7.95it/s]
100%|██████████| 28/28 [00:01<00:00, 16.20it/s]
100%|██████████| 111/111 [00:06<00:00, 15.89it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 2, f1_test: 0.25355097340148225, f1_train: 0.29173411590335957


100%|██████████| 111/111 [00:13<00:00,  8.11it/s]
100%|██████████| 28/28 [00:01<00:00, 16.18it/s]
100%|██████████| 111/111 [00:07<00:00, 15.45it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 3, f1_test: 0.3561665093600878, f1_train: 0.4097082985139217


100%|██████████| 111/111 [00:13<00:00,  8.31it/s]
100%|██████████| 28/28 [00:01<00:00, 16.06it/s]
100%|██████████| 111/111 [00:07<00:00, 15.54it/s]
  1%|          | 1/111 [00:00<00:14,  7.59it/s]

epoch: 4, f1_test: 0.3494387466921597, f1_train: 0.45181671057461237


100%|██████████| 111/111 [00:13<00:00,  8.18it/s]
100%|██████████| 28/28 [00:01<00:00, 15.34it/s]
100%|██████████| 111/111 [00:07<00:00, 15.62it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 5, f1_test: 0.4328167708131114, f1_train: 0.5271681413449881


100%|██████████| 111/111 [00:13<00:00,  8.18it/s]
100%|██████████| 28/28 [00:01<00:00, 17.17it/s]
100%|██████████| 111/111 [00:07<00:00, 14.95it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 6, f1_test: 0.4746482252679639, f1_train: 0.5223790932627627


100%|██████████| 111/111 [00:13<00:00,  8.22it/s]
100%|██████████| 28/28 [00:01<00:00, 15.77it/s]
100%|██████████| 111/111 [00:07<00:00, 15.53it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 7, f1_test: 0.5529561680590374, f1_train: 0.6627944015908095


100%|██████████| 111/111 [00:13<00:00,  8.09it/s]
100%|██████████| 28/28 [00:01<00:00, 16.57it/s]
100%|██████████| 111/111 [00:07<00:00, 15.58it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 8, f1_test: 0.5610194290738628, f1_train: 0.696918293914471


100%|██████████| 111/111 [00:13<00:00,  8.27it/s]
100%|██████████| 28/28 [00:01<00:00, 16.09it/s]
100%|██████████| 111/111 [00:07<00:00, 15.43it/s]
  1%|          | 1/111 [00:00<00:14,  7.62it/s]

epoch: 9, f1_test: 0.48521044792437856, f1_train: 0.6039715630375571


100%|██████████| 111/111 [00:13<00:00,  8.16it/s]
100%|██████████| 28/28 [00:01<00:00, 15.54it/s]
100%|██████████| 111/111 [00:06<00:00, 16.05it/s]
  1%|          | 1/111 [00:00<00:14,  7.65it/s]

epoch: 10, f1_test: 0.53396085877329, f1_train: 0.690427261583096


100%|██████████| 111/111 [00:13<00:00,  8.14it/s]
100%|██████████| 28/28 [00:01<00:00, 17.08it/s]
100%|██████████| 111/111 [00:07<00:00, 15.78it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 11, f1_test: 0.5996557969970892, f1_train: 0.7674035868883995


100%|██████████| 111/111 [00:13<00:00,  8.29it/s]
100%|██████████| 28/28 [00:01<00:00, 16.29it/s]
100%|██████████| 111/111 [00:07<00:00, 15.23it/s]
  1%|          | 1/111 [00:00<00:19,  5.71it/s]

epoch: 12, f1_test: 0.5317764072598671, f1_train: 0.6863034606639081


100%|██████████| 111/111 [00:13<00:00,  8.32it/s]
100%|██████████| 28/28 [00:01<00:00, 15.80it/s]
100%|██████████| 111/111 [00:07<00:00, 15.14it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 13, f1_test: 0.5256393615720969, f1_train: 0.6747247578159984


100%|██████████| 111/111 [00:13<00:00,  8.24it/s]
100%|██████████| 28/28 [00:01<00:00, 16.77it/s]
100%|██████████| 111/111 [00:07<00:00, 15.23it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 14, f1_test: 0.5418535358430485, f1_train: 0.6913837244584565


100%|██████████| 111/111 [00:13<00:00,  8.15it/s]
100%|██████████| 28/28 [00:01<00:00, 15.75it/s]
100%|██████████| 111/111 [00:07<00:00, 15.53it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 15, f1_test: 0.6226587230729366, f1_train: 0.8014973351379933


100%|██████████| 111/111 [00:13<00:00,  8.02it/s]
100%|██████████| 28/28 [00:01<00:00, 14.33it/s]
100%|██████████| 111/111 [00:07<00:00, 15.21it/s]
  1%|          | 1/111 [00:00<00:14,  7.69it/s]

epoch: 16, f1_test: 0.42308197459757096, f1_train: 0.5504070654979158


100%|██████████| 111/111 [00:13<00:00,  8.32it/s]
100%|██████████| 28/28 [00:01<00:00, 17.15it/s]
100%|██████████| 111/111 [00:07<00:00, 15.73it/s]
  1%|          | 1/111 [00:00<00:14,  7.58it/s]

epoch: 17, f1_test: 0.49867998540384506, f1_train: 0.658857005069475


100%|██████████| 111/111 [00:13<00:00,  8.20it/s]
100%|██████████| 28/28 [00:01<00:00, 17.15it/s]
100%|██████████| 111/111 [00:07<00:00, 15.66it/s]
  1%|          | 1/111 [00:00<00:14,  7.53it/s]

epoch: 18, f1_test: 0.5579497821999947, f1_train: 0.7319272730225758


100%|██████████| 111/111 [00:13<00:00,  8.26it/s]
100%|██████████| 28/28 [00:01<00:00, 17.11it/s]
100%|██████████| 111/111 [00:06<00:00, 16.30it/s]
  1%|          | 1/111 [00:00<00:15,  7.06it/s]

epoch: 19, f1_test: 0.5571325006232938, f1_train: 0.6957903774635286


100%|██████████| 111/111 [00:13<00:00,  8.15it/s]
100%|██████████| 28/28 [00:01<00:00, 16.62it/s]
100%|██████████| 111/111 [00:07<00:00, 15.64it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 20, f1_test: 0.46955998060031356, f1_train: 0.6558132821268612


100%|██████████| 111/111 [00:14<00:00,  7.49it/s]
100%|██████████| 28/28 [00:01<00:00, 14.14it/s]
100%|██████████| 111/111 [00:08<00:00, 13.50it/s]
  1%|          | 1/111 [00:00<00:15,  7.01it/s]

epoch: 21, f1_test: 0.48981598954550704, f1_train: 0.650733932483711


100%|██████████| 111/111 [00:13<00:00,  8.19it/s]
100%|██████████| 28/28 [00:01<00:00, 17.03it/s]
100%|██████████| 111/111 [00:07<00:00, 14.33it/s]
  1%|          | 1/111 [00:00<00:14,  7.53it/s]

epoch: 22, f1_test: 0.48420795722671806, f1_train: 0.6607939313215059


100%|██████████| 111/111 [00:14<00:00,  7.80it/s]
100%|██████████| 28/28 [00:01<00:00, 15.58it/s]
100%|██████████| 111/111 [00:07<00:00, 14.05it/s]
  1%|          | 1/111 [00:00<00:14,  7.66it/s]

epoch: 23, f1_test: 0.4508735701586984, f1_train: 0.5896833421726222


100%|██████████| 111/111 [00:13<00:00,  8.09it/s]
100%|██████████| 28/28 [00:01<00:00, 17.32it/s]
100%|██████████| 111/111 [00:07<00:00, 14.76it/s]
  1%|          | 1/111 [00:00<00:14,  7.58it/s]

epoch: 24, f1_test: 0.571814169964218, f1_train: 0.7611057880835792


100%|██████████| 111/111 [00:13<00:00,  8.04it/s]
100%|██████████| 28/28 [00:01<00:00, 16.14it/s]
100%|██████████| 111/111 [00:07<00:00, 14.59it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 25, f1_test: 0.5233877218544075, f1_train: 0.7131283965364282


100%|██████████| 111/111 [00:14<00:00,  7.49it/s]
100%|██████████| 28/28 [00:01<00:00, 15.80it/s]
100%|██████████| 111/111 [00:07<00:00, 13.96it/s]
  1%|          | 1/111 [00:00<00:14,  7.64it/s]

epoch: 26, f1_test: 0.5341731840891024, f1_train: 0.7076250982567919


100%|██████████| 111/111 [00:14<00:00,  7.90it/s]
100%|██████████| 28/28 [00:01<00:00, 15.27it/s]
100%|██████████| 111/111 [00:07<00:00, 15.26it/s]
  1%|          | 1/111 [00:00<00:19,  5.65it/s]

epoch: 27, f1_test: 0.5473086976098137, f1_train: 0.7375620636068481


100%|██████████| 111/111 [00:14<00:00,  7.83it/s]
100%|██████████| 28/28 [00:01<00:00, 14.75it/s]
100%|██████████| 111/111 [00:07<00:00, 14.54it/s]
  1%|          | 1/111 [00:00<00:16,  6.86it/s]

epoch: 28, f1_test: 0.5083482172203078, f1_train: 0.6905922529062383


100%|██████████| 111/111 [00:14<00:00,  7.76it/s]
100%|██████████| 28/28 [00:02<00:00, 13.70it/s]
100%|██████████| 111/111 [00:08<00:00, 13.54it/s]
  1%|          | 1/111 [00:00<00:14,  7.66it/s]

epoch: 29, f1_test: 0.5513084368952619, f1_train: 0.7726534498333878


100%|██████████| 111/111 [00:14<00:00,  7.77it/s]
100%|██████████| 28/28 [00:01<00:00, 15.59it/s]
100%|██████████| 111/111 [00:08<00:00, 13.79it/s]
  1%|          | 1/111 [00:00<00:16,  6.82it/s]

epoch: 30, f1_test: 0.594515325055727, f1_train: 0.8269128506549175


100%|██████████| 111/111 [00:14<00:00,  7.64it/s]
100%|██████████| 28/28 [00:02<00:00, 12.37it/s]
100%|██████████| 111/111 [00:07<00:00, 14.09it/s]
  1%|          | 1/111 [00:00<00:20,  5.30it/s]

epoch: 31, f1_test: 0.5583317614606507, f1_train: 0.7563482798684882


100%|██████████| 111/111 [00:14<00:00,  7.63it/s]
100%|██████████| 28/28 [00:01<00:00, 14.78it/s]
100%|██████████| 111/111 [00:07<00:00, 14.12it/s]
  1%|          | 1/111 [00:00<00:15,  6.97it/s]

epoch: 32, f1_test: 0.5941509798244281, f1_train: 0.8002843711761052


100%|██████████| 111/111 [00:14<00:00,  7.52it/s]
100%|██████████| 28/28 [00:02<00:00, 13.02it/s]
100%|██████████| 111/111 [00:08<00:00, 13.76it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 33, f1_test: 0.5140710680051351, f1_train: 0.6838671653410366


100%|██████████| 111/111 [00:14<00:00,  7.71it/s]
100%|██████████| 28/28 [00:01<00:00, 15.58it/s]
100%|██████████| 111/111 [00:08<00:00, 13.61it/s]
  1%|          | 1/111 [00:00<00:18,  5.89it/s]

epoch: 34, f1_test: 0.5970381640976719, f1_train: 0.8124793248735787


100%|██████████| 111/111 [00:14<00:00,  7.67it/s]
100%|██████████| 28/28 [00:01<00:00, 15.37it/s]
100%|██████████| 111/111 [00:07<00:00, 14.35it/s]
  1%|          | 1/111 [00:00<00:15,  6.95it/s]

epoch: 35, f1_test: 0.5479606898586282, f1_train: 0.7397387901488383


100%|██████████| 111/111 [00:14<00:00,  7.71it/s]
100%|██████████| 28/28 [00:02<00:00, 13.28it/s]
100%|██████████| 111/111 [00:07<00:00, 14.02it/s]
  1%|          | 1/111 [00:00<00:14,  7.54it/s]

epoch: 36, f1_test: 0.6114975741330437, f1_train: 0.8352293229768932


100%|██████████| 111/111 [00:14<00:00,  7.61it/s]
100%|██████████| 28/28 [00:01<00:00, 14.37it/s]
100%|██████████| 111/111 [00:07<00:00, 14.05it/s]
  1%|          | 1/111 [00:00<00:17,  6.18it/s]

epoch: 37, f1_test: 0.502708376143018, f1_train: 0.6816692501871874


100%|██████████| 111/111 [00:14<00:00,  7.85it/s]
100%|██████████| 28/28 [00:01<00:00, 16.58it/s]
100%|██████████| 111/111 [00:07<00:00, 14.10it/s]
  1%|          | 1/111 [00:00<00:17,  6.12it/s]

epoch: 38, f1_test: 0.5708448016964366, f1_train: 0.7498244171365385


100%|██████████| 111/111 [00:14<00:00,  7.84it/s]
100%|██████████| 28/28 [00:01<00:00, 15.39it/s]
100%|██████████| 111/111 [00:08<00:00, 13.47it/s]
  1%|          | 1/111 [00:00<00:15,  7.26it/s]

epoch: 39, f1_test: 0.5879043397739223, f1_train: 0.7929470964158711


100%|██████████| 111/111 [00:14<00:00,  7.78it/s]
100%|██████████| 28/28 [00:01<00:00, 15.33it/s]
100%|██████████| 111/111 [00:08<00:00, 13.75it/s]
  1%|          | 1/111 [00:00<00:15,  7.11it/s]

epoch: 40, f1_test: 0.5062550721992987, f1_train: 0.7154368513125894


100%|██████████| 111/111 [00:14<00:00,  7.53it/s]
100%|██████████| 28/28 [00:02<00:00, 13.71it/s]
100%|██████████| 111/111 [00:08<00:00, 13.63it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 41, f1_test: 0.6252973186176202, f1_train: 0.8505659057903101


100%|██████████| 111/111 [00:14<00:00,  7.82it/s]
100%|██████████| 28/28 [00:02<00:00, 13.66it/s]
100%|██████████| 111/111 [00:08<00:00, 13.82it/s]
  1%|          | 1/111 [00:00<00:16,  6.87it/s]

epoch: 42, f1_test: 0.6111373805075606, f1_train: 0.7887227495054853


100%|██████████| 111/111 [00:14<00:00,  7.73it/s]
100%|██████████| 28/28 [00:02<00:00, 13.23it/s]
100%|██████████| 111/111 [00:07<00:00, 14.82it/s]
  1%|          | 1/111 [00:00<00:16,  6.74it/s]

epoch: 43, f1_test: 0.5286643820007692, f1_train: 0.7337212960601662


100%|██████████| 111/111 [00:15<00:00,  7.30it/s]
100%|██████████| 28/28 [00:01<00:00, 14.39it/s]
100%|██████████| 111/111 [00:08<00:00, 13.76it/s]
  1%|          | 1/111 [00:00<00:14,  7.42it/s]

epoch: 44, f1_test: 0.5956040081370013, f1_train: 0.8300039085580488


100%|██████████| 111/111 [00:14<00:00,  7.83it/s]
100%|██████████| 28/28 [00:01<00:00, 15.77it/s]
100%|██████████| 111/111 [00:07<00:00, 15.53it/s]
  1%|          | 1/111 [00:00<00:14,  7.54it/s]

epoch: 45, f1_test: 0.5644468974243524, f1_train: 0.7662344297044733


100%|██████████| 111/111 [00:13<00:00,  8.08it/s]
100%|██████████| 28/28 [00:01<00:00, 15.67it/s]
100%|██████████| 111/111 [00:07<00:00, 15.71it/s]
  1%|          | 1/111 [00:00<00:14,  7.59it/s]

epoch: 46, f1_test: 0.5497737424438712, f1_train: 0.7428294431954231


100%|██████████| 111/111 [00:13<00:00,  8.03it/s]
100%|██████████| 28/28 [00:01<00:00, 17.19it/s]
100%|██████████| 111/111 [00:07<00:00, 15.10it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 47, f1_test: 0.667015937016888, f1_train: 0.9370448561961691


100%|██████████| 111/111 [00:13<00:00,  8.03it/s]
100%|██████████| 28/28 [00:01<00:00, 16.56it/s]
100%|██████████| 111/111 [00:07<00:00, 14.86it/s]
  1%|          | 1/111 [00:00<00:14,  7.54it/s]

epoch: 48, f1_test: 0.59345882150071, f1_train: 0.8171683320930807


100%|██████████| 111/111 [00:13<00:00,  8.17it/s]
100%|██████████| 28/28 [00:01<00:00, 16.25it/s]
100%|██████████| 111/111 [00:07<00:00, 15.16it/s]
  1%|          | 1/111 [00:00<00:14,  7.59it/s]

epoch: 49, f1_test: 0.5879719740163387, f1_train: 0.8222908617749788


100%|██████████| 111/111 [00:13<00:00,  8.12it/s]
100%|██████████| 28/28 [00:01<00:00, 15.42it/s]
100%|██████████| 111/111 [00:07<00:00, 15.07it/s]
  1%|          | 1/111 [00:00<00:14,  7.52it/s]

epoch: 50, f1_test: 0.5819638496946737, f1_train: 0.831340747459372


100%|██████████| 111/111 [00:14<00:00,  7.86it/s]
100%|██████████| 28/28 [00:01<00:00, 17.25it/s]
100%|██████████| 111/111 [00:07<00:00, 15.11it/s]
  1%|          | 1/111 [00:00<00:14,  7.67it/s]

epoch: 51, f1_test: 0.5602648771549301, f1_train: 0.7694531491527006


100%|██████████| 111/111 [00:13<00:00,  8.00it/s]
100%|██████████| 28/28 [00:01<00:00, 15.81it/s]
100%|██████████| 111/111 [00:07<00:00, 15.00it/s]
  1%|          | 1/111 [00:00<00:14,  7.64it/s]

epoch: 52, f1_test: 0.6540697523698289, f1_train: 0.8966780785843255


100%|██████████| 111/111 [00:13<00:00,  8.24it/s]
100%|██████████| 28/28 [00:01<00:00, 15.95it/s]
100%|██████████| 111/111 [00:07<00:00, 15.47it/s]
  1%|          | 1/111 [00:00<00:14,  7.60it/s]

epoch: 53, f1_test: 0.5979072065353643, f1_train: 0.8376848861974048


100%|██████████| 111/111 [00:13<00:00,  8.07it/s]
100%|██████████| 28/28 [00:01<00:00, 15.42it/s]
100%|██████████| 111/111 [00:07<00:00, 15.44it/s]
  1%|          | 1/111 [00:00<00:14,  7.67it/s]

epoch: 54, f1_test: 0.582892471187802, f1_train: 0.8061160909971169


100%|██████████| 111/111 [00:13<00:00,  8.10it/s]
100%|██████████| 28/28 [00:01<00:00, 14.75it/s]
100%|██████████| 111/111 [00:06<00:00, 16.05it/s]
  1%|          | 1/111 [00:00<00:14,  7.57it/s]

epoch: 55, f1_test: 0.5926460967201866, f1_train: 0.8180962275848912


100%|██████████| 111/111 [00:13<00:00,  8.07it/s]
100%|██████████| 28/28 [00:01<00:00, 15.94it/s]
100%|██████████| 111/111 [00:07<00:00, 14.78it/s]
  1%|          | 1/111 [00:00<00:14,  7.60it/s]

epoch: 56, f1_test: 0.5866476535193236, f1_train: 0.8027100360425875


100%|██████████| 111/111 [00:13<00:00,  8.28it/s]
100%|██████████| 28/28 [00:01<00:00, 16.94it/s]
100%|██████████| 111/111 [00:07<00:00, 14.82it/s]
  1%|          | 1/111 [00:00<00:14,  7.54it/s]

epoch: 57, f1_test: 0.5576817707664975, f1_train: 0.7627134282126853


100%|██████████| 111/111 [00:13<00:00,  8.14it/s]
100%|██████████| 28/28 [00:01<00:00, 16.07it/s]
100%|██████████| 111/111 [00:07<00:00, 14.76it/s]
  1%|          | 1/111 [00:00<00:14,  7.59it/s]

epoch: 58, f1_test: 0.6403069437722618, f1_train: 0.8869356171278581


100%|██████████| 111/111 [00:13<00:00,  8.24it/s]
100%|██████████| 28/28 [00:01<00:00, 16.50it/s]
100%|██████████| 111/111 [00:07<00:00, 14.80it/s]
  1%|          | 1/111 [00:00<00:14,  7.64it/s]

epoch: 59, f1_test: 0.6233325155086638, f1_train: 0.8314687704112131


100%|██████████| 111/111 [00:13<00:00,  8.23it/s]
100%|██████████| 28/28 [00:01<00:00, 14.48it/s]
100%|██████████| 111/111 [00:07<00:00, 15.82it/s]
  1%|          | 1/111 [00:00<00:14,  7.57it/s]

epoch: 60, f1_test: 0.6111797245309151, f1_train: 0.8500943930465292


100%|██████████| 111/111 [00:14<00:00,  7.67it/s]
100%|██████████| 28/28 [00:01<00:00, 14.94it/s]
100%|██████████| 111/111 [00:07<00:00, 15.44it/s]
  1%|          | 1/111 [00:00<00:14,  7.40it/s]

epoch: 61, f1_test: 0.5841503500434735, f1_train: 0.8152748429040421


100%|██████████| 111/111 [00:13<00:00,  8.18it/s]
100%|██████████| 28/28 [00:01<00:00, 16.87it/s]
100%|██████████| 111/111 [00:07<00:00, 15.71it/s]
  1%|          | 1/111 [00:00<00:14,  7.67it/s]

epoch: 62, f1_test: 0.6490465731669248, f1_train: 0.8738844130944925


100%|██████████| 111/111 [00:14<00:00,  7.83it/s]
100%|██████████| 28/28 [00:01<00:00, 16.52it/s]
100%|██████████| 111/111 [00:07<00:00, 15.29it/s]
  1%|          | 1/111 [00:00<00:14,  7.58it/s]

epoch: 63, f1_test: 0.5980012113678181, f1_train: 0.819128921017111


100%|██████████| 111/111 [00:13<00:00,  8.10it/s]
100%|██████████| 28/28 [00:01<00:00, 15.58it/s]
100%|██████████| 111/111 [00:07<00:00, 15.53it/s]
  1%|          | 1/111 [00:00<00:14,  7.61it/s]

epoch: 64, f1_test: 0.5753841921511882, f1_train: 0.7936820605801598


100%|██████████| 111/111 [00:13<00:00,  8.12it/s]
100%|██████████| 28/28 [00:01<00:00, 16.69it/s]
100%|██████████| 111/111 [00:07<00:00, 15.76it/s]
  1%|          | 1/111 [00:00<00:14,  7.59it/s]

epoch: 65, f1_test: 0.5669555213556223, f1_train: 0.7864931656562749


100%|██████████| 111/111 [00:14<00:00,  7.74it/s]
100%|██████████| 28/28 [00:01<00:00, 16.60it/s]
100%|██████████| 111/111 [00:07<00:00, 15.34it/s]
  1%|          | 1/111 [00:00<00:14,  7.55it/s]

epoch: 66, f1_test: 0.5758809289694308, f1_train: 0.7894801496741621


100%|██████████| 111/111 [00:13<00:00,  8.02it/s]
100%|██████████| 28/28 [00:01<00:00, 16.98it/s]
100%|██████████| 111/111 [00:07<00:00, 15.72it/s]
  1%|          | 1/111 [00:00<00:14,  7.60it/s]

epoch: 67, f1_test: 0.5793752970296472, f1_train: 0.7995635484922041


100%|██████████| 111/111 [00:13<00:00,  7.98it/s]
100%|██████████| 28/28 [00:01<00:00, 16.83it/s]
100%|██████████| 111/111 [00:06<00:00, 15.89it/s]
  1%|          | 1/111 [00:00<00:14,  7.58it/s]

epoch: 68, f1_test: 0.6533397293366141, f1_train: 0.8787712033238771


100%|██████████| 111/111 [00:13<00:00,  8.19it/s]
100%|██████████| 28/28 [00:01<00:00, 16.79it/s]
100%|██████████| 111/111 [00:07<00:00, 15.72it/s]
  1%|          | 1/111 [00:00<00:14,  7.60it/s]

epoch: 69, f1_test: 0.6143461645590067, f1_train: 0.8316362203853974


100%|██████████| 111/111 [00:13<00:00,  7.95it/s]
100%|██████████| 28/28 [00:01<00:00, 15.96it/s]
100%|██████████| 111/111 [00:06<00:00, 15.87it/s]
  1%|          | 1/111 [00:00<00:14,  7.41it/s]

epoch: 70, f1_test: 0.5672506727198753, f1_train: 0.7680698174491756


100%|██████████| 111/111 [00:14<00:00,  7.93it/s]
100%|██████████| 28/28 [00:01<00:00, 16.55it/s]
100%|██████████| 111/111 [00:07<00:00, 15.33it/s]
  1%|          | 1/111 [00:00<00:18,  5.93it/s]

epoch: 71, f1_test: 0.6117411214544367, f1_train: 0.8468875587473783


100%|██████████| 111/111 [00:15<00:00,  7.34it/s]
100%|██████████| 28/28 [00:02<00:00, 12.51it/s]
100%|██████████| 111/111 [00:08<00:00, 13.77it/s]
  1%|          | 1/111 [00:00<00:15,  6.97it/s]

epoch: 72, f1_test: 0.5903563306658821, f1_train: 0.8322547375454714


100%|██████████| 111/111 [00:14<00:00,  7.60it/s]
100%|██████████| 28/28 [00:01<00:00, 17.11it/s]
100%|██████████| 111/111 [00:07<00:00, 15.19it/s]
  1%|          | 1/111 [00:00<00:14,  7.48it/s]

epoch: 73, f1_test: 0.5746663070254596, f1_train: 0.7936453581633662


100%|██████████| 111/111 [00:13<00:00,  8.37it/s]
100%|██████████| 28/28 [00:01<00:00, 15.35it/s]
100%|██████████| 111/111 [00:07<00:00, 15.15it/s]
  1%|          | 1/111 [00:00<00:14,  7.56it/s]

epoch: 74, f1_test: 0.5964844261354332, f1_train: 0.8198766081960296


100%|██████████| 111/111 [00:13<00:00,  8.18it/s]
100%|██████████| 28/28 [00:01<00:00, 17.03it/s]
100%|██████████| 111/111 [00:07<00:00, 15.46it/s]
  1%|          | 1/111 [00:00<00:14,  7.58it/s]

epoch: 75, f1_test: 0.5799678401610108, f1_train: 0.7962146896867508


100%|██████████| 111/111 [00:13<00:00,  8.21it/s]
100%|██████████| 28/28 [00:01<00:00, 16.83it/s]
100%|██████████| 111/111 [00:06<00:00, 16.15it/s]
  1%|          | 1/111 [00:00<00:14,  7.55it/s]

epoch: 76, f1_test: 0.5983982955591436, f1_train: 0.8288971403228306


100%|██████████| 111/111 [00:14<00:00,  7.72it/s]
100%|██████████| 28/28 [00:01<00:00, 14.00it/s]
100%|██████████| 111/111 [00:09<00:00, 11.60it/s]
  1%|          | 1/111 [00:00<00:15,  7.24it/s]

epoch: 77, f1_test: 0.5920983485153527, f1_train: 0.8340713921309678


100%|██████████| 111/111 [00:15<00:00,  7.25it/s]
100%|██████████| 28/28 [00:01<00:00, 15.64it/s]
100%|██████████| 111/111 [00:06<00:00, 15.89it/s]
  1%|          | 1/111 [00:00<00:14,  7.57it/s]

epoch: 78, f1_test: 0.6151718495229238, f1_train: 0.8463174368624474


100%|██████████| 111/111 [00:13<00:00,  8.24it/s]
100%|██████████| 28/28 [00:01<00:00, 16.83it/s]
100%|██████████| 111/111 [00:07<00:00, 14.80it/s]
  1%|          | 1/111 [00:00<00:14,  7.62it/s]

epoch: 79, f1_test: 0.5680156569605705, f1_train: 0.8037536913464516


100%|██████████| 111/111 [00:13<00:00,  8.16it/s]
100%|██████████| 28/28 [00:01<00:00, 16.03it/s]
100%|██████████| 111/111 [00:07<00:00, 14.81it/s]
  1%|          | 1/111 [00:00<00:14,  7.65it/s]

epoch: 80, f1_test: 0.5953446185682619, f1_train: 0.8253539916686162


100%|██████████| 111/111 [00:13<00:00,  8.31it/s]
100%|██████████| 28/28 [00:01<00:00, 15.71it/s]
100%|██████████| 111/111 [00:06<00:00, 15.86it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 81, f1_test: 0.5761642475136113, f1_train: 0.7977725837347271


100%|██████████| 111/111 [00:13<00:00,  8.04it/s]
100%|██████████| 28/28 [00:01<00:00, 15.17it/s]
100%|██████████| 111/111 [00:08<00:00, 12.57it/s]
  1%|          | 1/111 [00:00<00:15,  7.22it/s]

epoch: 82, f1_test: 0.5961413184294474, f1_train: 0.8240459726751141


100%|██████████| 111/111 [00:14<00:00,  7.55it/s]
100%|██████████| 28/28 [00:01<00:00, 16.87it/s]
100%|██████████| 111/111 [00:07<00:00, 15.43it/s]
  1%|          | 1/111 [00:00<00:14,  7.62it/s]

epoch: 83, f1_test: 0.5741595580947121, f1_train: 0.7943838298842252


100%|██████████| 111/111 [00:13<00:00,  8.19it/s]
100%|██████████| 28/28 [00:01<00:00, 16.74it/s]
100%|██████████| 111/111 [00:07<00:00, 14.92it/s]
  0%|          | 0/111 [00:00<?, ?it/s]

epoch: 84, f1_test: 0.5867677109048006, f1_train: 0.8238582879556913


100%|██████████| 111/111 [00:13<00:00,  8.09it/s]
100%|██████████| 28/28 [00:01<00:00, 16.59it/s]
100%|██████████| 111/111 [00:07<00:00, 15.84it/s]
  1%|          | 1/111 [00:00<00:14,  7.57it/s]

epoch: 85, f1_test: 0.6101148049296465, f1_train: 0.8392029636262524


100%|██████████| 111/111 [00:13<00:00,  8.22it/s]
100%|██████████| 28/28 [00:01<00:00, 16.91it/s]
100%|██████████| 111/111 [00:07<00:00, 15.29it/s]
  1%|          | 1/111 [00:00<00:14,  7.60it/s]

epoch: 86, f1_test: 0.5890405435417683, f1_train: 0.820247113798823


100%|██████████| 111/111 [00:13<00:00,  8.30it/s]
100%|██████████| 28/28 [00:01<00:00, 15.72it/s]
100%|██████████| 111/111 [00:07<00:00, 13.99it/s]
  1%|          | 1/111 [00:00<00:15,  6.93it/s]

epoch: 87, f1_test: 0.5980981549755122, f1_train: 0.8356972736466732


100%|██████████| 111/111 [00:15<00:00,  7.11it/s]
100%|██████████| 28/28 [00:01<00:00, 14.37it/s]
100%|██████████| 111/111 [00:07<00:00, 14.72it/s]
  1%|          | 1/111 [00:00<00:14,  7.64it/s]

epoch: 88, f1_test: 0.6107625537127437, f1_train: 0.828531595933628


100%|██████████| 111/111 [00:13<00:00,  8.24it/s]
100%|██████████| 28/28 [00:01<00:00, 16.81it/s]
100%|██████████| 111/111 [00:07<00:00, 14.79it/s]
  1%|          | 1/111 [00:00<00:14,  7.62it/s]

epoch: 89, f1_test: 0.5973138493476995, f1_train: 0.826230694138632


100%|██████████| 111/111 [00:13<00:00,  7.95it/s]
100%|██████████| 28/28 [00:01<00:00, 16.91it/s]
100%|██████████| 111/111 [00:07<00:00, 15.17it/s]
  1%|          | 1/111 [00:00<00:14,  7.62it/s]

epoch: 90, f1_test: 0.6252921004675466, f1_train: 0.8318156346116717


100%|██████████| 111/111 [00:13<00:00,  8.21it/s]
100%|██████████| 28/28 [00:02<00:00, 12.71it/s]
100%|██████████| 111/111 [00:07<00:00, 15.61it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 91, f1_test: 0.6178930636089086, f1_train: 0.8570748523307519


100%|██████████| 111/111 [00:13<00:00,  8.18it/s]
100%|██████████| 28/28 [00:01<00:00, 16.34it/s]
100%|██████████| 111/111 [00:07<00:00, 14.75it/s]
  1%|          | 1/111 [00:00<00:14,  7.65it/s]

epoch: 92, f1_test: 0.5810700157798249, f1_train: 0.8219680630526542


100%|██████████| 111/111 [00:15<00:00,  7.00it/s]
100%|██████████| 28/28 [00:01<00:00, 16.31it/s]
100%|██████████| 111/111 [00:08<00:00, 13.06it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 93, f1_test: 0.5858783355916727, f1_train: 0.8020294917334412


100%|██████████| 111/111 [00:13<00:00,  8.05it/s]
100%|██████████| 28/28 [00:01<00:00, 14.44it/s]
100%|██████████| 111/111 [00:07<00:00, 15.50it/s]
  1%|          | 1/111 [00:00<00:14,  7.63it/s]

epoch: 94, f1_test: 0.6033179688121204, f1_train: 0.8262599397800517


100%|██████████| 111/111 [00:14<00:00,  7.86it/s]
100%|██████████| 28/28 [00:01<00:00, 16.17it/s]
100%|██████████| 111/111 [00:07<00:00, 15.56it/s]
  1%|          | 1/111 [00:00<00:14,  7.54it/s]

epoch: 95, f1_test: 0.6104338015295782, f1_train: 0.8176572660797187


100%|██████████| 111/111 [00:13<00:00,  8.23it/s]
100%|██████████| 28/28 [00:01<00:00, 16.49it/s]
100%|██████████| 111/111 [00:07<00:00, 15.52it/s]
  1%|          | 1/111 [00:00<00:14,  7.61it/s]

epoch: 96, f1_test: 0.6005834219551153, f1_train: 0.8367462670515983


100%|██████████| 111/111 [00:13<00:00,  8.11it/s]
100%|██████████| 28/28 [00:01<00:00, 16.37it/s]
100%|██████████| 111/111 [00:07<00:00, 15.12it/s]
  1%|          | 1/111 [00:00<00:20,  5.49it/s]

epoch: 97, f1_test: 0.5828030508041074, f1_train: 0.8271996039726968


100%|██████████| 111/111 [00:14<00:00,  7.91it/s]
100%|██████████| 28/28 [00:01<00:00, 14.47it/s]
100%|██████████| 111/111 [00:07<00:00, 14.92it/s]
  1%|          | 1/111 [00:00<00:14,  7.56it/s]

epoch: 98, f1_test: 0.6181327235393923, f1_train: 0.8340865443475837


100%|██████████| 111/111 [00:13<00:00,  7.98it/s]
100%|██████████| 28/28 [00:02<00:00, 13.00it/s]
100%|██████████| 111/111 [00:07<00:00, 15.50it/s]

epoch: 99, f1_test: 0.6140860930295599, f1_train: 0.8440663933052893





In [82]:
# make a model
model_name = 'baseline_fulldiv.pt'
model = BaseLineModel().cuda()
model.load_state_dict(torch.load(os.path.join('/content/drive/My drive/Colab Notebooks', model_name)))
model.eval()
forecast = []
with torch.no_grad():
    for wavs in tqdm(test_loader):
        wavs = wavs.cuda()
        outputs = model.inference(wavs)
        outputs = outputs.detach().cpu().numpy().argmax(axis=1)
        forecast.append(outputs)
forecast = [x for sublist in forecast for x in sublist]
decoder = {classes_dict[cl]:cl for cl in classes_dict}
forecast = pd.Series(forecast).map(decoder)
df_test['label'] = forecast
df_test.to_csv(f'{model_name}.csv', index=None)

  2%|▏         | 2/93 [00:00<00:05, 15.58it/s]

Loaded pretrained weights for efficientnet-b0


100%|██████████| 93/93 [00:06<00:00, 15.36it/s]
