In [1]:
import sys
sys.path.insert(0, './language_model/')

# for suppressing T.save warnings
# see https://discuss.pyT.org/t/got-warning-couldnt-retrieve-source-code-for-container/7689
import warnings
warnings.simplefilter('ignore', UserWarning)

In [2]:
import os
from timeit import default_timer as timer
from functools import reduce
import operator
import json

import pandas as pd
import scipy.stats as stats

import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as O
from torch.utils.data import DataLoader

from dataset import StandardDataset
from language_model.lm_train import train_language_model
from language_model.lm_prob import LMProb
from ml_utils.config import Config

import pprint
pp = pprint.PrettyPrinter(width=180, indent=2, compact=False)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

print(f'GPU: {T.cuda.is_available()} | CUDA: {T.version.cuda}')

def from_home(x):
    return os.path.join(os.environ['HOME'], x)

GPU: True | CUDA: 10.1


# 1. Setup

In [7]:
ROOT_DIR   = from_home('workspace/ml-data/msc-research')

# DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/testing') # simple django
DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/django')
CONALA_DIR = os.path.join(ROOT_DIR, 'raw-datasets/conala-corpus')

DATASET_DIR = CONALA_DIR
EMB_DIR     = os.path.join(ROOT_DIR, 'embeddings')

print(f'Dataset: {os.path.basename(DATASET_DIR)}')

Dataset: conala-corpus


## 1.1. Read dataset

In [8]:
a = [len(l.strip().split()) for l in open(DATASET_DIR + '/all.anno').readlines()]
c = [len(l.strip().split()) for l in open(DATASET_DIR + '/all.code').readlines()]
assert len(a) == len(c)

d = pd.DataFrame([{'a': _a, 'c': _c} for (_a, _c) in zip(a, c)])
d.describe()

a = round(len(list(filter(lambda x: x <= 10, a))) / len(a), 3)
c = round(len(list(filter(lambda x: x <= 10, c))) / len(c), 3)
a, c

(1.0, 1.0)

## 1.2. Construct config

In [28]:
CFG = Config() # main config

# sub-config for dataset
CFG.dataset_cfg = Config()
CFG.dataset_cfg.__dict__ = {
    'root_dir': DATASET_DIR,
    'anno_min_freq': 1,
    'code_min_freq': 1,
    'anno_seq_maxlen': 10,
    'code_seq_maxlen': 10,
    'emb_file': os.path.join(EMB_DIR, 'glove.6B.50d.txt.pickle'),
}

dataset = StandardDataset(config=CFG.dataset_cfg)

# sub-config for NL intents
CFG.anno = Config() 
CFG.anno.__dict__ = {
    'lstm_hidden_size': 64,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.anno_lang,
    'load_pretrained_emb': True,
    'emb_size': 50,
}

# sub-config for source code
CFG.code = Config() 
CFG.code.__dict__ = {
    'lstm_hidden_size': 128,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.code_lang,
    'load_pretrained_emb': False,
    'emb_size': 16,
}

CFG.__dict__.update({
    'cuda': False,
    'batch_size': 32,
    'num_epochs': 100
})

---

In [None]:
# toks = dataset.code_lang.to_numeric('return dict', tokenize_mode='anno', pad_mode='post', max_len=10)
# ws = dataset.code_lang.to_tokens(T.tensor(toks))

In [None]:
# i = np.random.randint(len(dataset))
# a, c = dataset[i]
# assert len(a) == CFG.dataset_cfg.anno_seq_maxlen, f'{i}'
# assert len(c) == CFG.dataset_cfg.code_seq_maxlen, f'{i}'
# pp.pprint(a)
# pp.pprint(dataset.anno_lang.to_tokens(a))
# print('-'*120)
# pp.pprint(c)
# pp.pprint(dataset.code_lang.to_tokens(c))

# 2. Compute LM probabilities

## 2.1. Get train/test/valid splits

