In [1]:
#!g1.1
import numpy as np
import pandas as pd
import os
import random
from tqdm.auto import tqdm
!export CUBLAS_WORKSPACE_CONFIG=':4096:8'

In [2]:
#!g1.1
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
nw = 8

def set_seed(seed=777):
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
set_seed() #reproducibility

In [3]:
#!g1.1
FOLD = '9' #this train fold will be used for validation
BASE = '/home/jupyter/mnt/datasets/full_dataset/' #path to dataset

meta = pd.read_csv(os.path.join(BASE, 'train_meta.tsv'), sep='\t')
val_subset = (meta.archive_features_path.str.split('/').str[0] == FOLD)
train_meta = meta[~val_subset]
val_meta = meta[val_subset]

n_classes = len(meta.artistid.unique()) + 1
batch_size = 3000

In [4]:
#!g1.1
FOLDS = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0']

In [5]:
#!g1.1
from torch.utils.data import Dataset, DataLoader
from datasets import ArtistDataset

train_dataset = ArtistDataset(train_meta, os.path.join(BASE, 'train_features'), train=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw, pin_memory=True, drop_last=True)

val_dataset = ArtistDataset(val_meta, os.path.join(BASE, 'train_features'), train=False)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True)

In [6]:
#!g1.1
from simple_network import Net
model = Net(input_dim=512, emb_dim=1024);

In [7]:
#!g1.1
from arcface import ArcFace
criterion = ArcFace(nOut=1024, nClasses=n_classes, margin=0.0, scale=15)

Initialised AAMSoftmax margin 0.000 scale 15.000


In [8]:
#!g1.1
def set_margin(step, low=1000, high=4500, low_m=0.0, high_m=0.5): #curriculum learning for arcface - gradually increase margin
    if step <= low:
        return low_m
    elif step >= high:
        return high_m
    
    frac = (step - low) / (high - low)
    m = low_m * (1-frac) + high_m * (frac)
    return m

In [9]:
#!g1.1
model.to(DEVICE);
criterion.to(DEVICE);
model.requires_grad_(True);
criterion.requires_grad_(True);

In [10]:
#!g1.1
n_epochs = 100

optimizer = optim.Adam([p for p in model.parameters()] + [p for p in criterion.parameters()], lr=1e-3, weight_decay=1e-4)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    epochs=n_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    div_factor=10,
    final_div_factor=1000,
)

In [11]:
#!g1.1
from compute_score import eval_submission
STEP = 0
last_ndcg = 0.52 #initial threshold for checkpoint saving
last_fname = None
for EPOCH in range(n_epochs):
    model.train()
    train_loss = []
    train_acc = []
    for X, y in tqdm(train_loader, desc=f'Train epoch {EPOCH+1}/{n_epochs}', leave=False):
        optimizer.zero_grad()
        
        X = X.to(DEVICE); y = y.to(DEVICE);
        emb = model(X)
        cur_loss, cur_acc = criterion(emb, y)
        cur_loss.backward()
        optimizer.step()
        
        train_loss.append(cur_loss.item())
        train_acc.append(cur_acc.item())
        
        scheduler.step()
        STEP += 1
        new_marg = set_margin(STEP)
        criterion.m = new_marg
                
    print(f'Train loss: {np.mean(train_loss):.4f}, train acc: {np.mean(train_acc):.4f}')
    
    model.eval()
    if EPOCH > 19:
        with torch.no_grad():
            all_embs = []
            for X, y in tqdm(val_loader, desc=f'Val epoch {EPOCH+1}/{n_epochs}', leave=False):
                X = X.to(DEVICE); y = y.to(DEVICE);
                X = X[:, 0, :, :] #discard TTA - 1st object is original sample
                emb = model(X)
                all_embs.append(emb)

            all_embs = torch.cat(all_embs, dim=0)
            all_embs = F.normalize(all_embs, dim=-1)
            ans = []
            for i, emb in enumerate(all_embs):
                scores = torch.matmul(all_embs, emb)
                assert scores.dim() == 1
                scores[i] = -10
                scores = torch.argsort(scores, descending=True)[:100].cpu().numpy() #top-100 candidates
                ans.append(scores)
            ans_dict = {val_meta.iloc[i].trackid: list(val_meta.iloc[row].trackid) for i, row in enumerate(ans)} #convert indices to IDs
            ndcg = eval_submission(ans_dict, val_meta, 100)
            print(f'Epoch: {EPOCH+1}/{n_epochs}, Val nDCG: {ndcg:.5f}')

            if ndcg > last_ndcg:
                last_ndcg = ndcg
                last_fname = f'att2net_fold_{FOLD}_{EPOCH}_epoch_{ndcg:.5f}_ndcg.pth'
                torch.save(model.state_dict(), last_fname)

    print('-'*40)
    if EPOCH == 35: #empirically, this is enough for convergence
        break

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

