In [1]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from datetime import datetime
from tqdm import tqdm
from glob import glob

In [2]:
CUDA_DEV = 0
NUM_TAGS = 256

In [3]:
data_dir = '/kaggle/input/yandex-cup-2023-recsys'
emb_dir = 'track_embeddings'

df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))
df_test = pd.read_csv(os.path.join(data_dir, 'test.csv'))

In [4]:
track_idx2embeds = {}
for fn in tqdm(glob(os.path.join(data_dir, emb_dir) + '/*')):
    filename = os.path.basename(fn)
    track_idx = int(os.path.splitext(filename)[0])
    embeds = np.load(fn)
    track_idx2embeds[track_idx] = embeds

len(track_idx2embeds)

100%|██████████| 76714/76714 [09:38<00:00, 132.70it/s]


76714

In [5]:
class TaggingDataset(Dataset):
    def __init__(self, df, testing=False):
        self.df = df
        self.testing = testing
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        track_idx = row.track
        embeds = track_idx2embeds[track_idx]
        if self.testing:
            return track_idx, embeds
        tags = [int(x) for x in row.tags.split(',')]
        target = np.zeros(NUM_TAGS)
        target[tags] = 1
        return track_idx, embeds, target

In [6]:
train_dataset = TaggingDataset(df_train)
test_dataset = TaggingDataset(df_test, True)

In [7]:
class Network(nn.Module):
    def __init__(
        self,
        num_classes = NUM_TAGS,
        input_dim = 768,
        hidden_dim = 512
    ):
        super().__init__()
        self.num_classes = num_classes
        self.bn = nn.LayerNorm(hidden_dim)
        self.projector =  nn.Linear(input_dim, hidden_dim)
        self.lin = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )
        self.fc = nn.Linear(hidden_dim, num_classes)
        

    def forward(self, embeds):
        x = [self.projector(x) for x in embeds]
        x = [v.mean(0).unsqueeze(0) for v in x]
        x = self.bn(torch.cat(x, dim = 0))
        x = self.lin(x)
        outs = self.fc(x)
        return outs


In [8]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = None
    alpha = 0.8
    for iteration,data in enumerate(loader):
        optimizer.zero_grad()
        track_idxs, embeds, target = data
        embeds = [x.to(CUDA_DEV) for x in embeds]
        target = target.to(CUDA_DEV)
        pred_logits = model(embeds)
        pred_probs = torch.sigmoid(pred_logits)
        ce_loss = criterion(pred_logits, target)
        ce_loss.backward()
        optimizer.step()
        
        if running_loss is None:
            running_loss = ce_loss.item()
        else:
            running_loss = alpha * running_loss + (1 - alpha) * ce_loss.item()
        if iteration % 100 == 0:
            print('   {} batch {} loss {}'.format(
                datetime.now(), iteration + 1, running_loss
            ))

In [9]:
def predict(model, loader):
    model.eval()
    track_idxs = []
    predictions = []
    with torch.no_grad():
        for data in loader:
            track_idx, embeds = data
            embeds = [x.to(CUDA_DEV) for x in embeds]
            pred_logits = model(embeds)
            pred_probs = torch.sigmoid(pred_logits)
            predictions.append(pred_probs.cpu().numpy())
            track_idxs.append(track_idx.numpy())
    predictions = np.vstack(predictions)
    track_idxs = np.vstack(track_idxs).ravel()
    return track_idxs, predictions

