In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
# downlaod deepmind's pretrained language model
# !wget -O deepmind_assets/language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle

In [3]:
from perceiver_io.perceiver_lm import PerceiverLM

import os, sys
import torch
import torch.nn as nn
import transformers

from deepmind_assets import bytes_tokenizer
import numpy as np
import scipy.sparse as sp
from tqdm.notebook import tqdm
import scipy.sparse as sp
import xclib.evaluation.xc_metrics as xc_metrics
from utils import csr_to_pad_tensor, ToD, read_sparse_mat, XCMetrics, _c
from torch.nn.utils.rnn import pad_sequence

# The tokenizer is just UTF-8 encoding (with an offset)
tokenizer = bytes_tokenizer.BytesTokenizer()

In [None]:
command = "--dataset EURLex-4K"

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--project', default='PerceiverIO')
parser.add_argument('--dataset', default='EURLex-4K')
parser.add_argument('--device', type=str, default='cuda:0')

args = parser.parse_args(command.split())

In [40]:
args.expname = args.project
args.maxlen = 2048
args.vocab_size = 262
args.embed_dim = 768
args.num_latents = 256

args.n_epochs = 25
args.xc_lr = 1e-3
args.enc_lr = 1e-4
args.bsz = 32
args.dropout = 0.5
args.warmup = 0.1
args.loss_with_logits = True
args.amp = True
args.eval_interval = 2

args.per_label_task = True
args.per_token_decoder = False

OUT_DIR = f'{args.project}/{args.dataset}'
os.makedirs(OUT_DIR, exist_ok=True)

In [41]:
DATA_DIR = 'Datasets/EURLex-4K'

trnX = [x.strip() for x in open(f'{DATA_DIR}/raw/trn_X.txt').readlines()]
tstX = [x.strip() for x in open(f'{DATA_DIR}/raw/tst_X.txt').readlines()]
trn_X_Y = read_sparse_mat(f'{DATA_DIR}/trn_X_Y.txt', use_xclib=False)
tst_X_Y = read_sparse_mat(f'{DATA_DIR}/tst_X_Y.txt', use_xclib=False)
inv_prop = xc_metrics.compute_inv_propesity(trn_X_Y, 0.55, 1.5)

args.numy = trn_X_Y.shape[1]

15449it [00:00, 181291.64it/s]
3865it [00:00, 188084.29it/s]


In [42]:
encoder = PerceiverLM(vocab_size=args.vocab_size, 
                      max_seq_len=args.maxlen, 
                      embedding_dim=args.embed_dim, 
                      num_latents=args.num_latents, 
                      latent_dim=1280, 
                      qk_out_dim=256, 
                      dropout=0,
                      num_self_attn_per_block=26, 
                      per_token_decoder=args.per_token_decoder, 
                      num_query_tasks=args.numy if args.per_label_task else 1)

encoder.load_pretrained("deepmind_assets/language_perceiver_io_bytes.pickle")

_IncompatibleKeys(missing_keys=['query_task_embedding.weight'], unexpected_keys=[])


In [43]:
## Sanity check for encoder with per_token_decoder=True

# input_str = "This is an incomplete sentence where some words are missing."
# input_tokens = tokenizer.to_int(input_str)

# # Mask " missing.". Note that the model performs much better if the masked chunk
# # starts with a space.
# input_tokens[51:60] = tokenizer.mask_token
# print("Tokenized string without masked bytes:")
# print(tokenizer.to_string(input_tokens))

# #@title Pad and reshape inputs
# inputs = input_tokens[None]
# input_mask = np.ones_like(inputs)

# def pad(max_sequence_length: int, inputs, input_mask):
#     input_len = inputs.shape[1]
#     assert input_len <= max_sequence_length
#     pad_len = max_sequence_length - input_len
#     padded_inputs = np.pad(
#       inputs,
#       pad_width=((0, 0), (0, pad_len)),
#       constant_values=tokenizer.pad_token)
#     padded_mask = np.pad(
#       input_mask,
#       pad_width=((0, 0), (0, pad_len)),
#       constant_values=0)
#     return padded_inputs, padded_mask

# inputs, input_mask = pad(args.maxlen, inputs, input_mask)

