In [2]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader

from src.data import AudioDatasetInference
from src.models import BasicClassifier
from src.utils import score

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
base_dir = 'data'
train_dir = base_dir + '/train_audio/'
test_dir = base_dir + '/test_soundscapes/'
unlabeled_dir = base_dir + '/unlabeled_soundscapes/'

class_names = sorted(os.listdir(train_dir))
n_classes = len(class_names)
class_labels = list(range(n_classes))
label2name = dict(zip(class_labels, class_names))
name2label = {v:k for k,v in label2name.items()}

In [4]:
test_df = pd.read_csv('valid_df.csv')
files = test_df['filepath']
targets = test_df['target']

In [5]:
n_classes = 182
test_dataset = AudioDatasetInference(files, targets=None, n_classes=n_classes, duration=5)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)

In [6]:
device = torch.device('cuda')

model = BasicClassifier(n_classes, pretrained=True).to(device)
checkpoint_name = "checkpoints/efficientnet_v2_s_imagenet_base_32.pth"
checkpoint = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [10]:
ids = []
preds = np.empty(shape=(0, n_classes), dtype='float32')

test_iter = tqdm(range(len(test_dataset)))
for i in test_iter:
    specs, file = test_dataset[i]
    filename = file.split('/')[-1][:-4]
    specs = specs.to(device)
    
    with torch.no_grad():
        outs = model(specs)
        outs = nn.functional.softmax(outs, dim=1).detach().cpu()

    frame_ids = [f'{filename}_{(frame_id+1)*5}' for frame_id in range(len(specs))]
    ids += frame_ids

    preds = np.concatenate([preds, outs], axis=0)

100%|██████████| 4892/4892 [09:30<00:00,  8.58it/s] 


In [13]:
preds.shape

(41777, 182)

Unnamed: 0,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,barswa,...,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1,target
count,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,...,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0,41777.0
mean,0.006201795,0.01062672,0.005032634,0.0004724374,0.01193423,0.0001523317,0.0006526313,0.0007585039,0.0002793207,0.0324964,...,0.0003909016,0.004662952,0.003778348,0.006037107,0.01126869,0.0002273498,0.0008745489,0.0001223989,0.009589118,73.604711
std,0.01666745,0.02569967,0.02179432,0.003291188,0.03290289,0.001408183,0.002779891,0.006007601,0.001935249,0.07120551,...,0.003213739,0.02277466,0.01621723,0.01861588,0.03493118,0.0009412312,0.005407379,0.001137161,0.03294986,50.455416
min,3.863898e-10,1.231771e-08,8.428349e-10,1.996736e-11,2.857773e-10,4.819793e-12,9.834163e-12,9.394703e-14,5.584176e-12,6.51928e-08,...,4.335576e-12,2.942689e-09,2.771318e-10,2.04875e-07,9.509377e-11,2.5482e-12,1.193256e-12,2.33877e-12,3.880675e-07,0.0
25%,0.0007668906,0.0009264194,0.0003444882,1.874943e-05,0.001320408,1.841844e-06,5.427812e-05,2.180173e-05,9.379202e-06,0.004048649,...,2.108786e-05,0.000231537,0.0002183181,0.0006262859,0.001114765,8.8345e-06,4.31499e-05,1.665278e-06,0.001083554,33.0
50%,0.002241388,0.003225964,0.001162751,7.036459e-05,0.00395436,7.607206e-06,0.0001747055,9.486527e-05,3.762457e-05,0.01170743,...,7.306557e-05,0.0007852175,0.0007797772,0.001875969,0.003359729,3.719285e-05,0.0001555454,7.597778e-06,0.002908994,70.0
75%,0.00594815,0.009927367,0.003540781,0.0002413811,0.01102945,2.971022e-05,0.0005096063,0.000360746,0.0001368521,0.0309886,...,0.0002358265,0.002584768,0.002472342,0.005244429,0.009422971,0.0001344772,0.0005176723,3.181732e-05,0.007436871,106.0
max,0.6333799,0.8097235,0.9888772,0.2076199,0.9628189,0.08207593,0.209274,0.5053285,0.1718858,0.9940274,...,0.4724048,0.9706143,0.7958552,0.7115533,0.9793032,0.04059278,0.4231272,0.07246704,0.981944,181.0


In [11]:
pred_df = pd.DataFrame(ids, columns=['row_id'])
pred_df.loc[:, class_names] = preds

def get_target(row_id):
    name = row_id.split('_')[0]
    target = test_df.loc[test_df['xc_id'] == name, 'target'].values[0]
    return target

pred_df['target'] = pred_df.row_id.map(get_target)
submission = pred_df[class_names]

