In [1]:
#!g1.1
import numpy as np
import pandas as pd
import os
import random
from tqdm.auto import tqdm
!export CUBLAS_WORKSPACE_CONFIG=':4096:8'
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_grad_enabled(False)
nw = 8

BASE = '/home/jupyter/mnt/datasets/full_dataset/'
meta = pd.read_csv(os.path.join(BASE, 'train_meta.tsv'), sep='\t')
train_meta = meta

In [3]:
#!g1.1
checkpoints = [
    'att2net_fold_0_32_epoch_0.55146_ndcg.pth',
    'att2net_fold_1_30_epoch_0.53716_ndcg.pth',
    'att2net_fold_2_35_epoch_0.54485_ndcg.pth',
    'att2net_fold_3_34_epoch_0.53384_ndcg.pth',
    'att2net_fold_4_35_epoch_0.53940_ndcg.pth',
    'att2net_fold_5_31_epoch_0.53342_ndcg.pth',
    'att2net_fold_6_32_epoch_0.55152_ndcg.pth',
    'att2net_fold_7_32_epoch_0.54322_ndcg.pth',
    'att2net_fold_8_32_epoch_0.54593_ndcg.pth',
    'att2net_fold_9_32_epoch_0.54423_ndcg.pth'
] #for ensembling

model_bank = []

from simple_network import Net

for ckp in checkpoints:
    cur_model = Net(input_dim=512, emb_dim=1024); cur_model.eval(); cur_model.to(DEVICE);
    cur_model.load_state_dict(torch.load(ckp))
    model_bank.append(cur_model)
    
#basically what follows is the same ASNorm inference, but with multiple models

In [4]:
#!g1.1
from torch.utils.data import Dataset, DataLoader
from datasets import CohortDataset

cohort_size = 10000
counts = train_meta.groupby('artistid').count()
cohort_idx = np.random.permutation(counts[counts.trackid == 10].index.values)[:cohort_size]
    
cohort_dataset = CohortDataset(train_meta, os.path.join(BASE, 'train_features'), cohort_idx)
cohort_loader = DataLoader(cohort_dataset, batch_size=64, shuffle=False, num_workers=nw, pin_memory=True, drop_last=False)

In [5]:
#!g1.1
cohort_embs = [[] for i in range(len(model_bank))]
for X, y in tqdm(cohort_loader, desc='collect'):
    X = X.to(DEVICE); y = y.to(DEVICE);
    b, n_per_spk, feat, seq = X.shape
    X = X.reshape(b * n_per_spk, feat, seq)
    for i, model in enumerate(model_bank):
        emb = model(X)
        emb = F.normalize(emb, dim=-1)
        emb = emb.reshape(b, n_per_spk, emb.shape[-1]).mean(dim=1)
        cohort_embs[i].append(emb)
        
for i in range(len(cohort_embs)):
    cohort_embs[i] = torch.cat(cohort_embs[i], dim=0)
    cohort_embs[i] = F.normalize(cohort_embs[i], dim=-1)
    
incohort_mean = [[] for i in range(len(model_bank))]
incohort_std = [[] for i in range(len(model_bank))]

for NUM in tqdm(range(len(cohort_embs)), desc='score'):
    for i, cohort_emb in tqdm(enumerate(cohort_embs[NUM]), desc=f'score: {NUM}'):
        incohort_scores = torch.matmul(cohort_embs[NUM], cohort_emb)
        incohort_scores = torch.cat([incohort_scores[:i], incohort_scores[i+1:]], dim=0)
        incohort_mean[NUM].append(incohort_scores.mean())
        incohort_std[NUM].append(incohort_scores.std())
    incohort_mean[NUM] = torch.stack(incohort_mean[NUM]).squeeze()
    incohort_std[NUM] = torch.stack(incohort_std[NUM]).squeeze()

HBox(children=(HTML(value='collect'), FloatProgress(value=0.0, max=157.0), HTML(value='')))