# encoder.eval()
# mask = torch.tensor(input_mask)
# input_ids = torch.tensor(inputs)
# out = encoder.forward(input_ids, mask)

# embs = out * mask.unsqueeze(-1) / mask.sum(dim=-1)

# logits = torch.matmul(out, encoder.token_embedding.weight.T) + encoder.decoder_token_bias
# masked_tokens_predictions = logits[0, 51:60].argmax(dim=-1)
# print("Greedy predictions:")
# print(masked_tokens_predictions)
# print()
# print("Predicted string:")
# print(tokenizer.to_string(masked_tokens_predictions.cpu().detach().numpy()))

In [44]:
class XMLDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels, tokenizer, maxlen):
        self.maxlen = maxlen
        self.input_ids = pad_sequence([torch.LongTensor(tokenizer.to_int(x)[:maxlen]) for x in inputs], batch_first=True, padding_value=0)
        self.input_mask = (self.input_ids != 0).long()
        self.labels = labels
            
    def __getitem__(self, index):
        return index
    
    def get_fts(self, indices, source='point'):
        input_mask = self.input_mask[indices]
        max_batch_seq_len = input_mask.sum(dim=-1).max()
        return {'input_ids': self.input_ids[indices, :max_batch_seq_len], 'input_mask': input_mask[:, :max_batch_seq_len]}
   
    def __len__(self):
        return self.labels.shape[0]
    
class XMLCollator():
    def __init__(self, dataset):
        self.numy = dataset.labels.shape[1]
        self.dataset = dataset
    
    def __call__(self, batch):
        ids = torch.LongTensor(batch)
        batch_data = {'batch_size': torch.LongTensor([len(batch)]),
                      'numy': torch.LongTensor([self.numy]),
                      'y': csr_to_pad_tensor(self.dataset.labels[ids], self.numy),
                      'ids': ids,
                      'xfts': self.dataset.get_fts(ids)
                     }
                
        return batch_data

In [45]:
trn_dataset = XMLDataset(trnX, trn_X_Y, tokenizer, args.maxlen)
tst_dataset = XMLDataset(tstX, tst_X_Y, tokenizer, args.maxlen)

trn_loader = torch.utils.data.DataLoader(
    trn_dataset,
    batch_size=args.bsz,
    num_workers=2,
    collate_fn=XMLCollator(trn_dataset),
    shuffle=True,
    pin_memory=True)

tst_loader = torch.utils.data.DataLoader(
    tst_dataset,
    batch_size=args.bsz,
    num_workers=2,
    collate_fn=XMLCollator(tst_dataset),
    shuffle=False,
    pin_memory=True)

In [46]:
class Net(nn.Module):
    def __init__(self, encoder, args):
        super().__init__()
        self.encoder = encoder
        self.numy = args.numy
        self.dropout = nn.Dropout(args.dropout)
        if args.per_label_task:
            self.w = nn.Sequential(nn.Linear(args.embed_dim, 2*args.embed_dim), 
                                   nn.ReLU(), 
                                   nn.Linear(2*args.embed_dim, 1))
        else:
            self.w = nn.Linear(args.embed_dim, args.numy)
        
    def get_device(self):
        return list(self.parameters())[0].device
    
    def forward(self, b):
        mask = b['xfts']['input_mask']
        embs = self.encoder(b['xfts']['input_ids'], mask)
        
        if self.encoder.per_token_decoder:
            embs = embs * mask.unsqueeze(-1) / mask.sum(dim=-1).reshape(-1, 1, 1)
            embs = embs.sum(dim=1)
        else:
            embs = embs.squeeze()
        out = self.w(self.dropout(embs))
        return out.squeeze()
    
    def predict(self, tst_loader, K=100):
        tst_X_Y = tst_loader.dataset.labels
        data = np.zeros((tst_X_Y.shape[0], K))
        inds = np.zeros((tst_X_Y.shape[0], K)).astype(np.int32)
        indptr = np.arange(0, tst_X_Y.shape[0]*K+1, K)
        self.eval()

        with torch.no_grad():
            for b in tqdm(tst_loader, leave=True, desc='Evaluating'):
                b = ToD(b, self.get_device())
                out = self(b)
                top_data, top_inds = torch.topk(out, K)
                data[b['ids'].cpu()] = top_data.detach().cpu().numpy()
                inds[b['ids'].cpu()] = top_inds.detach().cpu().numpy()
                del top_data, top_inds, b, out

        torch.cuda.empty_cache()
        score_mat = sp.csr_matrix((data.ravel(), inds.ravel(), indptr), tst_X_Y.shape)
        
        return score_mat
    