In [29]:
_tp, _vp = 0.1, 0.2
splits = dataset.train_test_valid_split(test_p=_tp, valid_p=_vp, seed=42)

for kind in splits:
    for t in splits[kind]:
        vs = splits[kind][t]
        vs = T.cat(vs)
        vs = vs[vs != 0]
        splits[kind][t] = vs
        
print(f'train {(1-_tp-_vp)*len(dataset):.2f} | test {_tp*len(dataset)} | dev {_vp*len(dataset)}')

train 315.00 | test 45.0 | dev 90.0


## 2.2. Train language model

**Note:** Must do this for both anno and code.

In [30]:
CFG.language_model = Config()
CFG.language_model.__dict__ = {
    'dataset'     : 'conala',
    'model'       : 'LSTM', # type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)
    'n_head'      : None,   # number of heads in the enc/dec of the Transformers
    'emb_size'    : 32,     # size of the word embeddings
    'n_hid'       : 64,     # number of hidden units per layer
    'n_layers'    : 1,      # number of layers
    'lr'          : 0.25,    # initial learning rate
    'clip'        : 0.25,   # gradient clipping
    'dropout_p'   : 0.05,    # dropout applied to layers
    'tied'        : False,  # whether to tie the word embeddings and softmax weights
    'log_interval': 100,
    'epochs'      : 100, # upper epoch limit
    'batch_size'  : 32,
    'seed'        : None # for reproducibility
}

# CFG.language_model

In [None]:
lm_cfg = CFG.language_model

for kind in ['anno', 'code']:
    print(f'Training LM for {kind}\n')

    lm_cfg.kind = kind
    lm_cfg.bptt = CFG.dataset_cfg.__dict__[f'{kind}_seq_maxlen'] # seq len
    lm_cfg.save_path = f'./data/lm/lm-{lm_cfg.dataset}-{lm_cfg.kind}.pt' # path to save the final model
    
#     train_language_model(lm_cfg, 
#                          num_tokens=len(getattr(dataset, f'{kind}_lang')),
#                          train_nums=T.stack(splits[kind]['train']),
#                          test_nums=T.stack(splits[kind]['test']),
#                          valid_nums=T.stack(splits[kind]['valid']))
    
    train_language_model(lm_cfg, 
                         num_tokens=len(getattr(dataset, f'{kind}_lang')),
                         train_nums=splits[kind]['train'],
                         test_nums=splits[kind]['test'],
                         valid_nums=splits[kind]['valid'])
    
    print('*' * 120, '\n')

## 2.3. Compute LM probs

In [31]:
lm_paths = {k: f'./data/lm/lm-{CFG.language_model.dataset}-{k}.pt' for k in ['anno', 'code']}

for f in lm_paths.values():
    assert os.path.exists(f), f'Language Model: file <{f}> does not exist!'
    
_ = dataset.compute_lm_probs(lm_paths)