solution = pd.DataFrame()
for class_name in class_names:
    # Create a new column where the value is 1 if the label matches the class index, and 0 otherwise
    solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)

  pred_df['target'] = pred_df.row_id.map(get_target)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df['target'] == name2label[class_name]).astype(int)
  solution[class_name] = (pred_df[

In [12]:
score(solution, submission, 'row_id')

0.6097012204702028

In [18]:
pred_df.head(30)

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1,target
0,XC756601_5,0.00263,0.01508,0.001529,2.8e-05,0.003712,1.806909e-06,0.000369,1.081393e-06,3.676947e-05,...,0.000568,0.000103,6.6e-05,0.001026,0.000126,0.0001861497,2e-06,0.0006483643,8.8e-05,139
1,XC756601_10,0.000685,0.000622,0.000238,8e-06,0.001599,1.940699e-07,5.9e-05,1.203048e-07,1.714421e-05,...,0.000464,4.8e-05,1.5e-05,0.000208,0.001796,2.619723e-06,1.4e-05,0.0003804449,6.2e-05,139
2,XC756601_15,0.005079,0.000165,0.003467,3e-06,0.001357,9.642807e-08,0.000158,5.396671e-06,9.081466e-06,...,0.000261,3.3e-05,5.6e-05,0.003161,0.000308,2.353099e-05,6e-06,0.001028877,4.3e-05,139
3,XC756601_20,0.001118,0.003576,0.000853,6.5e-05,0.048988,3.553077e-05,0.000173,0.001332497,0.0001812702,...,0.000299,0.001274,0.009416,0.008995,0.000577,2.784823e-05,0.00176,4.488797e-06,0.00307,139
4,XC756601_25,0.002985,0.001214,7e-06,7e-06,0.0061,1.708488e-06,0.000635,6.468e-05,7.462902e-06,...,1e-06,0.000511,0.000179,0.000461,0.012366,3.464679e-06,3.3e-05,3.77071e-08,0.000323,139
5,XC756601_30,0.000589,0.000132,0.007702,1.9e-05,0.000435,2.823442e-05,4.2e-05,0.001281502,6.44524e-05,...,1.4e-05,0.060516,0.005159,0.001397,0.002099,5.900901e-05,8.2e-05,4.016891e-07,0.139044,139
6,XC756601_35,0.001406,0.003722,0.000127,0.000116,0.0118,1.117657e-05,0.000398,0.0004705411,6.103696e-06,...,2.7e-05,0.00093,0.000499,0.001412,0.02881,7.426153e-06,0.000172,5.79608e-07,0.012893,139
7,XC756601_40,0.011152,0.002333,0.000159,0.000535,0.002115,0.0001379954,0.001585,0.0009661235,0.000659307,...,0.00055,0.000322,0.004594,0.010217,0.000582,3.334399e-05,0.000225,1.781803e-05,0.000723,139
8,XC756601_45,0.002321,0.007784,0.00215,0.00019,0.00085,1.698514e-05,5e-05,0.003782778,0.000100904,...,3.2e-05,0.001651,0.000641,0.000325,0.03209,4.532235e-05,0.001114,8.625884e-06,0.28174,139
9,XC756601_50,0.00524,0.002573,0.00479,0.001531,0.003738,7.23503e-05,2.3e-05,0.0002325403,0.0001184422,...,0.000117,0.001149,0.001144,0.004024,0.004159,5.246757e-05,0.000101,1.918984e-06,0.138511,139


In [1]:
pred_df.to_csv('test.csv', index=False)

NameError: name 'pred_df' is not defined

In [19]:
xx = pd.read_csv('test.csv')

In [20]:
xx

Unnamed: 0.1,Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,...,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1,target
0,0,XC756601_5,0.002630,0.015080,0.001529,0.000028,0.003712,1.806909e-06,0.000369,1.081393e-06,...,0.000568,0.000103,0.000066,0.001026,0.000126,0.000186,0.000002,6.483643e-04,0.000088,139
1,1,XC756601_10,0.000685,0.000622,0.000238,0.000008,0.001599,1.940699e-07,0.000059,1.203048e-07,...,0.000464,0.000048,0.000015,0.000208,0.001796,0.000003,0.000014,3.804449e-04,0.000062,139
2,2,XC756601_15,0.005079,0.000165,0.003467,0.000003,0.001357,9.642807e-08,0.000158,5.396670e-06,...,0.000261,0.000033,0.000056,0.003161,0.000308,0.000024,0.000006,1.028877e-03,0.000043,139
3,3,XC756601_20,0.001118,0.003576,0.000853,0.000065,0.048988,3.553077e-05,0.000173,1.332497e-03,...,0.000299,0.001274,0.009416,0.008995,0.000577,0.000028,0.001760,4.488796e-06,0.003070,139
4,4,XC756601_25,0.002985,0.001214,0.000007,0.000007,0.006100,1.708488e-06,0.000635,6.468000e-05,...,0.000001,0.000511,0.000179,0.000461,0.012366,0.000003,0.000033,3.770710e-08,0.000323,139
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
41772,41772,XC493181_60,0.000833,0.004182,0.005688,0.000033,0.026400,3.715507e-05,0.000037,1.198244e-04,...,0.000271,0.000296,0.000095,0.000985,0.008691,0.000008,0.000429,8.151869e-05,0.000731,10
41773,41773,XC493181_65,0.004899,0.009130,0.003177,0.002091,0.013983,1.629001e-05,0.002164,5.767886e-05,...,0.000492,0.003032,0.007130,0.004210,0.002038,0.000048,0.000091,6.134300e-05,0.010749,10
41774,41774,XC493181_70,0.000252,0.002724,0.002290,0.000294,0.000147,9.447167e-06,0.000494,1.330459e-06,...,0.000032,0.026561,0.000082,0.000260,0.000071,0.000013,0.000003,5.637811e-06,0.002162,10
41775,41775,XC750334_5,0.005262,0.004812,0.000305,0.000024,0.004555,3.983036e-06,0.000415,1.541633e-05,...,0.000025,0.000185,0.000685,0.003649,0.008156,0.000012,0.000038,2.765975e-05,0.005658,106