Train loss: 9.7385, train acc: 0.0400
----------------------------------------
Train loss: 9.1923, train acc: 0.4653
----------------------------------------
Train loss: 8.5340, train acc: 2.3893
----------------------------------------
Train loss: 7.7344, train acc: 7.7420
----------------------------------------
Train loss: 6.8525, train acc: 17.1667
----------------------------------------
Train loss: 5.9623, train acc: 29.6400
----------------------------------------
Train loss: 5.1394, train acc: 42.0267
----------------------------------------
Train loss: 4.4191, train acc: 53.0293
----------------------------------------
Train loss: 3.8282, train acc: 61.6687
----------------------------------------
Train loss: 3.3728, train acc: 68.2020
----------------------------------------
Train loss: 3.0307, train acc: 73.0820
----------------------------------------
Train loss: 2.7859, train acc: 76.3960
----------------------------------------
Train loss: 2.6026, train acc: 78.7387
-----

In [12]:
#!g1.1
#STEP2 fine-tune with SmoothAP
from torch.utils.data import Sampler
from datasets import ArtistMultiDataset, ArtistSampler

model.load_state_dict(torch.load(last_fname))

        
train_dataset = ArtistMultiDataset(train_meta, os.path.join(BASE, 'train_features'), train=True)
train_sampler = ArtistSampler(train_dataset, 3, 100, 280); next(iter(train_sampler))
train_loader = DataLoader(train_dataset, batch_size=280, num_workers=nw, pin_memory=True, drop_last=True, sampler=train_sampler)

val_dataset = ArtistDataset(val_meta, os.path.join(BASE, 'train_features'), train=False)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True)

from smoothap import SmoothAP
sap = SmoothAP(anneal=0.02, batch_size=280*3, num_id=280, feat_dims=1024)

n_epochs = 5
optimizer = optim.Adam(model.parameters(), lr=3e-6, weight_decay=1e-5)

for EPOCH in range(n_epochs):
    model.train()
    for X, y in tqdm(train_loader, desc=f'Train epoch {EPOCH+1}/{n_epochs}', leave=False):
        optimizer.zero_grad()
        X = X.to(DEVICE); y = y.to(DEVICE);
        X = X.reshape(X.shape[0] * X.shape[1], X.shape[2], X.shape[3])
        y = y.reshape(y.shape[0] * y.shape[1])
        emb = model(X)
        sapl = sap(emb)
        cur_loss = sapl
        cur_loss.backward()
        optimizer.step()        
            
    model.eval()
    with torch.no_grad():
        all_embs = []
        for X, y in tqdm(val_loader, desc=f'Val epoch {EPOCH+1}/{n_epochs}', leave=False):
            X = X.to(DEVICE); y = y.to(DEVICE);
            emb = model(X)
            all_embs.append(emb)

        all_embs = torch.cat(all_embs, dim=0)
        all_embs = F.normalize(all_embs, dim=-1)
        ans = []
        for i, emb in enumerate(all_embs):
            scores = torch.matmul(all_embs, emb)
            assert scores.dim() == 1
            scores[i] = -10
            scores = torch.argsort(scores, descending=True)[:100].cpu().numpy()
            ans.append(scores)
        ans_dict = {val_meta.iloc[i].trackid: list(val_meta.iloc[row].trackid) for i, row in enumerate(ans)}
        ndcg = eval_submission(ans_dict, val_meta, 100)
        print(f'Epoch: {EPOCH+1}/{n_epochs}, Val nDCG: {ndcg:.5f}')

        if ndcg > last_ndcg:
            last_ndcg = ndcg
            last_fname = f'att2net_step2_fold_{FOLD}_{EPOCH}_epoch_{ndcg:.5f}_ndcg.pth'
            torch.save(model.state_dict(), last_fname)

    print('-'*40)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16773.0), HTML(value='')))

Epoch: 1/5, Val nDCG: 0.54599
----------------------------------------
Epoch: 2/5, Val nDCG: 0.54650
----------------------------------------
Epoch: 3/5, Val nDCG: 0.54667
----------------------------------------
Epoch: 4/5, Val nDCG: 0.54697
----------------------------------------




KeyboardInterrupt: 

In [10]:
#!g1.1
model.load_state_dict(torch.load(last_fname))
model.eval();

In [11]:
#!g1.1
#default evaluation