HBox(children=(HTML(value='score'), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value='score: 0'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 1'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 2'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 3'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 4'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 5'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 6'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 7'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 8'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 9'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'),…

HBox(children=(HTML(value='score: 10'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 11'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 12'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 13'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 14'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 15'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 16'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 17'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 18'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

HBox(children=(HTML(value='score: 19'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px')…

























In [6]:
#!g1.1
from datasets import TestDataset

test_meta = pd.read_csv(os.path.join(BASE, 'test_meta.tsv'), sep='\t')
test_dataset = TestDataset(test_meta, os.path.join(BASE, 'test_features'))
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True)

In [7]:
#!g1.1
KNN = 400

all_embs = [[] for i in range(len(model_bank))]

for X, y in tqdm(test_loader, desc='test collect', leave=False):
    n_per_spk = X.shape[1]
    X = X.to(DEVICE); y = y.to(DEVICE);
    X = X.reshape(X.shape[0] * X.shape[1], X.shape[2], X.shape[3])
    for i, model in enumerate(model_bank):
        emb = model(X)
        emb = emb.reshape(emb.shape[0] // n_per_spk, n_per_spk, -1)
        emb = F.normalize(emb, dim=-1).mean(dim=1)
        all_embs[i].append(emb)

for i in range(len(all_embs)):
    all_embs[i] = torch.cat(all_embs[i], dim=0)
    all_embs[i] = F.normalize(all_embs[i], dim=-1)


HBox(children=(HTML(value='test collect'), FloatProgress(value=0.0, max=162.0), HTML(value='')))

In [8]:
#!g1.1
snorm_mean = [[] for i in range(len(model_bank))]
snorm_std = [[] for i in range(len(model_bank))]
nnk_means = [[] for i in range(len(model_bank))]
nnk_stds = [[] for i in range(len(model_bank))]

for NUM in range(len(all_embs)):
    for i, emb in tqdm(enumerate(all_embs[NUM]), total=len(all_embs[NUM]), leave=False, desc=f'snorm: {NUM}'):
        cohort_scores = torch.matmul(cohort_embs[NUM], emb)
        snorm_mean[NUM].append(cohort_scores.mean())
        snorm_std[NUM].append(cohort_scores.std())
        nnk = torch.argsort(cohort_scores, descending=True)[:KNN]
        sk = (cohort_scores[nnk] - cohort_scores.mean()) / (cohort_scores.std()) + (cohort_scores[nnk] - incohort_mean[NUM][nnk]) / incohort_std[NUM][nnk]
        nnk_means[NUM].append(sk.mean())
        nnk_stds[NUM].append(sk.std())

for i in range(len(snorm_mean)):
    snorm_mean[i] = torch.stack(snorm_mean[i]).squeeze()
    snorm_std[i] = torch.stack(snorm_std[i]).squeeze()
    nnk_means[i] = torch.stack(nnk_means[i]).squeeze()
    nnk_stds[i] = torch.stack(nnk_stds[i]).squeeze()
    
ANS = []
for i in tqdm(range(len(all_embs[0])), total=len(all_embs[0]), leave=False):
    scores_forthisobj = 0
    for j in range(len(all_embs)):
        scores_cur = torch.matmul(all_embs[j], all_embs[j][i])
        snorm_scores_cur = (scores_cur - snorm_mean[j][i]) / (snorm_std[j][i]) + (scores_cur - snorm_mean[j]) / snorm_std[j]
        ckd_scores_cur = (snorm_scores_cur - nnk_means[j][i]) / nnk_stds[j][i] + (snorm_scores_cur - nnk_means[j]) / nnk_stds[j]
        snorm_a = 0.3
        scores_cur = snorm_a * snorm_scores_cur + (1 - snorm_a) * ckd_scores_cur
        assert scores_cur.dim() == 1
        scores_cur[i] = -1000
        scores_forthisobj = scores_forthisobj + scores_cur
        
    scores_forthisobj = scores_forthisobj / len(all_embs) #average over models
    scores_forthisobj = torch.argsort(scores_forthisobj, descending=True)[:100].cpu().numpy()
    ANS.append(scores_forthisobj)


HBox(children=(HTML(value='snorm: 0'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 1'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 2'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 3'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 4'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 5'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 6'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 7'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 8'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 9'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 10'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 11'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 12'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 13'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 14'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 15'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 16'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 17'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 18'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value='snorm: 19'), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41377.0), HTML(value='')))

In [9]:
#!g1.1
with open('submission_att2net_10fold.txt', 'w') as f:
    for i, ans_row in tqdm(enumerate(ANS), total=len(ANS)):
        test_row = test_meta.iloc[i]
        fstring = str(test_row.trackid) + '\t' + ' '.join([str(a) for a in list(test_meta.iloc[ans_row].trackid)]) + '\n'
        f.write(fstring)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41377.0), HTML(value='')))




In [None]:
#!g1.1