HBox(children=(FloatProgress(value=0.0, description='P(anno)', max=450.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='P(code)', max=450.0, style=ProgressStyle(description_widt…




---

In [None]:
i = np.random.randint(len(dataset))
a, c, pa, pc = dataset[i]
' '.join(dataset.anno_lang.to_tokens(a)[0]), ' '.join(dataset.code_lang.to_tokens(c)[0])

In [None]:
# class MyLMProb:
#     def __init__(self, model_path):        
#         self.model = T.load(open(model_path, 'rb'), map_location={'cuda:0': 'cpu'})
#         self.model = self.model.cpu()
#         self.model.eval()

#     def get_prob(self, nums, verbose=False):
#         with T.no_grad():
#             inp = T.tensor([int(nums[0])]).long().unsqueeze(0)
#             hidden = self.model.init_hidden(bsz=1)
#             log_probs = []
            
#             for i in range(1, len(nums)):
#                 output, hidden = self.model(inp, hidden)
                
#                 #word_weights = output.squeeze().data.double().exp()
#                 #prob = word_weights[nums[i]] / word_weights.sum()
#                 probs = F.softmax(output.squeeze(), dim=-1)
#                 prob = probs[nums[i]]
                
#                 # append current log prob
#                 log_probs += [T.log(prob)]
#                 inp.data.fill_(int(nums[i]))

#             if verbose:
#                 for i in range(len(log_probs)):
#                     print(f'{nums[i+1]:4d}: P(w|s) = {np.exp(log_probs[i]):8.4f} | logP(w|s) = {log_probs[i]:8.4f}')
#                 print(f'=> sum_prob = {sum(log_probs):.4f}')

#         return sum(log_probs) / len(log_probs)

In [None]:
# lm_probs = {'anno': [], 'code': []}

# pad_idx = {
#     'anno': dataset.anno_lang.token2index['<pad>'],
#     'code': dataset.code_lang.token2index['<pad>']
# } 

# for kind in lm_probs:
#     lm = MyLMProb(lm_paths[kind])
#     p = pad_idx[kind]

#     for vec in tqdm(getattr(dataset, kind), total=len(dataset), desc=f'P({kind})'):
#         lm_probs[kind] += [np.exp(lm.get_prob(vec[vec != pad_idx[kind]], verbose=False))]
    
#     lm_probs[kind] = sum(lm_probs[kind])
#     break

In [None]:
# kind = 'anno'
# lm = MyLMProb(lm_paths[kind])
# s = {}
# for t, i in tqdm(getattr(dataset, f'{kind}_lang').token2index.items()):
#     if i in [0, 2, 3]:
#         continue
#     q = T.tensor([2, i, 3])
#     s[i] = np.exp(lm.get_prob(q))
    
# xs, ys = zip(*sorted(s.items(), key=lambda k: -k[1]))

# plt.figure(figsize=(14,6))
# plt.bar(xs, ys)
# plt.xticks(xs, rotation=90)

# sum(ys)

# 3. Dual CS/CG Model

In [32]:
def get_embeddings(config: Config):
    emb = nn.Embedding(len(config.lang), config.emb_size, padding_idx=config.lang.pad_idx)
    
    if config.load_pretrained_emb:
        assert config.lang.emb_matrix is not None
        emb.weight = nn.Parameter(T.tensor(config.lang.emb_matrix, dtype=T.float32))
        emb.weight.requires_grad = False
        
    return emb

In [33]:
class Model(nn.Module):
    def __init__(self, config: Config, model_type):
        """
        :param model_type: cs / cg
        cs: code -> anno
        cg: anno -> code
        """
        super(Model, self).__init__()
        
        assert model_type in ['cs', 'cg']
        self.model_type = model_type
        
        src_cfg = config.anno if model_type == 'cg' else config.code
        tgt_cfg = config.code if model_type == 'cg' else config.anno
        
        # 1. ENCODER
        self.src_embedding = get_embeddings(src_cfg)
        self.encoder = nn.LSTM(input_size=src_cfg.emb_size,
                               hidden_size=src_cfg.lstm_hidden_size,
                               dropout=src_cfg.lstm_dropout_p,
                               bidirectional=True,
                               batch_first=True)
        
        self.decoder_cell_init_linear = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                                  out_features=tgt_cfg.lstm_hidden_size)
        
        # 2. ATTENTION
        # project source encoding to decoder rnn's h space (W from Luong score general)
        self.att_src_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the attentional vector in (W from Luong eq. 5)
        self.att_vec_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size + tgt_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # 3. DECODER
        self.tgt_embedding = get_embeddings(tgt_cfg)
        self.decoder = nn.LSTMCell(input_size=tgt_cfg.emb_size + tgt_cfg.lstm_hidden_size,
                                   hidden_size=tgt_cfg.lstm_hidden_size)
       
        # prob layer over target language
        self.readout = nn.Linear(in_features=tgt_cfg.lstm_hidden_size,
                                 out_features=len(tgt_cfg.lang),
                                 bias=False)
        
        self.dropout = nn.Dropout(tgt_cfg.att_dropout_p)
        
        # 4. COPY MECHANISM
        self.copy_gate = ... # TODO
        
        # save configs
        self.src_cfg = src_cfg
        self.tgt_cfg = tgt_cfg
        
        device = T.device('cuda' if CFG.cuda else 'cpu')
        self.to(device)
        print(f'[{model_type}] using [{device}]')
        
        
    def forward(self, src, tgt):
        """
        src: bs, max_src_len
        tgt: bs, max_tgt_len
        """
        enc_out, (h0_dec, c0_dec) = self.encode(src)
        scores, att_mats = self.decode(enc_out, h0_dec, c0_dec, tgt)
        
        return scores, att_mats
    
    
    def encode(self, src):
        """
        src : bs x max_src_len (emb look-up indices)
        out : bs x max_src_len x 2*hid_size
        h/c0: bs x tgt_hid_size
        """
        emb = self.src_embedding(src)
        out, (hn, cn) = self.encoder(emb) # hidden is zero by default
        
        # construct initial state for the decoder
        c0_dec = self.decoder_cell_init_linear(T.cat([cn[0], cn[1]], dim=1))
        h0_dec = c0_dec.tanh()
        
        return out, (h0_dec, c0_dec)
    
    
    def decode(self, src_enc, h0_dec, c0_dec, tgt):
        """
        src_enc: bs, max_src_len, 2*hid_size (== encoder output)
        h/c0   : bs, tgt_hid_size
        tgt    : bs, max_tgt_len (emb look-up indices)
        """
        batch_size, tgt_len = tgt.shape
        scores, att_mats = [], []
        
        hidden = (h0_dec, c0_dec)
        
        emb = self.tgt_embedding(tgt) # bs, max_tgt_len, tgt_emb_size
        
        att_vec = T.zeros(batch_size, self.tgt_cfg.lstm_hidden_size, requires_grad=False)
        if CFG.cuda:
            att_vec = att_vec.cuda()
        
        # Luong W*hs: same for each timestep of the decoder
        src_enc_att = self.att_src_W(src_enc) # bs, max_src_len, tgt_hid_size
        
        for t in range(tgt_len):
            emb_t = emb[:, t, :]
            x = T.cat([emb_t, att_vec], dim=-1)
            h_t, c_t = self.decoder(x, hidden)

            ctx_t, att_mat = self.luong_attention(h_t, src_enc, src_enc_att)
            
            # Luong eq. (5)
            att_t = self.att_vec_W(T.cat([h_t, ctx_t], dim=1))
            att_t = att_t.tanh()
            att_t = self.dropout(att_t)
            
            # Luong eq. (6)
            score_t = self.readout(att_t)
            score_t = F.softmax(score_t, dim=-1)
            
            scores   += [score_t]
            att_mats += [att_mat]
            
            # for next state t+1
            att_vec = att_t
            hidden  = (h_t, c_t)
        
        # bs, max_tgt_len, tgt_vocab_size
        scores = T.stack(scores).permute((1, 0, 2))
        
        # each element: bs, max_src_len, max_tgt_len
        att_mats = T.cat(att_mats, dim=1)
        
        return scores, att_mats
            
        
    def luong_attention(self, h_t, src_enc, src_enc_att, mask=None):
        """
        h_t               : bs, hid_size
        src_enc (hs)      : bs, max_src_len, 2*src_hid_size 
        src_enc_att (W*hs): bs, max_src_len, tgt_hid_size
        mask              : bs, max_src_len
        
        ctx_vec    : bs, 2*src_hid_size
        att_weight : bs, max_src_len
        att_mat    : bs, 1, max_src_len
        """
        
        # bs x src_max_len
        score = T.bmm(src_enc_att, h_t.unsqueeze(2)).squeeze(2)
        
        if mask:
            score.data.masked_fill_(mask, -np.inf)
        
        att_mat = score.unsqueeze(1)
        att_weights = F.softmax(score, dim=-1)
        
        # sum per timestep
        ctx_vec = T.sum(att_weights.unsqueeze(2) * src_enc, dim=1)
        
        return ctx_vec, att_mat
    
    
    def translate(self, src_sents, beam_size, to_words=True):
        """
        Beam search
        src: input sequence (anno / code)
        """
        pass

# 4. Train

## 4.1. Setup

In [34]:
def get_vocab_mask(lang):
    mask = T.ones(len(lang))
    mask[lang.token2index['<pad>']] = 0
    return mask

def JSD(a, b, clone=True, mask=None):
    eps = 1e-8
    
    assert a.shape == b.shape
    _, n, _ = a.shape 
        
    xa, xb = (a.clone(), b.clone()) if clone else (a, b)        
    
    xa = F.softmax(xa, dim=2) + eps
    xb = F.softmax(xb, dim=2) + eps
    
    # common, averaged dist
    avg = 0.5 * (xa + xb)
    
    # kl
    xa = T.sum(xa * T.log(xa / avg), dim=2)
    xb = T.sum(xb * T.log(xb / avg), dim=2)
    
    # js
    xa = T.sum(xa, dim=1) / n
    xb = T.sum(xb, dim=1) / n
    
    return 0.5 * (xa + xb)


def JSD_2(A, B, mask=None):
    eps = 1e-8
    
    assert A.shape == B.shape
    b, n, m = A.shape
        
    js = []
    for bi in range(b):
        kl_a, kl_b = 0, 0
        
        for i in range(n):
            a = A[bi, i, :]
            b = B[bi, i, :]
            
            if mask is not None:
                a[mask[i]] = -(1e8)
                b[mask[i]] = -(1e8)
            
            a = F.softmax(a) + eps
            b = F.softmax(b) + eps
            m = 0.5 * (a + b)
            kl_a += stats.entropy(a, m) / n
            kl_b += stats.entropy(b, m) / n
        
        js += [0.5 * (kl_a + kl_b)]
    
    return T.tensor(js)

In [89]:
cg_model     = Model(CFG, model_type='cg')
cg_model.opt = O.Adam(lr=0.002, params=filter(lambda p: p.requires_grad, cg_model.parameters()))

cs_model     = Model(CFG, model_type='cs')
cs_model.opt = O.Adam(lr=0.002, params=filter(lambda p: p.requires_grad, cs_model.parameters()))

kwargs = {} # {'num_workers': 4, 'pin_memory': True}
train_loader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True, **kwargs)
print(f'DataLoader: {len(train_loader)} batches of size {CFG.batch_size}')

__cg_l = 0
__cs_l = 0
__att_l = 0
__dual_l = 0
__rep_every = 15

[cg] using [cpu]
[cs] using [cpu]
DataLoader: 15 batches of size 32


## 4.2. Loop

In [90]:
for epoch_idx in range(1, CFG.num_epochs+1):
    
    for batch_idx, (anno, code, anno_lm_p, code_lm_p) in enumerate(train_loader, start=1):        
        anno_len, code_len = anno.shape[1], code.shape[1]
        
        if CFG.cuda:
            anno, code, anno_lm_p, code_lm_p = map(lambda t: t.cuda(), [anno, code, anno_lm_p, code_lm_p])
            
        # binary mask indicating the presence of padding token
        anno_mask = T.tensor(anno == dataset.anno_lang.token2index['<pad>']).byte()
        code_mask = T.tensor(code == dataset.code_lang.token2index['<pad>']).byte()
            
        # forward pass
        code_pred, code_att_mat = cg_model(src=anno, tgt=code)
        anno_pred, anno_att_mat = cs_model(src=code, tgt=anno)
                                
        # loss computation
        l_cg_ce, l_cs_ce = 0, 0
        
        # CG cross-entropy loss
        for t in range(code_len):
            probs = code_pred[:, t, :].gather(1, code[:, t].view(-1, 1)).squeeze(1)
            l_cg_ce += -T.log(probs) / code_len
                    
        # CS cross-entropy loss
        for t in range(anno_len):
            probs = anno_pred[:, t, :].gather(1, anno[:, t].view(-1, 1)).squeeze(1)
            l_cs_ce += -T.log(probs) / anno_len
            
        # dual loss: P(x,y) = P(x).P(y|x) = P(y).P(x|y)
        l_dual = (code_lm_p - l_cs_ce - anno_lm_p + l_cg_ce) ** 2
                
        # attention loss: JSD
        l_att = JSD(anno_att_mat, code_att_mat.transpose(2,1), mask=code_mask) + \
                JSD(anno_att_mat.transpose(2,1), code_att_mat, mask=anno_mask)
                
        # final loss
        l_cg = T.mean(l_cg_ce + 0.01 * l_dual + 0.2 * l_att)
        l_cs = T.mean(l_cs_ce + 0.01 * l_dual + 0.2 * l_att)
                
        # optimize CG
        cg_model.opt.zero_grad()
        l_cg.backward(retain_graph=True)
        cg_model.opt.step()
                
        # optimize CS
        cs_model.opt.zero_grad()
        l_cs.backward()
        cs_model.opt.step()
                
        # reporting
        __cg_l   += l_cg.item()   / __rep_every
        __cs_l   += l_cs.item()   / __rep_every
        __att_l  += l_att.mean().item()  / __rep_every
        __dual_l += l_dual.mean().item() / __rep_every
        
        if batch_idx % __rep_every == 0:
            status = [f'Epoch {epoch_idx:>5d}/{CFG.num_epochs:>3d}', f'Batch {batch_idx:>5d}/{len(train_loader):5d}',
                      f'avg CG {__cg_l:7.5f}', f'avg CS {__cs_l:7.5f}', f'avg ATT {__att_l:7.5f}', f'avg DUAL {__dual_l:7.5f}']
            print(' | '.join(status))
            __cg_l, __cs_l, __att_l, __dual_l = 0, 0, 0, 0
    # --- epoch end
            
    # TODO...
#     if epoch_idx % 1 == 0:
#         with T.no_grad():
#             print()
#             ws = dataset.code_lang.to_tokens(code_pred.argmax(dim=-1))
#             i = np.random.randint(len(ws))
#             print(f'{i} pred: {" ".join(ws[i])}')
#             print(f'{i}  tgt: {" ".join(code[i].argmax())}')
            
#             print('\t'+'-'*80)
#             ws = dataset.anno_lang.to_tokens(anno_pred.argmax(dim=-1))
#             i = np.random.randint(len(ws))
#             print(f'\t{i}: {" ".join(ws[i])}')
#             print()

Epoch     1/100 | Batch    15/   15 | avg CG 4.96300 | avg CS 5.92062 | avg ATT 0.14292 | avg DUAL 0.39005
Epoch     2/100 | Batch    15/   15 | avg CG 3.61521 | avg CS 4.70281 | avg ATT 0.09453 | avg DUAL 0.18951
Epoch     3/100 | Batch    15/   15 | avg CG 3.31325 | avg CS 4.49830 | avg ATT 0.03157 | avg DUAL 0.25341
Epoch     4/100 | Batch    15/   15 | avg CG 3.15264 | avg CS 4.44265 | avg ATT 0.03369 | avg DUAL 0.28624
Epoch     5/100 | Batch    15/   15 | avg CG 2.82657 | avg CS 4.28564 | avg ATT 0.02600 | avg DUAL 0.35062
Epoch     6/100 | Batch    15/   15 | avg CG 2.49606 | avg CS 4.19147 | avg ATT 0.05679 | avg DUAL 0.56886
Epoch     7/100 | Batch    15/   15 | avg CG 2.19018 | avg CS 4.10501 | avg ATT 0.08498 | avg DUAL 0.87867
Epoch     8/100 | Batch    15/   15 | avg CG 1.91632 | avg CS 3.99943 | avg ATT 0.06911 | avg DUAL 1.28860
Epoch     9/100 | Batch    15/   15 | avg CG 1.73252 | avg CS 3.89443 | avg ATT 0.05709 | avg DUAL 1.56567
Epoch    10/100 | Batch    15/   15 |

# 5. Evaluate

## 5.1. Exact match

In [91]:
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

metrics = {k: 0 for k in [
    'em_anno', 'em_code', 'strict_em_anno', 'strict_em_code'
]}

with T.no_grad():
    for batch_idx, (anno, code, anno_lm_p, code_lm_p) in tqdm(enumerate(test_loader, start=1), total=len(test_loader)): 
        if CFG.cuda:
            anno, code, anno_lm_p, code_lm_p = map(lambda t: t.cuda(), [anno, code, anno_lm_p, code_lm_p])
            
        # forward pass
        code_pred, code_att_mat = cg_model(src=anno, tgt=code)
        anno_pred, anno_att_mat = cs_model(src=code, tgt=anno)
        
        # TODO: ideally, this should be beam-search
        code_pred = code_pred.argmax(dim=2)
        anno_pred = anno_pred.argmax(dim=2)
        
        code_score = T.mean((code_pred == code).float())
        anno_score = T.mean((anno_pred == anno).float())
        
        metrics['em_code'] += code_score / len(test_loader)
        metrics['em_anno'] += anno_score / len(test_loader)
        
        if np.isclose(code_score, 1):
            metrics['strict_em_code'] += 1 / len(test_loader)
        if np.isclose(anno_score, 1):
            metrics['strict_em_anno'] += 1 / len(test_loader)
        
for k, v in metrics.items():
    print(f'{k:>16s}: {v:7.5f}')

HBox(children=(FloatProgress(value=0.0, max=450.0), HTML(value='')))


         em_anno: 0.98622
         em_code: 1.00000
  strict_em_anno: 0.86222
  strict_em_code: 1.00000


## 5.2. Attention matrices

In [None]:
with T.no_grad():
    a, c = anno[[1]], code[[1]]
    x, x_mat = cg_model(src=a, tgt=c)
    y, y_mat = cs_model(src=c, tgt=a)
    x = x[0].cpu()
    x_mat = x_mat[0].cpu()
    y = y[0].cpu()
    y_mat = y_mat[0].cpu()

In [None]:
ct = dataset.code_lang.to_tokens(c)[0]
ct

In [None]:
at = dataset.anno_lang.to_tokens(a)[0]
at

In [None]:
xt = dataset.code_lang.to_tokens(x.argmax(dim=1))[0]
xt

In [None]:
yt = dataset.anno_lang.to_tokens(y.argmax(dim=1))[0]
yt

In [None]:
plt.figure(figsize=(12,8))
plt.imshow(F.softmax(x_mat, -1), cmap='jet')
plt.grid(False)
plt.xticks(ticks=np.arange(len(yt)), labels=yt, rotation=90)
plt.xlabel('anno')
plt.yticks(ticks=np.arange(len(xt)), labels=xt, rotation=0)
plt.ylabel('code')
plt.colorbar()

In [None]:
plt.figure(figsize=(12,8))
plt.imshow(F.softmax(y_mat, -1), cmap='jet')
plt.grid(False)
plt.xticks(ticks=np.arange(len(ct)), labels=ct, rotation=90)
plt.xlabel('code')
plt.yticks(ticks=np.arange(len(yt)), labels=yt, rotation=0)
plt.ylabel('anno')
plt.colorbar()