In [1]:
%cd /home/nilesh/work/dl_base

/home/nilesh/work/dl_base


In [2]:
import sys, os, time, socket, yaml, wandb, logging
import logging.config
from tqdm import tqdm

from nets import *
from losses import *
from optimizer_bundles import *
from resources import _c, load_config_and_runtime_args, dump_diff_config
from datasets import DATA_MANAGERS, XMCEvaluator
from dl_helper import unwrap, expand_multilabel_dataset

import torch
import transformers
transformers.set_seed(42)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Config and runtime argument parsing
args = load_config_and_runtime_args("python configs/seq2seq.yaml".split())

In [4]:
args.lr = 1e-3
args.num_epochs = 100
args.eval_topk = 5
args.eval_interval = 10

In [5]:
args.num_gpu = 1
args.device = 'cuda:3'
args.use_grad_scaler = False
args.hostname = socket.gethostname()
args.exp_start_time = time.ctime()
args.DATA_DIR = DATA_DIR = f'Datasets/{args.dataset}'
args.OUT_DIR = OUT_DIR = f'Results/{args.project}/{args.dataset}/{args.expname}'
os.makedirs(OUT_DIR, exist_ok=True)
torch.cuda.set_device(args.device)

In [6]:
with open('configs/logging.yaml') as f:
    log_config = yaml.safe_load(f.read())
    log_config['handlers']['file_handler']['filename'] = f"{OUT_DIR}/{log_config['handlers']['file_handler']['filename']}"
    logging.config.dictConfig(log_config)

In [7]:
args

Namespace(project='Seq2Seq', expname='seq2seq-label-text', desc='Learn a sequence to sequence model to generate label text', dataset='EURLex-4K', net='t5', loss='tf-loss', data_manager='two-tower', tf='t5-small', tf_max_len=128, save=True, resume_path='', data_tokenization='offline', num_val_points=0, track_metric='nDCG@5', transpose_trn_dataset=False, optim_bundle='base', optim='adamw', num_epochs=50, dropout=0.5, warmup=0.1, bsz=1024, eval_interval=10, eval_topk=5, w_accumulation_steps=1, lr=0.002, weight_decay=0.01, amp_encode=False, norm_embs=False, use_swa=False, swa_start=8, swa_step=1000, num_gpu=1, device='cuda:3', use_grad_scaler=False, hostname='habanero.csres.utexas.edu', exp_start_time='Fri Oct  7 22:55:40 2022', DATA_DIR='Datasets/EURLex-4K', OUT_DIR='Results/Seq2Seq/EURLex-4K/seq2seq-label-text')

In [8]:
criterion = LOSSES[args.loss](args)

In [9]:
# Data loading
from datasets import TwoTowerDataset

data_manager = DATA_MANAGERS[args.data_manager](args)
trn_dataset, val_dataset, _ = data_manager.build_datasets()
expand_multilabel_dataset(trn_dataset.x_dataset, copy=False, multiclass=True)
data_manager.trn_dataset = TwoTowerDataset(trn_dataset.x_dataset, trn_dataset.y_dataset)

trn_loader, val_loader, _ = data_manager.build_data_loaders()

In [10]:
t1 = trn_dataset.x_dataset.tokenizer

In [11]:
t1.decode(t1.encode_plus('hello my name is nilesh')['input_ids'])

'hello my name is nilesh</s>'

In [12]:
t1.decode(trn_dataset.get_fts([0], 'x')['input_ids'][0])

'decis ec european parliament council establish multiannu commun programm promot safer internet onlin technolog text eea relev european parliament council european union regard treati establish european commun articl thereof regard propos commiss regard opinion european econom social committe consult committe region act accord procedur laid articl treati internet penetr technolog mobil phone grow consider commun alongsid danger children abus technolog continu exist danger abus emerg order encourag exploit opportun offer internet onlin technolog mea</s>'

In [13]:
from tokenizers import normalizers, pre_tokenizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
from tokenizers.pre_tokenizers import Whitespace, Punctuation
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import WordPieceTrainer, BpeTrainer
from tokenizers import Tokenizer
from tokenizers.models import WordPiece, BPE