from compute_score import eval_submission
model.eval()
with torch.no_grad():
    all_embs = []
    for X, y in tqdm(val_loader, desc=f'Val epoch', leave=False):
        n_per_spk = X.shape[1]
        X = X.to(DEVICE); y = y.to(DEVICE);
        X = X.reshape(X.shape[0] * X.shape[1], X.shape[2], X.shape[3])
        emb = model(X)
        emb = emb.reshape(emb.shape[0] // n_per_spk, n_per_spk, -1)
        emb = F.normalize(emb, dim=-1).mean(dim=1)
        all_embs.append(emb)

    all_embs = torch.cat(all_embs, dim=0)
    all_embs = F.normalize(all_embs, dim=-1)
    ans = []
    for i, emb in tqdm(enumerate(all_embs), leave=False):
        scores = torch.matmul(all_embs, emb)
        assert scores.dim() == 1
        scores[i] = -10
        scores = torch.argsort(scores, descending=True)[:100].cpu().numpy()
        ans.append(scores)
    ans_dict = {val_meta.iloc[i].trackid: list(val_meta.iloc[row].trackid) for i, row in tqdm(enumerate(ans), leave=False)}
    ndcg = eval_submission(ans_dict, val_meta, 100)
    print(f'Val nDCG: {ndcg:.5f}')



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16748.0), HTML(value='')))

Val nDCG: 0.54124


In [12]:
#!g1.1
#ASNorm evaluation
model.eval()
cohort_size = 10000
counts = train_meta.groupby('artistid').count()
cohort_idx = np.random.permutation(counts[counts.trackid == 10].index.values)[:cohort_size] #cohort of impostors indices

In [13]:
#!g1.1
from datasets import CohortDataset    

cohort_dataset = CohortDataset(train_meta, os.path.join(BASE, 'train_features'), cohort_idx)
cohort_loader = DataLoader(cohort_dataset, batch_size=64, shuffle=False, num_workers=nw, pin_memory=True, drop_last=False)

In [14]:
#!g1.1
with torch.no_grad():
    cohort_embs = []
    for X, y in tqdm(cohort_loader):
        X = X.to(DEVICE); y = y.to(DEVICE);
        b, n_per_spk, feat, seq = X.shape
        X = X.reshape(b * n_per_spk, feat, seq)
        emb = model(X)
        emb = F.normalize(emb, dim=-1)
        emb = emb.reshape(b, n_per_spk, emb.shape[-1]).mean(dim=1) #mean emb for impostors
        cohort_embs.append(emb)
    cohort_embs = torch.cat(cohort_embs, dim=0)
    cohort_embs = F.normalize(cohort_embs, dim=-1)

100%|██████████| 157/157 [00:19<00:00,  8.14it/s]


In [15]:
#!g1.1
#mean and std for each impostor against impostors
incohort_mean = []
incohort_std = []
for i, cohort_emb in tqdm(enumerate(cohort_embs)):
    incohort_scores = torch.matmul(cohort_embs, cohort_emb)
    incohort_scores = torch.cat([incohort_scores[:i], incohort_scores[i+1:]], dim=0)
    incohort_mean.append(incohort_scores.mean())
    incohort_std.append(incohort_scores.std())
incohort_mean = torch.stack(incohort_mean).squeeze()
incohort_std = torch.stack(incohort_std).squeeze()

10000it [00:00, 10608.27it/s]


In [16]:
#!g1.1

from compute_score import eval_submission
KNN = 400
model.eval()
with torch.no_grad():
    all_embs = []
    for X, y in tqdm(val_loader, desc='val', leave=False):
        n_per_spk = X.shape[1]
        X = X.to(DEVICE); y = y.to(DEVICE);
        X = X.reshape(X.shape[0] * X.shape[1], X.shape[2], X.shape[3])
        emb = model(X)
        emb = emb.reshape(emb.shape[0] // n_per_spk, n_per_spk, -1)
        emb = F.normalize(emb, dim=-1).mean(dim=1)
        all_embs.append(emb)
        
    all_embs = torch.cat(all_embs, dim=0)
    all_embs = F.normalize(all_embs, dim=-1)
    snorm_mean = []
    snorm_std = [] #mean and std for each record against cohort
    nnk_means = []
    nnk_stds = []
    for i, emb in tqdm(enumerate(all_embs), total=len(all_embs), leave=False):
        cohort_scores = torch.matmul(cohort_embs, emb)
        snorm_mean.append(cohort_scores.mean())
        snorm_std.append(cohort_scores.std()) #s-norm
        nnk = torch.argsort(cohort_scores, descending=True)[:KNN] #nearest impostors
        sk = (cohort_scores[nnk] - cohort_scores.mean()) / (cohort_scores.std()) + (cohort_scores[nnk] - incohort_mean[nnk]) / incohort_std[nnk] #statistics for nearest impostors
        nnk_means.append(sk.mean())
        nnk_stds.append(sk.std())
        
    snorm_mean = torch.stack(snorm_mean).squeeze()
    snorm_std = torch.stack(snorm_std).squeeze()
    nnk_means = torch.stack(nnk_means).squeeze()
    nnk_stds = torch.stack(nnk_stds).squeeze()
    ans = []
    for i, emb in tqdm(enumerate(all_embs), total=len(all_embs), leave=False):
        scores = torch.matmul(all_embs, emb)
        snorm_scores = (scores - snorm_mean[i]) / (snorm_std[i]) + (scores - snorm_mean) / (snorm_std)
        ckd_scores = (snorm_scores - nnk_means[i]) / nnk_stds[i] + (snorm_scores - nnk_means) / nnk_stds
        snorm_a = 0.3
        scores = snorm_a * snorm_scores + (1 - snorm_a) * ckd_scores #ASNorm
        assert scores.dim() == 1
        scores[i] = -1000
        scores = torch.argsort(scores, descending=True)[:100].cpu().numpy()
        ans.append(scores)
        
ans_dict = {val_meta.iloc[i].trackid: list(val_meta.iloc[row].trackid) for i, row in tqdm(enumerate(ans))}
eval_submission(ans_dict, val_meta, 100)

16748it [00:10, 1543.18it/s]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16748.0), HTML(value='')))