class OvABCELoss(nn.Module):
    def __init__(self, args, reduction='mean'):
        super(OvABCELoss, self).__init__()
        if args.loss_with_logits:
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        else:
            self.criterion = torch.nn.BCELoss(reduction=reduction)

    def forward(self, model, b):
        out = model(b)
        targets = torch.zeros((out.shape[0], out.shape[1]+1), device=out.device).scatter_(1, b['y']['inds'], 1)[:, :-1]
        loss = self.criterion(out, targets)
        return loss

In [47]:
net = Net(encoder, args)
criterion = OvABCELoss(args)

In [48]:
optim_wrap = {
    'xc' : {'class': torch.optim.Adam, 'params': [], 'args': {'lr': args.xc_lr}},
    'enc': {'class': transformers.optimization.AdamW, 'params': [], 
            'args': {'lr': args.enc_lr, 'eps': 1e-06, 'weight_decay': 0.01}}
    }

for n,p in net.named_parameters():
    if 'query_task_embedding' in n or p.shape[-1] == args.numy or p.shape[0] == args.numy: 
        optim_wrap['xc']['params'].append((n, p))
    else: 
        optim_wrap['enc']['params'].append((n, p))
        
optims = []
for k, v in optim_wrap.items():
    if len(v['params']) > 0: optims.append(v['class']([x[1] for x in v['params']], **v['args']))
        

total_steps = len(trn_loader)*args.n_epochs
schedulers = [transformers.get_linear_schedule_with_warmup(optim, num_warmup_steps=int(args.warmup*total_steps), num_training_steps=total_steps) for optim in optims]

In [49]:
net.to(args.device);

In [50]:
optim_wrap['xc']

{'class': torch.optim.adam.Adam,
 'params': [('encoder.query_task_embedding.weight',
   Parameter containing:
   tensor([[ 0.1021,  2.1816,  1.3950,  ...,  0.1481, -0.6218,  0.3663],
           [ 0.5220,  0.2766,  0.7901,  ...,  0.5506, -1.8529, -0.0192],
           [ 0.9008, -0.6969, -0.5283,  ..., -0.0365, -1.0203, -0.9421],
           ...,
           [-0.8841, -0.3392,  0.0727,  ...,  0.4174,  1.1941,  1.9235],
           [-0.2528,  0.4482, -1.5365,  ..., -2.2480, -0.8965,  0.2644],
           [ 1.0314,  0.7671, -0.2131,  ...,  1.4414,  1.6199,  0.3878]],
          device='cuda:0', requires_grad=True))],
 'args': {'lr': 0.001}}

In [51]:
scaler = torch.cuda.amp.GradScaler()
best_ndcg = -100
for epoch in range(args.n_epochs):
    net.train()
    cum_loss = 0; ctr = 0
    t = tqdm(trn_loader, desc='Epoch: 0, Loss: 0.0', leave=True)
          
    for b in t:        
        for optim in optims: optim.zero_grad()
        b = ToD(b, args.device)
        with torch.cuda.amp.autocast(enabled=args.amp):
            loss = criterion(net, b)
            
        if args.amp:
            scaler.scale(loss).backward()
            for optim in optims: scaler.step(optim)
            scaler.update()
        else:
            loss.backward()
            for optim in optims: optim.step()
                
        for sch in schedulers: sch.step()
        cum_loss += loss.item()
        ctr += 1
        t.set_description('Epoch: %d/%d, Loss: %.4E'%(epoch, args.n_epochs, (cum_loss/ctr)), refresh=True)
    
    print(f'mean loss after epoch {epoch}/{args.n_epochs}: {"%.4E"%(cum_loss/ctr)}', flush=True)
    if epoch%args.eval_interval == 0 or epoch == (args.n_epochs-1):
        score_mat = net.predict(tst_loader)
        metrics = XCMetrics(score_mat, tst_X_Y, inv_prop, method=args.expname, disp=True)

        if metrics.loc[args.expname]['nDCG@5'] > best_ndcg:
            best_ndcg = metrics.loc[args.expname]['nDCG@5']
            print(_c(f'Found new best model with nDCG@5: {"%.2f"%best_ndcg}\n', attr='blue'))
            sp.save_npz(f'{OUT_DIR}/score_mat.npz', score_mat)
            metrics.to_csv(open(f'{OUT_DIR}/metrics.tsv', 'a'), sep='\t')
            torch.save(net.state_dict(), f'{OUT_DIR}/model.pt')
    sys.stdout.flush()

Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 0/25: 5.3451E-02


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
7.17	4.52	7.17	5.19	2.08	2.33	7.4	12.65	27.58	12.76

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
7.17 4.52 7.17 5.19 2.08 2.33 7.4 12.65 27.58 12.76