label_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
label_tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
label_tokenizer.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Punctuation()])
label_tokenizer.post_processor = TemplateProcessing(
    single="$A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", 1),
        ("[SEP]", 2),
    ],
)
trainer = BpeTrainer(vocab_size=4000, special_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[MASK]"])

In [14]:
Y = [x.strip() for x in open(f'{DATA_DIR}/Y.txt')]
label_tokenizer.train_from_iterator(Y, trainer)
label_tokenizer.save(f'{OUT_DIR}/label_tokenizer.pt')
label_tokens = label_tokenizer.encode_batch(Y)
Y_ii = torch.nn.utils.rnn.pad_sequence([torch.LongTensor(x.ids) for x in label_tokens], batch_first=True, padding_value=label_tokenizer.token_to_id('[PAD]'))
data_manager.trn_dataset.y_dataset.X_ii = Y_ii






In [142]:
from transformers import T5ForConditionalGeneration
from dl_helper import BatchIterator

from genre_trie import MarisaTrie
Y = [x.strip() for x in open(f'{DATA_DIR}/Y.txt')]
inv_Y = {v: i for i, v in enumerate(Y)}
trie = MarisaTrie([val_dataset.get_fts([y], source='y')['input_ids'][0] for y in np.arange(len(val_dataset.y_dataset))], cache_fist_branch=False)

def get_trie_candidates(batch_id, x):
    return trie.get(x[1:])

Y_size = label_tokenizer.get_vocab_size()
def get_target_mask(seq, other_pos_seqs):
    with torch.no_grad():
        trie_mask = torch.zeros(len(seq), Y_size).bool()
        target_mask = torch.ones(len(seq), Y_size).bool()
        trie_mask.scatter_(1, seq.reshape(-1, 1), True)
        for i in range(len(seq)):
            candidates = np.array(trie.get(seq[:i]))
            if len(candidates) < 2:
                break
            trie_mask[i, candidates] = True
            trie_mask[i, other_pos_seqs[:, i]] = False
            trie_mask[i, seq[i]] = True
            
            target_mask[i, np.intersect1d(other_pos_seqs[:, i], candidates)] = False
            target_mask[i, seq[i]] = True
        return trie_mask, target_mask
    
class T5(SWANet):
    def __init__(self, args):
        super().__init__(args)
        tf_args = {'tie_word_embeddings': False}
        self.tf = T5ForConditionalGeneration.from_pretrained(args.tf, **tf_args)

    def forward(self, b):
        b = self.ToD(b)
        return self.tf(**b['xfts'], labels=b['yfts']['input_ids'])

    def predict(self, data_loader, bsz=256, K=5):
        self.eval()
        data_iter = BatchIterator(data_loader.dataset.x_dataset, bsz)
        out_seq, out_scores = [], []
        with torch.no_grad():
            for b in tqdm(data_iter):
                b = self.ToD(b)
                batch_out = self.tf.generate(
                    **b['xfts'], 
                    num_return_sequences=K, 
                    num_beams=K, 
                    early_stopping=True, 
                    output_scores=True, 
                    return_dict_in_generate=True,
                    prefix_allowed_tokens_fn=get_trie_candidates)
                batch_out_seq, batch_out_scores = batch_out.sequences, batch_out.sequences_scores
                out_seq.append(batch_out_seq.detach().cpu().numpy().reshape(-1, K, batch_out_seq.shape[1]))
                out_scores.append(batch_out_scores.detach().cpu().numpy().reshape(-1, K))

        out_lbls = [[inv_Y[''.join([x.strip() for x in label_tokenizer.decode(seq, skip_special_tokens=True)])] for instance_seq in batch_seq for seq in instance_seq] for batch_seq in out_seq]
        data = np.concatenate(out_scores).flatten()
        inds = np.concatenate(out_lbls)
        indptr = np.arange(0, len(data)+1, K)
        score_mat = sp.csr_matrix((data, inds, indptr), data_loader.dataset.labels.shape)
        return score_mat
        
NETS['t5'] = T5

In [132]:
data_manager.trn_X_Y[0].indices

array([  47,  522,  630, 1840, 1932, 2852], dtype=int32)

In [143]:
%timeit ret = get_target_mask(Y_ii[47], Y_ii[[522,  630, 1840, 1932, 2852]])

2.41 ms ± 77.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


tensor([False, False, False, False, False])

In [None]:
class TFLoss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.alpha = 5
        self.lamda = 0.05

    def forward(self, model, b):
        out = model(b)
        targets = torch.zeros_like(out.logits)
        targets.scatter_(-1, b['yfts']['input_ids'].unsqueeze(-1), 1.0)
        loss = nn.BCELoss()(out.logits, targets)
        return loss

In [16]:
net = NETS[args.net](args)
criterion = LOSSES[args.loss](args)
evaluator = XMCEvaluator(args, val_loader, data_manager, 'val')
optim_bundle = OPTIM_BUNDLES[args.optim_bundle](args)

if os.path.exists(args.resume_path):
    logging.info(f'Loading net state dict from: {args.resume_path}')
    logging.info(net.load(args.resume_path))

In [17]:
net.tf.lm_head = nn.Linear(net.tf.lm_head.in_features, label_tokenizer.get_vocab_size())

net.tf.config.decoder_start_token_id = label_tokenizer.token_to_id('[CLS]')
net.tf.config.eos_token_id = label_tokenizer.token_to_id('[SEP]')
net.tf.config.pad_token_id = label_tokenizer.token_to_id('[PAD]')
net.tf.decoder.config.vocab_size = label_tokenizer.get_vocab_size()

net.tf.config.max_length = 32
net.tf.config.min_length = 1
net.tf.config.no_repeat_ngram_size = 3
net.tf.config.early_stopping = True
net.tf.config.length_penalty = 1.0
net.tf.config.num_beams = 5

In [18]:
net.to('cuda')
optim_bundle.inject_params(net)
optim_bundle.init_schedulers(args, len(trn_loader))
logging.info(optim_bundle)

INFO - root - 07-Oct-22 22:55:53 : [1m[94m<class 'torch.optim.adamw.AdamW'>(base) ({'lr': 0.002}, accum: 1): [0m tf.shared.weight tf.encoder.block.0.layer.0.SelfAttention.q.weight tf.encoder.block.0.layer.0.SelfAttention.k.weight tf.encoder.block.0.layer.0.SelfAttention.v.weight tf.encoder.block.0.layer.0.SelfAttention.o.weight tf.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight tf.encoder.block.0.layer.0.layer_norm.weight tf.encoder.block.0.layer.1.DenseReluDense.wi.weight tf.encoder.block.0.layer.1.DenseReluDense.wo.weight tf.encoder.block.0.layer.1.layer_norm.weight tf.encoder.block.1.layer.0.SelfAttention.q.weight tf.encoder.block.1.layer.0.SelfAttention.k.weight tf.encoder.block.1.layer.0.SelfAttention.v.weight tf.encoder.block.1.layer.0.SelfAttention.o.weight tf.encoder.block.1.layer.0.layer_norm.weight tf.encoder.block.1.layer.1.DenseReluDense.wi.weight tf.encoder.block.1.layer.1.DenseReluDense.wo.weight tf.encoder.block.1.layer.1.layer_norm.weight tf.enco

In [125]:
for b in trn_loader:
    b = net.ToD(b)
    break

In [126]:
out = net(b)

In [130]:
out.logits

tensor([[[ 1.4331e+00, -4.6934e+00,  1.5489e+00,  ...,  4.9096e-01,
          -4.0118e+00, -2.9179e+00],
         [ 2.1655e+00, -5.0546e+00,  2.1445e+01,  ..., -5.3368e+00,
          -7.2699e+00, -1.1005e+00],
         [ 2.0349e+01, -5.9286e+00, -2.6120e-01,  ..., -6.4735e+00,
          -5.6312e+00, -1.0281e+00],
         [ 2.0478e+01, -5.9548e+00, -3.7392e-02,  ..., -6.6003e+00,
          -5.6674e+00, -9.5759e-01],
         [ 2.0481e+01, -5.9625e+00, -5.5562e-02,  ..., -6.5869e+00,
          -5.6471e+00, -9.7824e-01],
         [ 2.0472e+01, -5.9733e+00, -1.7975e-02,  ..., -6.5771e+00,
          -5.6551e+00, -9.9055e-01]],

        [[-1.8961e+00, -3.0166e+00,  6.1340e-02,  ..., -1.1566e-01,
          -4.3922e+00, -2.4037e+00],
         [ 3.0239e+00, -5.5895e+00,  2.2094e+01,  ..., -5.8928e+00,
          -7.8450e+00, -1.0512e+00],
         [ 2.0453e+01, -6.1208e+00, -6.0008e-02,  ..., -6.3928e+00,
          -5.7926e+00, -1.0604e+00],
         [ 2.0574e+01, -6.1919e+00,  6.4543e-02,  ...

In [19]:
# Training loop
scaler = torch.cuda.amp.GradScaler()
global_step = 0

for epoch in range(args.num_epochs):
    epoch_loss = 0
    net.train()
    optim_bundle.zero_grad()

    t = tqdm(trn_loader, desc='Epoch: 0, Loss: 0.0', leave=True)
    for i, b in enumerate(t):
        loss = criterion(net, b)
        loss.backward()
        
        optim_bundle.step_and_zero_grad(scaler if args.use_grad_scaler else None)
        unwrap(net).update_non_parameters(epoch, global_step)

        epoch_loss += loss.item()
        global_step += 1
        t.set_description('Epoch: %d/%d, Loss: %.4E'%(epoch, args.num_epochs, (epoch_loss/(i+1))), refresh=True)
    
    epoch_loss = (epoch_loss/(i+1))
    logging.info(f'Mean loss after epoch {epoch}/{args.num_epochs}: {"%.4E"%(epoch_loss)}')
    metrics = evaluator.predict_and_track_eval(unwrap(net), epoch, epoch_loss)
    if metrics is not None:
        logging.info('\n'+metrics.to_csv(sep='\t', index=False))

Epoch: 0/50, Loss: 6.0472E+00: 100%|██████████| 81/81 [00:42<00:00,  1.91it/s]

INFO - root - 07-Oct-22 22:56:36 : Mean loss after epoch 0/50: 6.0472E+00



  next_indices = next_tokens // vocab_size
100%|██████████| 4/4 [01:06<00:00, 16.65s/it]


[94mFound new best model with nDCG@5: 0.79
[0m
INFO - root - 07-Oct-22 22:57:43 : 
P@1	P@3	P@5	nDCG@1	nDCG@3	nDCG@5	PSP@1	PSP@3	PSP@5	R@10	R@20	R@50	MRR@10	loss
0.23	1.08	0.78	0.23	0.93	0.79	0.15	0.58	0.55	0.73	0.73	0.73	1.71	6.0472E+00



Epoch: 1/50, Loss: 1.2081E+00: 100%|██████████| 81/81 [00:41<00:00,  1.96it/s]

INFO - root - 07-Oct-22 22:58:25 : Mean loss after epoch 1/50: 1.2081E+00



Epoch: 2/50, Loss: 9.1413E-01: 100%|██████████| 81/81 [00:41<00:00,  1.95it/s]

INFO - root - 07-Oct-22 22:59:06 : Mean loss after epoch 2/50: 9.1413E-01



Epoch: 3/50, Loss: 7.2249E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 22:59:48 : Mean loss after epoch 3/50: 7.2249E-01



Epoch: 4/50, Loss: 5.5545E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:00:30 : Mean loss after epoch 4/50: 5.5545E-01



Epoch: 5/50, Loss: 4.1543E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:01:12 : Mean loss after epoch 5/50: 4.1543E-01



Epoch: 6/50, Loss: 3.5508E-01: 100%|██████████| 81/81 [00:42<00:00,  1.92it/s]

INFO - root - 07-Oct-22 23:01:54 : Mean loss after epoch 6/50: 3.5508E-01



Epoch: 7/50, Loss: 3.2903E-01: 100%|██████████| 81/81 [00:41<00:00,  1.97it/s]

INFO - root - 07-Oct-22 23:02:35 : Mean loss after epoch 7/50: 3.2903E-01



Epoch: 8/50, Loss: 3.0653E-01: 100%|██████████| 81/81 [00:41<00:00,  1.93it/s]

INFO - root - 07-Oct-22 23:03:17 : Mean loss after epoch 8/50: 3.0653E-01



Epoch: 9/50, Loss: 2.8784E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:03:59 : Mean loss after epoch 9/50: 2.8784E-01



Epoch: 10/50, Loss: 2.6414E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:04:40 : Mean loss after epoch 10/50: 2.6414E-01



  next_indices = next_tokens // vocab_size
100%|██████████| 4/4 [01:39<00:00, 24.88s/it]


[94mFound new best model with nDCG@5: 53.03
[0m
INFO - root - 07-Oct-22 23:06:21 : 
P@1	P@3	P@5	nDCG@1	nDCG@3	nDCG@5	PSP@1	PSP@3	PSP@5	R@10	R@20	R@50	MRR@10	loss
67.68	54.97	45.07	67.68	58.15	53.03	33.32	36.84	38.45	44.21	44.21	44.21	78.56	2.6414E-01



Epoch: 11/50, Loss: 2.6203E-01: 100%|██████████| 81/81 [00:41<00:00,  1.96it/s]

INFO - root - 07-Oct-22 23:07:02 : Mean loss after epoch 11/50: 2.6203E-01



Epoch: 12/50, Loss: 2.5555E-01: 100%|██████████| 81/81 [00:41<00:00,  1.96it/s]

INFO - root - 07-Oct-22 23:07:43 : Mean loss after epoch 12/50: 2.5555E-01



Epoch: 13/50, Loss: 2.4944E-01: 100%|██████████| 81/81 [00:41<00:00,  1.93it/s]

INFO - root - 07-Oct-22 23:08:25 : Mean loss after epoch 13/50: 2.4944E-01



Epoch: 14/50, Loss: 2.4497E-01: 100%|██████████| 81/81 [00:41<00:00,  1.93it/s]

INFO - root - 07-Oct-22 23:09:07 : Mean loss after epoch 14/50: 2.4497E-01



Epoch: 15/50, Loss: 2.3911E-01: 100%|██████████| 81/81 [00:41<00:00,  1.95it/s]

INFO - root - 07-Oct-22 23:09:49 : Mean loss after epoch 15/50: 2.3911E-01



Epoch: 16/50, Loss: 2.2786E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:10:30 : Mean loss after epoch 16/50: 2.2786E-01



Epoch: 17/50, Loss: 2.1851E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:11:12 : Mean loss after epoch 17/50: 2.1851E-01



Epoch: 18/50, Loss: 2.1961E-01: 100%|██████████| 81/81 [00:41<00:00,  1.95it/s]

INFO - root - 07-Oct-22 23:11:54 : Mean loss after epoch 18/50: 2.1961E-01



Epoch: 19/50, Loss: 2.1731E-01: 100%|██████████| 81/81 [00:41<00:00,  1.96it/s]

INFO - root - 07-Oct-22 23:12:35 : Mean loss after epoch 19/50: 2.1731E-01



Epoch: 20/50, Loss: 2.1251E-01: 100%|██████████| 81/81 [00:41<00:00,  1.94it/s]

INFO - root - 07-Oct-22 23:13:17 : Mean loss after epoch 20/50: 2.1251E-01



  next_indices = next_tokens // vocab_size
100%|██████████| 4/4 [02:17<00:00, 34.46s/it]


KeyError: 'translation_centre_for_the_bodies_of_the))))))))))))))))))'

In [40]:
net.save(f'{OUT_DIR}/model.pt')
# net.load(f'{OUT_DIR}/model.pt')

In [117]:
nnz = data_manager.trn_X_Y.getnnz(0)

In [113]:
score_mat = sp.load_npz(f'{OUT_DIR}/val_score_mat.npz')

In [115]:
from resources import compute_xmc_metrics
K = 5
compute_xmc_metrics(score_mat, val_dataset.labels, data_manager.inv_prop, K=K, disp=False)

Unnamed: 0,P@1,P@3,P@5,nDCG@1,nDCG@3,nDCG@5,PSP@1,PSP@3,PSP@5,R@10,R@20,R@50,MRR@10
Method,75.58,68.54,59.06,75.58,70.62,66.79,40.18,48.98,53.5,57.96,57.96,57.96,84.69


In [124]:
from resources import vis_point

x = np.random.randint(score_mat.shape[0])
vis_point(x, score_mat, [' ']*score_mat.shape[0], Y, nnz, val_dataset.labels)

x[3092]: [1m [0m

1) [33mindustrial_region[0m [1823] (-0.6737, 66)

2) [33mcoordination_of_aid[0m [749] (-0.7332, 59)

3) [33mregional_development[0m [2950] (-0.8451, 104)

4) [33mregional_planning[0m [2955] (-0.9912, 76)

5) [33mcommunity_financial_instrument[0m [629] (-1.0156, 198)



In [121]:
x = 2
print(get_text(x, Y, val_loader.dataset.labels))
print(*[_c(*x, attr='blue') for x in [(tokenizer.decode(out_seq[0][x, i], skip_special_tokens=True), out_scores[0][x, i]) for i in range(K)]], sep=', ')

2 : 
[1m[4mexport_refund[0m(1.00, 1360) [1m[4mfishery_product[0m(1.00, 1465) [1m[4moriginating_product[0m(1.00, 2535) [1m[4mprocessed_foodstuff[0m(1.00, 2779) [1m[4mship's_flag[0m(1.00, 3169) [1m[4mthird_country[0m(1.00, 3475)
[94mfishery _ product -0.46635327[0m, [94mexport _ refund -0.5364342[0m, [94mexport _ licence -0.92781454[0m, [94mimport _ policy -1.001939[0m, [94mmarketing _ standard -1.0298482[0m