In [10]:
def collate_fn(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    targets = np.vstack([x[2] for x in b])
    targets = torch.from_numpy(targets)
    return track_idxs, embeds, targets

def collate_fn_test(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    return track_idxs, embeds

In [11]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn_test)

In [12]:
model = Network()
criterion = nn.BCEWithLogitsLoss()

epochs = 100
model = model.to(CUDA_DEV)
criterion = criterion.to(CUDA_DEV)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in tqdm(range(epochs)):
    train_epoch(model, train_dataloader, criterion, optimizer)

  0%|          | 0/100 [00:00<?, ?it/s]

   2023-10-25 18:20:00.689212 batch 1 loss 0.7380401492561646
   2023-10-25 18:20:04.804588 batch 101 loss 0.08221160314581685
   2023-10-25 18:20:08.871037 batch 201 loss 0.06677076556559228
   2023-10-25 18:20:12.971962 batch 301 loss 0.06113852584081181


  1%|          | 1/100 [00:18<29:50, 18.08s/it]

   2023-10-25 18:20:16.993327 batch 1 loss 0.06180681880682215
   2023-10-25 18:20:21.088439 batch 101 loss 0.05847679932389772
   2023-10-25 18:20:25.235757 batch 201 loss 0.05609370486928812
   2023-10-25 18:20:29.351700 batch 301 loss 0.05563216616552631


  2%|▏         | 2/100 [00:34<27:55, 17.10s/it]

   2023-10-25 18:20:33.405447 batch 1 loss 0.05833762639696083
   2023-10-25 18:20:37.459470 batch 101 loss 0.053852809040926826
   2023-10-25 18:20:41.529177 batch 201 loss 0.052677309672474464
   2023-10-25 18:20:45.551341 batch 301 loss 0.05293379023092989


  3%|▎         | 3/100 [00:50<26:56, 16.67s/it]

   2023-10-25 18:20:49.555066 batch 1 loss 0.048718741509699015
   2023-10-25 18:20:53.605134 batch 101 loss 0.05188734073392412
   2023-10-25 18:20:57.844088 batch 201 loss 0.05098870696460726
   2023-10-25 18:21:01.997918 batch 301 loss 0.0493686198414952


  4%|▍         | 4/100 [01:07<26:32, 16.58s/it]

   2023-10-25 18:21:06.013501 batch 1 loss 0.052855650206267146
   2023-10-25 18:21:10.011023 batch 101 loss 0.050590014912497955
   2023-10-25 18:21:14.081833 batch 201 loss 0.04960063489046956
   2023-10-25 18:21:18.103920 batch 301 loss 0.05064226500965645


  5%|▌         | 5/100 [01:23<25:59, 16.42s/it]

   2023-10-25 18:21:22.128981 batch 1 loss 0.048388380770878925
   2023-10-25 18:21:26.219409 batch 101 loss 0.047954737735023546
   2023-10-25 18:21:30.334947 batch 201 loss 0.047308004824893224
   2023-10-25 18:21:34.388579 batch 301 loss 0.0487684543892852


  6%|▌         | 6/100 [01:39<25:39, 16.38s/it]

   2023-10-25 18:21:38.431132 batch 1 loss 0.04823226074432185
   2023-10-25 18:21:42.486786 batch 101 loss 0.04719236996655274
   2023-10-25 18:21:46.522359 batch 201 loss 0.047912046282786724
   2023-10-25 18:21:50.552771 batch 301 loss 0.04829571363579562


  7%|▋         | 7/100 [01:55<25:16, 16.31s/it]

   2023-10-25 18:21:54.588637 batch 1 loss 0.045794874134823915
   2023-10-25 18:21:58.686977 batch 101 loss 0.04629958911879592
   2023-10-25 18:22:02.840034 batch 201 loss 0.04761958702217688
   2023-10-25 18:22:06.870596 batch 301 loss 0.0464652025637487


  8%|▊         | 8/100 [02:11<24:59, 16.30s/it]

   2023-10-25 18:22:10.868624 batch 1 loss 0.04698114000202899
   2023-10-25 18:22:14.880422 batch 101 loss 0.04621104226145992
   2023-10-25 18:22:18.880828 batch 201 loss 0.04510951288993498
   2023-10-25 18:22:22.931557 batch 301 loss 0.04627557350558701


  9%|▉         | 9/100 [02:28<24:36, 16.23s/it]

   2023-10-25 18:22:26.937674 batch 1 loss 0.043499826263864616
   2023-10-25 18:22:31.118372 batch 101 loss 0.045163988783401006
   2023-10-25 18:22:35.195850 batch 201 loss 0.04605355761160327
   2023-10-25 18:22:39.220628 batch 301 loss 0.045799375057137916


 10%|█         | 10/100 [02:44<24:24, 16.28s/it]

   2023-10-25 18:22:43.329847 batch 1 loss 0.048115097944308516
   2023-10-25 18:22:47.365559 batch 101 loss 0.045273197432914046
   2023-10-25 18:22:51.447308 batch 201 loss 0.045436168967921964
   2023-10-25 18:22:55.500461 batch 301 loss 0.04468738163944635


 11%|█         | 11/100 [03:00<24:07, 16.26s/it]

   2023-10-25 18:22:59.546290 batch 1 loss 0.045124971179113155
   2023-10-25 18:23:03.864988 batch 101 loss 0.04462090630704797
   2023-10-25 18:23:07.889161 batch 201 loss 0.045185046799056286
   2023-10-25 18:23:11.945228 batch 301 loss 0.045199442696605516


 12%|█▏        | 12/100 [03:17<23:56, 16.33s/it]

   2023-10-25 18:23:16.031673 batch 1 loss 0.044619993594302534
   2023-10-25 18:23:20.063151 batch 101 loss 0.044989912209800115
   2023-10-25 18:23:24.098613 batch 201 loss 0.045098571064107904
   2023-10-25 18:23:28.119495 batch 301 loss 0.04399157140399525


 13%|█▎        | 13/100 [03:33<23:39, 16.31s/it]

   2023-10-25 18:23:32.311478 batch 1 loss 0.04150987304445993
   2023-10-25 18:23:36.370324 batch 101 loss 0.04445278391337641
   2023-10-25 18:23:40.408184 batch 201 loss 0.0445897543170139
   2023-10-25 18:23:44.467209 batch 301 loss 0.043940671358534106


 14%|█▍        | 14/100 [03:49<23:19, 16.27s/it]

   2023-10-25 18:23:48.478274 batch 1 loss 0.042527387810392545
   2023-10-25 18:23:52.536228 batch 101 loss 0.04407666809617196
   2023-10-25 18:23:56.578875 batch 201 loss 0.04332529685467119
   2023-10-25 18:24:00.596768 batch 301 loss 0.04407340968274134


 15%|█▌        | 15/100 [04:05<23:04, 16.29s/it]

   2023-10-25 18:24:04.807429 batch 1 loss 0.043281685315350184
   2023-10-25 18:24:08.848599 batch 101 loss 0.04327949361363884
   2023-10-25 18:24:12.920786 batch 201 loss 0.043095643880390694
   2023-10-25 18:24:16.955835 batch 301 loss 0.04355063953299182


 16%|█▌        | 16/100 [04:22<22:45, 16.26s/it]

   2023-10-25 18:24:20.995967 batch 1 loss 0.04514786257245247
   2023-10-25 18:24:25.017820 batch 101 loss 0.043228896916173876
   2023-10-25 18:24:29.029767 batch 201 loss 0.043520257799909826
   2023-10-25 18:24:33.097318 batch 301 loss 0.041668585520302906


 17%|█▋        | 17/100 [04:38<22:28, 16.25s/it]

   2023-10-25 18:24:37.225982 batch 1 loss 0.045718156768101904
   2023-10-25 18:24:41.257669 batch 101 loss 0.041944832034322
   2023-10-25 18:24:45.255565 batch 201 loss 0.043605641033851575
   2023-10-25 18:24:49.266578 batch 301 loss 0.043023118724716876


 18%|█▊        | 18/100 [04:54<22:10, 16.22s/it]

   2023-10-25 18:24:53.387699 batch 1 loss 0.04162969793013786
   2023-10-25 18:24:57.400769 batch 101 loss 0.04492240045307235
   2023-10-25 18:25:01.467199 batch 201 loss 0.042506937364097906
   2023-10-25 18:25:05.553979 batch 301 loss 0.04248855897943678


 19%|█▉        | 19/100 [05:10<21:54, 16.23s/it]

   2023-10-25 18:25:09.617613 batch 1 loss 0.04321522372055403
   2023-10-25 18:25:13.677828 batch 101 loss 0.04380601636028917
   2023-10-25 18:25:17.730038 batch 201 loss 0.04352806572381997
   2023-10-25 18:25:21.784390 batch 301 loss 0.0443972533077715


 20%|██        | 20/100 [05:26<21:37, 16.22s/it]

   2023-10-25 18:25:25.811763 batch 1 loss 0.03665109337616339
   2023-10-25 18:25:29.883937 batch 101 loss 0.04322077667346569
   2023-10-25 18:25:34.016691 batch 201 loss 0.042271945779563225
   2023-10-25 18:25:38.176413 batch 301 loss 0.04284365086960353


 21%|██        | 21/100 [05:43<21:25, 16.27s/it]

   2023-10-25 18:25:42.202476 batch 1 loss 0.041474817928936605
   2023-10-25 18:25:46.210884 batch 101 loss 0.041055567286396366
   2023-10-25 18:25:50.218931 batch 201 loss 0.0428714843285661
   2023-10-25 18:25:54.299369 batch 301 loss 0.042880781688627106


 22%|██▏       | 22/100 [05:59<21:08, 16.26s/it]

   2023-10-25 18:25:58.441560 batch 1 loss 0.03512247411873238
   2023-10-25 18:26:02.519382 batch 101 loss 0.04043881744591721
   2023-10-25 18:26:06.637584 batch 201 loss 0.040884224612767354
   2023-10-25 18:26:10.854551 batch 301 loss 0.04280920395729933


 23%|██▎       | 23/100 [06:15<20:56, 16.31s/it]

   2023-10-25 18:26:14.882073 batch 1 loss 0.03946293095142253
   2023-10-25 18:26:18.903260 batch 101 loss 0.04033556082373935
   2023-10-25 18:26:22.949763 batch 201 loss 0.040910897205627676
   2023-10-25 18:26:26.957961 batch 301 loss 0.04245210733175943


 24%|██▍       | 24/100 [06:32<20:35, 16.25s/it]

   2023-10-25 18:26:30.991123 batch 1 loss 0.041677884843882955
   2023-10-25 18:26:35.012401 batch 101 loss 0.039978061296239424
   2023-10-25 18:26:39.140701 batch 201 loss 0.041753194141659455
   2023-10-25 18:26:43.357494 batch 301 loss 0.041415345357421224


 25%|██▌       | 25/100 [06:48<20:21, 16.29s/it]

   2023-10-25 18:26:47.375776 batch 1 loss 0.03686500038459173
   2023-10-25 18:26:51.448099 batch 101 loss 0.04040344137354556
   2023-10-25 18:26:55.507525 batch 201 loss 0.04082171493417241
   2023-10-25 18:26:59.548873 batch 301 loss 0.0408648157370157


 26%|██▌       | 26/100 [07:04<20:03, 16.27s/it]

   2023-10-25 18:27:03.592860 batch 1 loss 0.04353278135216791
   2023-10-25 18:27:07.705468 batch 101 loss 0.040153939895287934
   2023-10-25 18:27:11.896096 batch 201 loss 0.04127061029972911
   2023-10-25 18:27:15.972939 batch 301 loss 0.04025091807406611


 27%|██▋       | 27/100 [07:21<19:52, 16.33s/it]

   2023-10-25 18:27:20.069708 batch 1 loss 0.03773182573510324
   2023-10-25 18:27:24.116963 batch 101 loss 0.04122237273862238
   2023-10-25 18:27:28.141221 batch 201 loss 0.040541791049106365
   2023-10-25 18:27:32.181031 batch 301 loss 0.04077306044429756


 28%|██▊       | 28/100 [07:37<19:33, 16.29s/it]

   2023-10-25 18:27:36.274669 batch 1 loss 0.036808775066747006
   2023-10-25 18:27:40.328695 batch 101 loss 0.0390007678162455
   2023-10-25 18:27:44.635110 batch 201 loss 0.0408749346256639
   2023-10-25 18:27:48.634713 batch 301 loss 0.040575586609459334


 29%|██▉       | 29/100 [07:53<19:18, 16.32s/it]

   2023-10-25 18:27:52.657706 batch 1 loss 0.0387051278734879
   2023-10-25 18:27:56.709064 batch 101 loss 0.04042715615192809
   2023-10-25 18:28:00.744712 batch 201 loss 0.03976899625067684
   2023-10-25 18:28:04.810915 batch 301 loss 0.04010958579575039


 30%|███       | 30/100 [08:09<19:00, 16.30s/it]

   2023-10-25 18:28:08.896703 batch 1 loss 0.041378650631729116
   2023-10-25 18:28:13.000352 batch 101 loss 0.039908134221743406
   2023-10-25 18:28:17.123416 batch 201 loss 0.03924793557759719
   2023-10-25 18:28:21.164300 batch 301 loss 0.03986236155060219


 31%|███       | 31/100 [08:26<18:43, 16.29s/it]

   2023-10-25 18:28:25.165427 batch 1 loss 0.04683948319410643
   2023-10-25 18:28:29.185593 batch 101 loss 0.04085011783087754
   2023-10-25 18:28:33.225224 batch 201 loss 0.03956765257657194
   2023-10-25 18:28:37.298854 batch 301 loss 0.04049754686542989


 32%|███▏      | 32/100 [08:42<18:28, 16.30s/it]

   2023-10-25 18:28:41.496770 batch 1 loss 0.03784356840796626
   2023-10-25 18:28:45.567039 batch 101 loss 0.04012903661196887
   2023-10-25 18:28:49.618652 batch 201 loss 0.039591913713508634
   2023-10-25 18:28:53.699149 batch 301 loss 0.040146792608957114


 33%|███▎      | 33/100 [08:58<18:10, 16.27s/it]

   2023-10-25 18:28:57.697908 batch 1 loss 0.0394368872343512
   2023-10-25 18:29:01.746507 batch 101 loss 0.03958408337242554
   2023-10-25 18:29:05.795570 batch 201 loss 0.0403523810971444
   2023-10-25 18:29:09.834138 batch 301 loss 0.04088432002645537


 34%|███▍      | 34/100 [09:15<17:54, 16.28s/it]

   2023-10-25 18:29:13.988292 batch 1 loss 0.04191972106913984
   2023-10-25 18:29:18.154032 batch 101 loss 0.04064864377962038
   2023-10-25 18:29:22.217770 batch 201 loss 0.03947677169181345
   2023-10-25 18:29:26.255726 batch 301 loss 0.03955240367374008


 35%|███▌      | 35/100 [09:31<17:39, 16.29s/it]

   2023-10-25 18:29:30.320920 batch 1 loss 0.04135253473352037
   2023-10-25 18:29:34.417781 batch 101 loss 0.03937510819422006
   2023-10-25 18:29:38.474345 batch 201 loss 0.039711386483671285
   2023-10-25 18:29:42.603705 batch 301 loss 0.03941406740732009


 36%|███▌      | 36/100 [09:47<17:24, 16.31s/it]

   2023-10-25 18:29:46.685384 batch 1 loss 0.03890863989766757
   2023-10-25 18:29:50.764344 batch 101 loss 0.03841069868505521
   2023-10-25 18:29:54.806722 batch 201 loss 0.03887612157961224
   2023-10-25 18:29:58.819819 batch 301 loss 0.0401544820750039


 37%|███▋      | 37/100 [10:03<17:05, 16.28s/it]

   2023-10-25 18:30:02.873396 batch 1 loss 0.036833180889569245
   2023-10-25 18:30:06.931091 batch 101 loss 0.040034118591691745
   2023-10-25 18:30:10.962911 batch 201 loss 0.038235887366072
   2023-10-25 18:30:15.118057 batch 301 loss 0.039224571202880004


 38%|███▊      | 38/100 [10:20<16:50, 16.30s/it]

   2023-10-25 18:30:19.221916 batch 1 loss 0.03925628555877805
   2023-10-25 18:30:23.431201 batch 101 loss 0.03871379749934269
   2023-10-25 18:30:27.456348 batch 201 loss 0.03803565136969636
   2023-10-25 18:30:31.515721 batch 301 loss 0.03892634274643158


 39%|███▉      | 39/100 [10:36<16:34, 16.30s/it]

   2023-10-25 18:30:35.529341 batch 1 loss 0.03827534142731212
   2023-10-25 18:30:39.648432 batch 101 loss 0.038019405748363086
   2023-10-25 18:30:43.840467 batch 201 loss 0.038356693595196494
   2023-10-25 18:30:47.899605 batch 301 loss 0.03723026193625922


 40%|████      | 40/100 [10:53<16:22, 16.37s/it]

   2023-10-25 18:30:52.053608 batch 1 loss 0.04170156644935813
   2023-10-25 18:30:56.122832 batch 101 loss 0.03893772284632212
   2023-10-25 18:31:00.183204 batch 201 loss 0.03868926496256136
   2023-10-25 18:31:04.286491 batch 301 loss 0.03931307331391903


 41%|████      | 41/100 [11:09<16:03, 16.34s/it]

   2023-10-25 18:31:08.311924 batch 1 loss 0.037927209803873096
   2023-10-25 18:31:12.366939 batch 101 loss 0.037697062065583396
   2023-10-25 18:31:16.470176 batch 201 loss 0.03827392830839878
   2023-10-25 18:31:20.509393 batch 301 loss 0.04029127379063277


 42%|████▏     | 42/100 [11:25<15:46, 16.33s/it]

   2023-10-25 18:31:24.622669 batch 1 loss 0.039993960878429796
   2023-10-25 18:31:28.669017 batch 101 loss 0.03722140268365146
   2023-10-25 18:31:32.722558 batch 201 loss 0.03827497320696097
   2023-10-25 18:31:36.756126 batch 301 loss 0.03985350938933708


 43%|████▎     | 43/100 [11:41<15:28, 16.29s/it]

   2023-10-25 18:31:40.833435 batch 1 loss 0.04183223200161446
   2023-10-25 18:31:44.986138 batch 101 loss 0.038316485923101405
   2023-10-25 18:31:49.014180 batch 201 loss 0.03877596660336042
   2023-10-25 18:31:53.082707 batch 301 loss 0.03849424503368373


 44%|████▍     | 44/100 [11:58<15:13, 16.31s/it]

   2023-10-25 18:31:57.190338 batch 1 loss 0.03476398104999136
   2023-10-25 18:32:01.298779 batch 101 loss 0.03818613004642182
   2023-10-25 18:32:05.343324 batch 201 loss 0.038487008492302736
   2023-10-25 18:32:09.384787 batch 301 loss 0.03771793229933086


 45%|████▌     | 45/100 [12:14<14:56, 16.30s/it]

   2023-10-25 18:32:13.477105 batch 1 loss 0.037639185070932925
   2023-10-25 18:32:17.578260 batch 101 loss 0.03813117771722894
   2023-10-25 18:32:21.687674 batch 201 loss 0.038131678517577164
   2023-10-25 18:32:25.892962 batch 301 loss 0.03881240787936505


 46%|████▌     | 46/100 [12:31<14:43, 16.37s/it]

   2023-10-25 18:32:29.998454 batch 1 loss 0.038406796281975596
   2023-10-25 18:32:34.070919 batch 101 loss 0.03812868922945208
   2023-10-25 18:32:38.145728 batch 201 loss 0.03757165458471767
   2023-10-25 18:32:42.216534 batch 301 loss 0.03798054320591068


 47%|████▋     | 47/100 [12:47<14:27, 16.37s/it]

   2023-10-25 18:32:46.359237 batch 1 loss 0.034704431491229495
   2023-10-25 18:32:50.408335 batch 101 loss 0.03739609071089692
   2023-10-25 18:32:54.536151 batch 201 loss 0.038306655069347724
   2023-10-25 18:32:58.647944 batch 301 loss 0.038413850030817494


 48%|████▊     | 48/100 [13:03<14:11, 16.37s/it]

   2023-10-25 18:33:02.723576 batch 1 loss 0.03723487159302513
   2023-10-25 18:33:06.794780 batch 101 loss 0.03879025488936582
   2023-10-25 18:33:10.830837 batch 201 loss 0.038945367918832574
   2023-10-25 18:33:14.892217 batch 301 loss 0.036448217561468994


 49%|████▉     | 49/100 [13:20<13:53, 16.35s/it]

   2023-10-25 18:33:19.042783 batch 1 loss 0.036260314150413114
   2023-10-25 18:33:23.154325 batch 101 loss 0.037504040016446645
   2023-10-25 18:33:27.206654 batch 201 loss 0.03714385878470321
   2023-10-25 18:33:31.338068 batch 301 loss 0.03769940713589014


 50%|█████     | 50/100 [13:36<13:37, 16.35s/it]

   2023-10-25 18:33:35.395289 batch 1 loss 0.034945236883490466
   2023-10-25 18:33:39.466444 batch 101 loss 0.037762022561893246
   2023-10-25 18:33:43.555863 batch 201 loss 0.037414168623703364
   2023-10-25 18:33:47.669614 batch 301 loss 0.03624059002899676


 51%|█████     | 51/100 [13:52<13:20, 16.34s/it]

   2023-10-25 18:33:51.710637 batch 1 loss 0.038848179068716786
   2023-10-25 18:33:55.725839 batch 101 loss 0.037509127392605535
   2023-10-25 18:33:59.843585 batch 201 loss 0.037004768288689205
   2023-10-25 18:34:03.889101 batch 301 loss 0.03744713488091318


 52%|█████▏    | 52/100 [14:08<13:02, 16.29s/it]

   2023-10-25 18:34:07.889093 batch 1 loss 0.040783952967282794
   2023-10-25 18:34:11.915328 batch 101 loss 0.03676186139011783
   2023-10-25 18:34:15.901506 batch 201 loss 0.037640719737824124
   2023-10-25 18:34:19.936853 batch 301 loss 0.037184707275266855


 53%|█████▎    | 53/100 [14:25<12:42, 16.22s/it]

   2023-10-25 18:34:23.941979 batch 1 loss 0.033510643961434314
   2023-10-25 18:34:27.942849 batch 101 loss 0.036508818268622824
   2023-10-25 18:34:32.070680 batch 201 loss 0.03771428404343525
   2023-10-25 18:34:36.086174 batch 301 loss 0.03836207635293762


 54%|█████▍    | 54/100 [14:41<12:26, 16.22s/it]

   2023-10-25 18:34:40.159115 batch 1 loss 0.03597658029551953
   2023-10-25 18:34:44.256604 batch 101 loss 0.03766522568365629
   2023-10-25 18:34:48.343597 batch 201 loss 0.036622345665068265
   2023-10-25 18:34:52.398440 batch 301 loss 0.037353577553478665


 55%|█████▌    | 55/100 [14:57<12:10, 16.24s/it]

   2023-10-25 18:34:56.444936 batch 1 loss 0.03857920454911721
   2023-10-25 18:35:00.510439 batch 101 loss 0.037132234747640054
   2023-10-25 18:35:04.623578 batch 201 loss 0.03709349271890855
   2023-10-25 18:35:08.641249 batch 301 loss 0.036513507064665336


 56%|█████▌    | 56/100 [15:13<11:54, 16.24s/it]

   2023-10-25 18:35:12.688011 batch 1 loss 0.03540596708148866
   2023-10-25 18:35:16.691051 batch 101 loss 0.036336353450492236
   2023-10-25 18:35:20.766568 batch 201 loss 0.03658535177133496
   2023-10-25 18:35:24.790041 batch 301 loss 0.03795066862912571


 57%|█████▋    | 57/100 [15:29<11:36, 16.20s/it]

   2023-10-25 18:35:28.804686 batch 1 loss 0.036663297897265815
   2023-10-25 18:35:32.834757 batch 101 loss 0.0365320390577107
   2023-10-25 18:35:36.933862 batch 201 loss 0.03660773484819431
   2023-10-25 18:35:40.981442 batch 301 loss 0.0361329166103812


 58%|█████▊    | 58/100 [15:46<11:20, 16.20s/it]

   2023-10-25 18:35:44.986679 batch 1 loss 0.03397931605714661
   2023-10-25 18:35:49.064460 batch 101 loss 0.03631639169025572
   2023-10-25 18:35:53.108370 batch 201 loss 0.03574661332102882
   2023-10-25 18:35:57.165943 batch 301 loss 0.03612196445298562


 59%|█████▉    | 59/100 [16:02<11:04, 16.21s/it]

   2023-10-25 18:36:01.219779 batch 1 loss 0.03647280758948911
   2023-10-25 18:36:05.319918 batch 101 loss 0.03607320816843266
   2023-10-25 18:36:09.331843 batch 201 loss 0.03607582572812345
   2023-10-25 18:36:13.386296 batch 301 loss 0.0362937253122904


 60%|██████    | 60/100 [16:18<10:47, 16.20s/it]

   2023-10-25 18:36:17.398008 batch 1 loss 0.040476280397640896
   2023-10-25 18:36:21.463675 batch 101 loss 0.034670838008779604
   2023-10-25 18:36:25.485745 batch 201 loss 0.03664535812380144
   2023-10-25 18:36:29.493548 batch 301 loss 0.037071356789753644


 61%|██████    | 61/100 [16:34<10:30, 16.18s/it]

   2023-10-25 18:36:33.518462 batch 1 loss 0.036327137652470864
   2023-10-25 18:36:37.612846 batch 101 loss 0.03581105407039148
   2023-10-25 18:36:41.650986 batch 201 loss 0.035973760156331254
   2023-10-25 18:36:45.649564 batch 301 loss 0.03722946446119607


 62%|██████▏   | 62/100 [16:50<10:14, 16.17s/it]

   2023-10-25 18:36:49.666363 batch 1 loss 0.03363628117725384
   2023-10-25 18:36:53.714544 batch 101 loss 0.03612089616179371
   2023-10-25 18:36:57.732426 batch 201 loss 0.03651234600229585
   2023-10-25 18:37:01.779461 batch 301 loss 0.037717166845523044


 63%|██████▎   | 63/100 [17:06<09:57, 16.15s/it]

   2023-10-25 18:37:05.775712 batch 1 loss 0.03631438815832239
   2023-10-25 18:37:09.855126 batch 101 loss 0.03639726214414802
   2023-10-25 18:37:13.875031 batch 201 loss 0.03538456445511562
   2023-10-25 18:37:17.864703 batch 301 loss 0.03714757965854788


 64%|██████▍   | 64/100 [17:23<09:42, 16.17s/it]

   2023-10-25 18:37:21.986222 batch 1 loss 0.03233522553122117
   2023-10-25 18:37:26.009546 batch 101 loss 0.03547284805920569
   2023-10-25 18:37:30.017049 batch 201 loss 0.035582230187028635
   2023-10-25 18:37:34.073436 batch 301 loss 0.036772423843605004


 65%|██████▌   | 65/100 [17:39<09:25, 16.15s/it]

   2023-10-25 18:37:38.108120 batch 1 loss 0.03439839203572141
   2023-10-25 18:37:42.262680 batch 101 loss 0.035606532809023456
   2023-10-25 18:37:46.262541 batch 201 loss 0.03654976893488608
   2023-10-25 18:37:50.258661 batch 301 loss 0.035916326659816304


 66%|██████▌   | 66/100 [17:55<09:09, 16.18s/it]

   2023-10-25 18:37:54.331743 batch 1 loss 0.03362658371137215
   2023-10-25 18:37:58.395986 batch 101 loss 0.03675101409666987
   2023-10-25 18:38:02.432992 batch 201 loss 0.036042136386343376
   2023-10-25 18:38:06.456958 batch 301 loss 0.034879325058309


 67%|██████▋   | 67/100 [18:11<08:54, 16.19s/it]

   2023-10-25 18:38:10.542716 batch 1 loss 0.03377475683552056
   2023-10-25 18:38:14.576936 batch 101 loss 0.03528754483901058
   2023-10-25 18:38:18.587642 batch 201 loss 0.03624102577716584
   2023-10-25 18:38:22.624563 batch 301 loss 0.035748563811155654


 68%|██████▊   | 68/100 [18:27<08:37, 16.17s/it]

   2023-10-25 18:38:26.681045 batch 1 loss 0.03751431547523783
   2023-10-25 18:38:30.692256 batch 101 loss 0.034725741684187086
   2023-10-25 18:38:34.725148 batch 201 loss 0.03533925649356243
   2023-10-25 18:38:38.754056 batch 301 loss 0.03513355547340315


 69%|██████▉   | 69/100 [18:43<08:21, 16.19s/it]

   2023-10-25 18:38:42.905877 batch 1 loss 0.033117262318319594
   2023-10-25 18:38:46.925711 batch 101 loss 0.03523857369041477
   2023-10-25 18:38:50.975073 batch 201 loss 0.034937685986894555
   2023-10-25 18:38:55.004504 batch 301 loss 0.03600380486542893


 70%|███████   | 70/100 [19:00<08:05, 16.17s/it]

   2023-10-25 18:38:59.028871 batch 1 loss 0.03500341691316009
   2023-10-25 18:39:03.108148 batch 101 loss 0.03572872953699582
   2023-10-25 18:39:07.133490 batch 201 loss 0.035885035060473885
   2023-10-25 18:39:11.138829 batch 301 loss 0.03508152183321401


 71%|███████   | 71/100 [19:16<07:48, 16.16s/it]

   2023-10-25 18:39:15.180348 batch 1 loss 0.03905447212892707
   2023-10-25 18:39:19.187589 batch 101 loss 0.034864293531105706
   2023-10-25 18:39:23.220548 batch 201 loss 0.035041407845045246
   2023-10-25 18:39:27.294274 batch 301 loss 0.03557354326461527


 72%|███████▏  | 72/100 [19:32<07:32, 16.15s/it]

   2023-10-25 18:39:31.310030 batch 1 loss 0.033363023063192894
   2023-10-25 18:39:35.336556 batch 101 loss 0.034947935638928035
   2023-10-25 18:39:39.385294 batch 201 loss 0.036277275414990026
   2023-10-25 18:39:43.425537 batch 301 loss 0.03651100201101442


 73%|███████▎  | 73/100 [19:48<07:16, 16.17s/it]

   2023-10-25 18:39:47.519032 batch 1 loss 0.032828607608120834
   2023-10-25 18:39:51.588156 batch 101 loss 0.034149976648397
   2023-10-25 18:39:55.678479 batch 201 loss 0.0356411299465224
   2023-10-25 18:39:59.740697 batch 301 loss 0.035565047216523364


 74%|███████▍  | 74/100 [20:04<07:01, 16.21s/it]

   2023-10-25 18:40:03.832353 batch 1 loss 0.032519295041660304
   2023-10-25 18:40:07.938913 batch 101 loss 0.03380509292654433
   2023-10-25 18:40:12.041238 batch 201 loss 0.0345208887710206
   2023-10-25 18:40:16.137218 batch 301 loss 0.03380073144176708


 75%|███████▌  | 75/100 [20:21<06:46, 16.26s/it]

   2023-10-25 18:40:20.199174 batch 1 loss 0.035989860958061004
   2023-10-25 18:40:24.294669 batch 101 loss 0.034339626856081065
   2023-10-25 18:40:28.411605 batch 201 loss 0.03411130251127516
   2023-10-25 18:40:32.498034 batch 301 loss 0.033987835357392186


 76%|███████▌  | 76/100 [20:37<06:31, 16.29s/it]

   2023-10-25 18:40:36.578256 batch 1 loss 0.03252435010394306
   2023-10-25 18:40:40.649242 batch 101 loss 0.03427643374931143
   2023-10-25 18:40:44.724420 batch 201 loss 0.03505469371725637
   2023-10-25 18:40:48.811686 batch 301 loss 0.03606984074172948


 77%|███████▋  | 77/100 [20:53<06:14, 16.29s/it]

   2023-10-25 18:40:52.859381 batch 1 loss 0.0318253255364253
   2023-10-25 18:40:57.084640 batch 101 loss 0.034414105021301614
   2023-10-25 18:41:01.181008 batch 201 loss 0.03455846717080435
   2023-10-25 18:41:05.261019 batch 301 loss 0.03474906846768249


 78%|███████▊  | 78/100 [21:10<05:59, 16.33s/it]

   2023-10-25 18:41:09.288162 batch 1 loss 0.03403894695443781
   2023-10-25 18:41:13.360238 batch 101 loss 0.03357511496929846
   2023-10-25 18:41:17.382969 batch 201 loss 0.03491168385476312
   2023-10-25 18:41:21.474303 batch 301 loss 0.03465339739612377


 79%|███████▉  | 79/100 [21:26<05:42, 16.30s/it]

   2023-10-25 18:41:25.505505 batch 1 loss 0.03650927864031808
   2023-10-25 18:41:29.608733 batch 101 loss 0.03507094185326295
   2023-10-25 18:41:33.678072 batch 201 loss 0.03405001396299697
   2023-10-25 18:41:37.741457 batch 301 loss 0.035311022550039345


 80%|████████  | 80/100 [21:42<05:26, 16.31s/it]

   2023-10-25 18:41:41.850930 batch 1 loss 0.03346206162335019
   2023-10-25 18:41:45.910003 batch 101 loss 0.03338159227063437
   2023-10-25 18:41:49.998501 batch 201 loss 0.03346656455220341
   2023-10-25 18:41:54.031863 batch 301 loss 0.035446302830364915


 81%|████████  | 81/100 [21:59<05:09, 16.29s/it]

   2023-10-25 18:41:58.102360 batch 1 loss 0.030949168184429606
   2023-10-25 18:42:02.185972 batch 101 loss 0.03523517991411031
   2023-10-25 18:42:06.211082 batch 201 loss 0.034550328319892465
   2023-10-25 18:42:10.254376 batch 301 loss 0.034537028049599995


 82%|████████▏ | 82/100 [22:15<04:53, 16.28s/it]

   2023-10-25 18:42:14.352202 batch 1 loss 0.029723756141369406
   2023-10-25 18:42:18.387502 batch 101 loss 0.0331803019455338
   2023-10-25 18:42:22.558785 batch 201 loss 0.034082279851003275
   2023-10-25 18:42:26.584047 batch 301 loss 0.03561703504378284


 83%|████████▎ | 83/100 [22:31<04:37, 16.30s/it]

   2023-10-25 18:42:30.713474 batch 1 loss 0.03300703496042036
   2023-10-25 18:42:34.844287 batch 101 loss 0.0343949770549184
   2023-10-25 18:42:38.944461 batch 201 loss 0.034361033979935916
   2023-10-25 18:42:43.014828 batch 301 loss 0.03442864807676323


 84%|████████▍ | 84/100 [22:48<04:21, 16.35s/it]

   2023-10-25 18:42:47.155738 batch 1 loss 0.03460310742124334
   2023-10-25 18:42:51.230753 batch 101 loss 0.033202660332150025
   2023-10-25 18:42:55.294147 batch 201 loss 0.03374914508507868
   2023-10-25 18:42:59.374422 batch 301 loss 0.03457200752731887


 85%|████████▌ | 85/100 [23:04<04:05, 16.34s/it]

   2023-10-25 18:43:03.485225 batch 1 loss 0.03303978020958784
   2023-10-25 18:43:07.556770 batch 101 loss 0.03325887584177956
   2023-10-25 18:43:11.601779 batch 201 loss 0.03416466125543332
   2023-10-25 18:43:15.612547 batch 301 loss 0.03421392961445458


 86%|████████▌ | 86/100 [23:20<03:48, 16.30s/it]

   2023-10-25 18:43:19.685710 batch 1 loss 0.033724459094688256
   2023-10-25 18:43:23.788387 batch 101 loss 0.03299118909363471
   2023-10-25 18:43:27.791562 batch 201 loss 0.03490136134365917
   2023-10-25 18:43:31.928645 batch 301 loss 0.034771850694634704


 87%|████████▋ | 87/100 [23:37<03:31, 16.29s/it]

   2023-10-25 18:43:35.940407 batch 1 loss 0.03103473644164617
   2023-10-25 18:43:39.959521 batch 101 loss 0.034339243768131085
   2023-10-25 18:43:43.988761 batch 201 loss 0.03429345049898885
   2023-10-25 18:43:47.996516 batch 301 loss 0.03473901359596343


 88%|████████▊ | 88/100 [23:53<03:14, 16.23s/it]

   2023-10-25 18:43:52.029112 batch 1 loss 0.03226188335437645
   2023-10-25 18:43:56.182177 batch 101 loss 0.03387858505560761
   2023-10-25 18:44:00.232489 batch 201 loss 0.03347290295971637
   2023-10-25 18:44:04.292138 batch 301 loss 0.03307292507756015


 89%|████████▉ | 89/100 [24:09<02:58, 16.24s/it]

   2023-10-25 18:44:08.314937 batch 1 loss 0.03308469995692511
   2023-10-25 18:44:12.327160 batch 101 loss 0.03400486080724475
   2023-10-25 18:44:16.317157 batch 201 loss 0.03311589418811553
   2023-10-25 18:44:20.367337 batch 301 loss 0.03437873315295137


 90%|█████████ | 90/100 [24:25<02:42, 16.20s/it]

   2023-10-25 18:44:24.418411 batch 1 loss 0.03145837600473385
   2023-10-25 18:44:28.492921 batch 101 loss 0.03264679105086892
   2023-10-25 18:44:32.598494 batch 201 loss 0.034877819595241874
   2023-10-25 18:44:36.600123 batch 301 loss 0.03370669332313222


 91%|█████████ | 91/100 [24:41<02:25, 16.20s/it]

   2023-10-25 18:44:40.617317 batch 1 loss 0.031088975814142324
   2023-10-25 18:44:44.619821 batch 101 loss 0.03387284934021792
   2023-10-25 18:44:48.609719 batch 201 loss 0.034341341178616064
   2023-10-25 18:44:52.631385 batch 301 loss 0.03466501152001032


 92%|█████████▏| 92/100 [24:57<02:09, 16.15s/it]

   2023-10-25 18:44:56.636222 batch 1 loss 0.030949051375362885
   2023-10-25 18:45:00.696015 batch 101 loss 0.03355267651611703
   2023-10-25 18:45:04.915978 batch 201 loss 0.03418652766058161
   2023-10-25 18:45:08.935230 batch 301 loss 0.03334804253542714


 93%|█████████▎| 93/100 [25:14<01:53, 16.21s/it]

   2023-10-25 18:45:12.989434 batch 1 loss 0.03070429241966464
   2023-10-25 18:45:16.998158 batch 101 loss 0.033338083698948945
   2023-10-25 18:45:21.035216 batch 201 loss 0.03346828872806759
   2023-10-25 18:45:25.047860 batch 301 loss 0.03405896785719889


 94%|█████████▍| 94/100 [25:30<01:37, 16.22s/it]

   2023-10-25 18:45:29.232963 batch 1 loss 0.03047275872054417
   2023-10-25 18:45:33.322992 batch 101 loss 0.03273399522092772
   2023-10-25 18:45:37.341885 batch 201 loss 0.033412047209496955
   2023-10-25 18:45:41.430629 batch 301 loss 0.034415628087095904


 95%|█████████▌| 95/100 [25:46<01:21, 16.22s/it]

   2023-10-25 18:45:45.468100 batch 1 loss 0.033666779755151865
   2023-10-25 18:45:49.500409 batch 101 loss 0.032254290786798
   2023-10-25 18:45:53.521917 batch 201 loss 0.03302708333437089
   2023-10-25 18:45:57.585954 batch 301 loss 0.03460171402745815


 96%|█████████▌| 96/100 [26:02<01:05, 16.27s/it]

   2023-10-25 18:46:01.851928 batch 1 loss 0.027862511610378657
   2023-10-25 18:46:05.960921 batch 101 loss 0.033468735813396995
   2023-10-25 18:46:09.974093 batch 201 loss 0.03216469153614425
   2023-10-25 18:46:14.033596 batch 301 loss 0.03296266177898056


 97%|█████████▋| 97/100 [26:19<00:48, 16.25s/it]

   2023-10-25 18:46:18.065250 batch 1 loss 0.029468442640252824
   2023-10-25 18:46:22.150953 batch 101 loss 0.03290932472369486
   2023-10-25 18:46:26.147424 batch 201 loss 0.03370581462434756
   2023-10-25 18:46:30.139567 batch 301 loss 0.03412388266822143


 98%|█████████▊| 98/100 [26:35<00:32, 16.28s/it]

   2023-10-25 18:46:34.391515 batch 1 loss 0.030857259624481298
   2023-10-25 18:46:38.694564 batch 101 loss 0.032198381246665096
   2023-10-25 18:46:42.859210 batch 201 loss 0.03327277193791083
   2023-10-25 18:46:46.930307 batch 301 loss 0.033411353198106494


 99%|█████████▉| 99/100 [26:52<00:16, 16.40s/it]

   2023-10-25 18:46:51.065061 batch 1 loss 0.030924108633544268
   2023-10-25 18:46:55.098419 batch 101 loss 0.03223530343964418
   2023-10-25 18:46:59.320140 batch 201 loss 0.03221924849145485
   2023-10-25 18:47:03.806772 batch 301 loss 0.03275228829756491


100%|██████████| 100/100 [27:09<00:00, 16.29s/it]


In [13]:
%%time
track_idxs, predictions = predict(model, test_dataloader)

CPU times: user 4.56 s, sys: 19.9 ms, total: 4.58 s
Wall time: 4.58 s


In [14]:
predictions_df = pd.DataFrame([
    {'track': track, 'prediction': ','.join([str(p) for p in probs])}
    for track, probs in zip(track_idxs, predictions)
])

In [15]:
submit_file = 'submit_batch-128_epochs-100.csv'
predictions_df.to_csv(submit_file, index=False)

In [16]:
from IPython.display import FileLink

FileLink(submit_file)