0.546785666077093

In [17]:
#!g1.1
from datasets import TestDataset

In [18]:
#!g1.1
test_meta = pd.read_csv(os.path.join(BASE, 'test_meta.tsv'), sep='\t')
test_dataset = TestDataset(test_meta, os.path.join(BASE, 'test_features'))
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True)

In [19]:
#!g1.1
KNN = 400
model.eval()
with torch.no_grad():
    all_embs = []
    folds = []
    for X, y in tqdm(test_loader, desc='test', leave=False):
        n_per_spk = X.shape[1]
        X = X.to(DEVICE); y = y.to(DEVICE);
        X = X.reshape(X.shape[0] * X.shape[1], X.shape[2], X.shape[3])
        emb = model(X)
        emb = emb.reshape(emb.shape[0] // n_per_spk, n_per_spk, -1)
        emb = F.normalize(emb, dim=-1).mean(dim=1)
        all_embs.append(emb)
        
    all_embs = torch.cat(all_embs, dim=0)
    all_embs = F.normalize(all_embs, dim=-1)
    
    
    snorm_mean = []
    snorm_std = []
    nnk_means = []
    nnk_stds = []
    for i, emb in tqdm(enumerate(all_embs), total=len(all_embs), leave=False):
        cohort_scores = torch.matmul(cohort_embs, emb)
        snorm_mean.append(cohort_scores.mean())
        snorm_std.append(cohort_scores.std())
        nnk = torch.argsort(cohort_scores, descending=True)[:KNN]
        sk = (cohort_scores[nnk] - cohort_scores.mean()) / (cohort_scores.std()) + (cohort_scores[nnk] - incohort_mean[nnk]) / incohort_std[nnk]
        nnk_means.append(sk.mean())
        nnk_stds.append(sk.std())
        
    snorm_mean = torch.stack(snorm_mean).squeeze()
    snorm_std = torch.stack(snorm_std).squeeze()
    nnk_means = torch.stack(nnk_means).squeeze()
    nnk_stds = torch.stack(nnk_stds).squeeze()
    ANS = []
    for i, emb in tqdm(enumerate(all_embs), total=len(all_embs), leave=False):
        scores = torch.matmul(all_embs, emb)
        snorm_scores = (scores - snorm_mean[i]) / (snorm_std[i]) + (scores - snorm_mean) / (snorm_std)
        ckd_scores = (snorm_scores - nnk_means[i]) / nnk_stds[i] + (snorm_scores - nnk_means) / nnk_stds
        snorm_a = 0.3
        scores = snorm_a * snorm_scores + (1 - snorm_a) * ckd_scores
        assert scores.dim() == 1
        scores[i] = -1000
        scores = torch.argsort(scores, descending=True)[:100].cpu().numpy()
        ANS.append(scores)



In [20]:
#!g1.1
with open('submission_att2net_snorm10k_2step_0.54113val4.txt', 'w') as f: #dump submission
    for i, ans_row in tqdm(enumerate(ANS), total=len(ANS)):
        test_row = test_meta.iloc[i]
        fstring = str(test_row.trackid) + '\t' + ' '.join([str(a) for a in list(test_meta.iloc[ans_row].trackid)]) + '\n'
        f.write(fstring)

100%|██████████| 41377/41377 [00:27<00:00, 1482.47it/s]


In [None]:
#!g1.1