[94mFound new best model with nDCG@5: 5.19
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 1/25: 9.3175E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 2/25: 8.9519E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
12.16	7.4	12.16	8.63	3.6	3.9	11.29	17.05	39.19	19.91

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
12.16 7.4 12.16 8.63 3.6 3.9 11.29 17.05 39.19 19.91

[94mFound new best model with nDCG@5: 8.63
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 3/25: 8.4119E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 4/25: 7.6494E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
32.06	20.52	32.06	24.19	10.66	12.58	27.57	37.13	60.05	44.45

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
32.06 20.52 32.06 24.19 10.66 12.58 27.57 37.13 60.05 44.45

[94mFound new best model with nDCG@5: 24.19
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 5/25: 7.0542E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 6/25: 6.5698E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
48.46	29.04	48.46	35.05	17.88	19.67	37.76	47.23	69.32	59.9

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
48.46 29.04 48.46 35.05 17.88 19.67 37.76 47.23 69.32 59.9

[94mFound new best model with nDCG@5: 35.05
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 7/25: 6.1617E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 8/25: 5.8805E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
55.89	34.35	55.89	41.16	21.35	24.07	43.71	53.89	74.06	66.77

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
55.89 34.35 55.89 41.16 21.35 24.07 43.71 53.89 74.06 66.77

[94mFound new best model with nDCG@5: 41.16
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 9/25: 5.5313E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 10/25: 5.2119E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
61.89	37.71	61.89	45.34	24.89	27.79	48.13	57.93	76.53	71.74

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
61.89 37.71 61.89 45.34 24.89 27.79 48.13 57.93 76.53 71.74

[94mFound new best model with nDCG@5: 45.34
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 11/25: 4.9492E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 12/25: 4.6633E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
67.06	40.58	67.06	48.86	28.07	30.94	51.18	60.84	78.4	75.87

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
67.06 40.58 67.06 48.86 28.07 30.94 51.18 60.84 78.4 75.87

[94mFound new best model with nDCG@5: 48.86
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 13/25: 4.3943E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 14/25: 4.1377E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
67.09	41.96	67.09	50.16	28.58	32.37	52.28	61.61	79.06	76.28

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
67.09 41.96 67.09 50.16 28.58 32.37 52.28 61.61 79.06 76.28

[94mFound new best model with nDCG@5: 50.16
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 15/25: 3.8690E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 16/25: 3.5955E-03


Evaluating:   0%|          | 0/121 [00:00<?, ?it/s]



P@1	P@5	nDCG@1	nDCG@5	PSP@1	PSP@5	R@10	R@20	R@100	MRR@10
69.42	43.33	69.42	51.99	30.45	34.61	53.72	63.02	79.68	78.21

P@1 P@5 nDCG@1 nDCG@5 PSP@1 PSP@5 R@10 R@20 R@100 MRR@10
69.42 43.33 69.42 51.99 30.45 34.61 53.72 63.02 79.68 78.21

[94mFound new best model with nDCG@5: 51.99
[0m


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



mean loss after epoch 17/25: 3.3243E-03


Epoch: 0, Loss: 0.0:   0%|          | 0/483 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [35]:
num_params = 0
for p in model.parameters():
    num_params += np.prod(p.shape)

num_params

NameError: name 'model' is not defined