In [None]:
import sys
sys.path.insert(0, './language_model/')

# for suppressing torch.save warnings
# see https://discuss.pytorch.org/t/got-warning-couldnt-retrieve-source-code-for-container/7689
import warnings
warnings.simplefilter('ignore', UserWarning)

In [None]:
import os
from timeit import default_timer as timer
from functools import reduce
import operator

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as O
from torch.utils.data import DataLoader

from django import Django
from ml_utils.config import Config
from language_model.lm_train import train_language_model
from language_model.lm_prob import LMProb


import pprint
pp = pprint.PrettyPrinter(width=180, indent=2, compact=False)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

In [None]:
def from_home(x):
    return os.path.join(os.environ['HOME'], x)

## Config

In [None]:
ROOT_DIR   = from_home('workspace/ml-data/msc-research')
# DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/testing') # simple django
DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/django')
EMB_DIR    = os.path.join(ROOT_DIR, 'embeddings')

CFG = Config() # main config

# sub-config for dataset
CFG.dataset_config = Config()
CFG.dataset_config.__dict__ = {
    'root_dir': DJANGO_DIR,
    'anno_min_freq': 1,
    'code_min_freq': 1,
    'anno_seq_maxlen': 30,
    'code_seq_maxlen': 20,
    'emb_file': os.path.join(EMB_DIR, 'glove.6B.50d.txt.pickle'),
}

dataset = Django(config=CFG.dataset_config)

# sub-config for NL intents
CFG.anno = Config() 
CFG.anno.__dict__ = {
    'lstm_hidden_size': 128,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.anno_lang,
    'load_pretrained_emb': True,
    'emb_size': 50,
}

# sub-config for source code
CFG.code = Config() 
CFG.code.__dict__ = {
    'lstm_hidden_size': 128,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.code_lang,
    'load_pretrained_emb': False,
    'emb_size': 50,
}

In [None]:
toks = dataset.code_lang.to_numeric('return func(1+a)', tokenize_mode='code', pad_mode='post', max_len=10)
ws = dataset.code_lang.to_tokens(torch.tensor(toks))

In [None]:
i = np.random.randint(len(dataset))
a, c = dataset[i]
assert len(a) == CFG.dataset_config.anno_seq_maxlen, f'{i}'
assert len(c) == CFG.dataset_config.code_seq_maxlen, f'{i}'
pp.pprint(a)
pp.pprint(dataset.anno_lang.to_tokens(a))
print('-'*120)
pp.pprint(c)
pp.pprint(dataset.code_lang.to_tokens(c))

## Compute LM probabilities

### Get train/test/valid splits

In [None]:
# splits = dataset.train_test_valid_split(test_p=1/len(dataset), valid_p=1/len(dataset), seed=42)
splits = dataset.train_test_valid_split(test_p=0.1, valid_p=0.2, seed=42)

for kind in splits:
    for t in splits[kind]:
        vs = splits[kind][t]
        vs = torch.cat(vs)
        vs = vs[vs != 0]
        splits[kind][t] = vs
        
# splits

### Train language model

**Note:** Must do this for both anno and code.

In [None]:
CFG.language_model = Config()
CFG.language_model.__dict__ = {
    'dataset'     : 'django',
    'model'       : 'LSTM', # type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)
    'n_head'      : None,   # number of heads in the enc/dec of the Transformers
    'emb_size'    : 32,     # size of the word embeddings
    'n_hid'       : 64,     # number of hidden units per layer
    'n_layers'    : 1,      # number of layers
    'lr'          : 0.25,    # initial learning rate
    'clip'        : 0.25,   # gradient clipping
    'dropout_p'   : 0.0,    # dropout applied to layers
    'tied'        : False,  # whether to tie the word embeddings and softmax weights
    'log_interval': 100,
    'epochs'      : 50, # upper epoch limit
    'batch_size'  : 32,
    'seed'        : None # for reproducibility
}

CFG.language_model

In [None]:
lm_cfg = CFG.language_model

for kind in ['anno', 'code']:
    print(f'Training LM for {kind}\n')

    lm_cfg.kind = kind
    lm_cfg.bptt = CFG.dataset_config.__dict__[f'{kind}_seq_maxlen'] # seq len
    lm_cfg.save_path = f'./data/lm/lm-{lm_cfg.dataset}-{lm_cfg.kind}.pt' # path to save the final model
    
#     train_language_model(lm_cfg, 
#                          num_tokens=len(getattr(dataset, f'{kind}_lang')),
#                          train_nums=torch.stack(splits[kind]['train']),
#                          test_nums=torch.stack(splits[kind]['test']),
#                          valid_nums=torch.stack(splits[kind]['valid']))
    
    train_language_model(lm_cfg, 
                         num_tokens=len(getattr(dataset, f'{kind}_lang')),
                         train_nums=splits[kind]['train'],
                         test_nums=splits[kind]['test'],
                         valid_nums=splits[kind]['valid'])
    
    print('*' * 120, '\n')

### Compute LM probs

In [None]:
lm_root_dir = './data/lm'
lm_paths = {
    'anno': f'{lm_root_dir}/lm-{lm_cfg.dataset}-anno.pt',
    'code': f'{lm_root_dir}/lm-{lm_cfg.dataset}-code.pt'
}

for f in lm_paths.values():
    assert os.path.exists(f), f'Language Model: file <{f}> does not exist!'
    
# _ = dataset.compute_lm_probs(lm_paths)

In [None]:
class MyLMProb:
    def __init__(self, model_path):        
        self.model = torch.load(open(model_path, 'rb'), map_location={'cuda:0': 'cpu'})
        self.model = self.model.cpu()
        self.model.eval()

    def get_prob(self, nums, verbose=False):
        with torch.no_grad():
            inp = torch.tensor([int(nums[0])]).long().unsqueeze(0)
            hidden = self.model.init_hidden(bsz=1)
            log_probs = []
            
            for i in range(1, len(nums)):
                output, hidden = self.model(inp, hidden)
                
                #word_weights = output.squeeze().data.double().exp()
                #prob = word_weights[nums[i]] / word_weights.sum()
                probs = F.softmax(output.squeeze(), dim=-1)
                prob = probs[nums[i]]
                
                # append current log prob
                log_probs += [torch.log(prob)]
                inp.data.fill_(int(nums[i]))

            if verbose:
                for i in range(len(log_probs)):
                    print(f'{nums[i+1]:4d}: P(w|s) = {np.exp(log_probs[i]):8.4f} | logP(w|s) = {log_probs[i]:8.4f}')
                print(f'=> sum_prob = {sum(log_probs):.4f}')

        return sum(log_probs) / len(log_probs)

In [None]:
lm_probs = {'anno': [], 'code': []}

pad_idx = {
    'anno': dataset.anno_lang.token2index['<pad>'],
    'code': dataset.code_lang.token2index['<pad>']
} 

for kind in lm_probs:
    lm = MyLMProb(lm_paths[kind])
    p = pad_idx[kind]

    for vec in tqdm(getattr(dataset, kind), total=len(dataset), desc=f'P({kind})'):
        lm_probs[kind] += [np.exp(lm.get_prob(vec[vec != pad_idx[kind]], verbose=False))]
    
    lm_probs[kind] = sum(lm_probs[kind])
    break

In [None]:
kind = 'anno'
lm = MyLMProb(lm_paths[kind])
s = {}
for t, i in tqdm(getattr(dataset, f'{kind}_lang').token2index.items()):
    if i in [0, 2, 3]:
        continue
    q = torch.tensor([2, i, 3])
    s[i] = np.exp(lm.get_prob(q))
    
xs, ys = zip(*sorted(s.items(), key=lambda k: -k[1]))

plt.figure(figsize=(14,6))
plt.bar(xs, ys)
plt.xticks(xs, rotation=90)

sum(ys)

## Dual CS/CG Model

In [None]:
def get_embeddings(config: Config):
    emb = nn.Embedding(len(config.lang), config.emb_size, padding_idx=config.lang.pad_idx)
    
    if config.load_pretrained_emb:
        assert config.lang.emb_matrix is not None
        emb.weight = nn.Parameter(torch.tensor(config.lang.emb_matrix, dtype=torch.float32))
        emb.weight.requires_grad = False
        
    return emb

In [None]:
class Model(nn.Module):
    def __init__(self, config: Config, model_type):
        """
        :param model_type: cs / cg
        cs: code -> anno
        cg: anno -> code
        """
        super(Model, self).__init__()
        
        assert model_type in ['cs', 'cg']
        self.model_type = model_type
        
        src_cfg = config.anno if model_type == 'cg' else config.code
        tgt_cfg = config.code if model_type == 'cg' else config.anno
        
        # 1. ENCODER
        self.src_embedding = get_embeddings(src_cfg)
        self.encoder = nn.LSTM(input_size=src_cfg.emb_size,
                               hidden_size=src_cfg.lstm_hidden_size,
                               dropout=src_cfg.lstm_dropout_p,
                               bidirectional=True,
                               batch_first=True)
        
        self.decoder_cell_init_linear = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                                  out_features=tgt_cfg.lstm_hidden_size)
        
        # 2. ATTENTION
        # project source encoding to decoder rnn's h space (W from Luong score general)
        self.att_src_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the attentional vector in (W from Luong eq. 5)
        self.att_vec_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size + tgt_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # 3. DECODER
        self.tgt_embedding = get_embeddings(tgt_cfg)
        self.decoder = nn.LSTMCell(input_size=tgt_cfg.emb_size + tgt_cfg.lstm_hidden_size,
                                   hidden_size=tgt_cfg.lstm_hidden_size)
       
        # prob layer over target language
        self.readout = nn.Linear(in_features=tgt_cfg.lstm_hidden_size,
                                 out_features=len(tgt_cfg.lang),
                                 bias=False)
        
        self.dropout = nn.Dropout(tgt_cfg.att_dropout_p)
        
        # 4. COPY MECHANISM
        self.copy_gate = ... # TODO
        
        # save configs
        self.src_cfg = src_cfg
        self.tgt_cfg = tgt_cfg
        
        
    def forward(self, src, tgt):
        """
        src: bs, max_src_len
        tgt: bs, max_tgt_len
        """
        enc_out, (h0_dec, c0_dec) = self.encode(src)
        scores, att_mats = self.decode(enc_out, h0_dec, c0_dec, tgt)
        
        return scores, att_mats
    
    
    def encode(self, src):
        """
        src : bs x max_src_len (emb look-up indices)
        out : bs x max_src_len x 2*hid_size
        h/c0: bs x tgt_hid_size
        """
        emb = self.src_embedding(src)
        out, (hn, cn) = self.encoder(emb) # hidden is zero by default
        
        # construct initial state for the decoder
        c0_dec = self.decoder_cell_init_linear(torch.cat([cn[0], cn[1]], dim=1))
        h0_dec = c0_dec.tanh()
        
        return out, (h0_dec, c0_dec)
    
    
    def decode(self, src_enc, h0_dec, c0_dec, tgt):
        """
        src_enc: bs, max_src_len, 2*hid_size (== encoder output)
        h/c0   : bs, tgt_hid_size
        tgt    : bs, max_tgt_len (emb look-up indices)
        
        scores :
        """
        batch_size, tgt_len = tgt.shape
        scores, att_mats = [], []
        
        hidden = (h0_dec, c0_dec)
        
        emb = self.tgt_embedding(tgt) # bs, max_tgt_len, tgt_emb_size
        
        att_vec = torch.zeros(batch_size, self.tgt_cfg.lstm_hidden_size, requires_grad=False)
        
        # Luong W*hs: same for each timestep of the decoder
        src_enc_att = self.att_src_W(src_enc) # bs, max_src_len, tgt_hid_size
        
        for t in range(tgt_len):
            emb_t = emb[:, t, :]
            x = torch.cat([emb_t, att_vec], dim=-1)
            h_t, c_t = self.decoder(x, hidden)

            ctx_t, att_mat = self.luong_attention(h_t, src_enc, src_enc_att)
            
            # Luong eq. (5)
            att_t = self.att_vec_W(torch.cat([h_t, ctx_t], dim=1))
            att_t = att_t.tanh() 
            att_t = self.dropout(att_t)
            
            # Luong eq. (6)
            score_t = self.readout(att_t)
            score_t = F.softmax(score_t, dim=-1)
            
            scores   += [score_t]
            att_mats += [att_mat]
            
            # for next state t+1
            att_vec = att_t
            hidden  = (h_t, c_t)
        
        # bs, max_tgt_len, tgt_vocab_size
        scores = torch.stack(scores).permute((1, 0, 2))
        
        # each element: bs, max_src_len, max_tgt_len
        att_mats = torch.cat(att_mats, dim=1)
        
        return scores, att_mats
            
        
    def luong_attention(self, h_t, src_enc, src_enc_att, mask=None):
        """
        h_t               : bs, hid_size
        src_enc (hs)      : bs, max_src_len, 2*src_hid_size 
        src_enc_att (W*hs): bs, max_src_len, tgt_hid_size
        mask              : bs, max_src_len
        
        ctx_vec    : bs, 2*src_hid_size
        att_weight : bs, max_src_len
        att_mat    : bs, 1, max_src_len
        """
        
        # bs x src_max_len
        score = torch.bmm(src_enc_att, h_t.unsqueeze(2)).squeeze(2)
        
        if mask:
            score.data.masked_fill_(mask, -np.inf)
        
        att_mat = score.unsqueeze(1)
        att_weights = F.softmax(score, dim=-1)
        
        # sum per timestep
        ctx_vec = torch.sum(att_weights.unsqueeze(2) * src_enc, dim=1)
        
        return ctx_vec, att_mat
    
    
    def translate(self, src):
        """
        Beam search
        src: input sequence (anno / code)
        """
        pass

In [None]:
def get_vocab_mask(lang):
    mask = torch.ones(len(lang))
    mask[lang.token2index['<pad>']] = 0
    return mask


def jensen_shannon_divergence(a1, b1, a_mask, b_mask):
    """
    a       : bs, n, m
    b       : bs, m, n
    pad_mask: bs, n
    """
    # TODO NO MORE PAD NEEDED
    eps, inf = 1e-8, 1e+8
    
    a = a1.clone()
    b = b1.clone()
    
    bs, n, m = a.shape
    assert b.shape == (bs, m, n)
    dmax = max(n, m)

    
    a[a_mask == 1] = -inf
    b[b_mask == 1] = -inf
#     a.data.masked_fill_(a_mask, -inf)
#     b.data.masked_fill_(b_mask, -inf)
    
    a = F.softmax(a, dim=2) + eps
    b = F.softmax(b, dim=2) + eps
    
    a_ = eps * torch.ones(bs, dmax, dmax)
    a_[:, :min(n,dmax), :min(m,dmax)] = a
    
    b_ = eps * torch.ones(bs, dmax, dmax)
    b_[:, :min(m,dmax), :min(n,dmax)] = b
    
    a_mask_ = torch.ones(bs, dmax)
    a_mask_[:, :min(n, dmax)] = a_mask
    b_mask_ = torch.ones(bs, dmax)
    b_mask_[:, :min(m, dmax)] = b_mask
    
    a = a_
    b = b_
    a_mask = a_mask_
    b_mask = b_mask_
    
    kl_a = a * torch.log(a / ((a+b)/2))
    kl_b = b * torch.log(b / ((a+b)/2))

    kl_a[a_mask == 1] = 0
    kl_b[b_mask == 1] = 0
#     kl_a.data.masked_fill_(a_mask, 0)
#     kl_b.data.masked_fill_(b_mask, 0)
    
    kl_a = torch.sum(kl_a, dim=2)
    kl_b = torch.sum(kl_b, dim=2)
    
    kl_a = torch.sum(kl_a, dim=1) / torch.sum(1-a_mask, dim=1).float()
    kl_b = torch.sum(kl_b, dim=1) / torch.sum(1-b_mask, dim=1).float()
    
    js_div = (kl_a + kl_b) / 2.0
    
    return js_div

In [None]:
Axy = anno_att_mat.clone()
Ayx = code_att_mat.clone()
Mxy = anno_mask.clone()
Myx = code_mask.clone()

jensen_shannon_divergence(
    Axy.transpose(2, 1), 
    Ayx.transpose(2, 1), 
    Myx, 
    Mxy
)

In [None]:
dataset.code_lang.token2count.most_common(100)

## Train

In [None]:
kwargs = {} # {'num_workers': 4, 'pin_memory': True}
train_loader = DataLoader(dataset, batch_size=1, shuffle=False, **kwargs)

model = {
    'cg': Model(CFG, model_type='cg'),
    'cs': Model(CFG, model_type='cs')
}

opt = {
    'cg': O.Adam(lr=0.001, params=filter(lambda p: p.requires_grad, model['cg'].parameters())),
    'cs': O.Adam(lr=0.001, params=filter(lambda p: p.requires_grad, model['cs'].parameters()))
}

crit = {
    'cg': nn.CrossEntropyLoss(weight=get_vocab_mask(dataset.code_lang)),
    'cs': nn.CrossEntropyLoss(weight=get_vocab_mask(dataset.anno_lang))
}

__cg_l = 0
__cs_l = 0
__rep_every = 200

for e in range(1000):
    for i, (anno, code, anno_lm_p, code_lm_p) in enumerate(train_loader):        
        anno_len, code_len = anno.shape[1], code.shape[1]

        # binary mask indicating the presence of padding token
        anno_mask = torch.tensor(anno == dataset.anno_lang.token2index['<pad>']).byte()
        code_mask = torch.tensor(code == dataset.code_lang.token2index['<pad>']).byte()
                    
        # forward pass
        code_pred, code_att_mat = model['cg'](src=anno, tgt=code)
        anno_pred, anno_att_mat = model['cs'](src=code, tgt=anno)
                
        loss = {k: 0 for k in [
            'cg', 'cs',         # total
            'cg_ce', 'cs_ce',   # cross-entropy
            'cg_att', 'cs_att', # attention
            'dual'              # common dual loss
        ]}
        
        # CG cross-entropy loss
        for t in range(code_len):
            loss['cg_ce'] += crit['cg'](code_pred[:, t, :], code[:, t]) / code_len
        
        # CS cross-entropy loss
        for t in range(anno_len):
            loss['cs_ce'] += crit['cs'](anno_pred[:, t, :], anno[:, t]) / anno_len

        # dual loss: P(x,y)
        loss['dual'] = ((code_lm_p - loss['cs_ce']) - (anno_lm_p - loss['cg_ce'])) ** 2
                
        # attention loss: JSD
        loss['cg_att'] = jensen_shannon_divergence(anno_att_mat, 
                                                   code_att_mat, 
                                                   anno_mask, 
                                                   code_mask)
        
        loss['cs_att'] = jensen_shannon_divergence(anno_att_mat.transpose(2,1), 
                                                   code_att_mat.transpose(2,1), 
                                                   code_mask, 
                                                   anno_mask)
                
        att_loss = loss['cg_att'] + loss['cs_att']
        
        loss['cg'] = torch.mean(loss['cg_ce'] + 0.01  * loss['dual'] + 0.1  * att_loss)
        loss['cs'] = torch.mean(loss['cs_ce'] + 0.001 * loss['dual'] + 0.01 * att_loss)
        
        opt['cg'].zero_grad()
        loss['cg'].backward(retain_graph=True)
        opt['cg'].step()
        
        opt['cs'].zero_grad()
        loss['cs'].backward()
        opt['cs'].step()
        
        __cg_l += loss['cg'].item() / __rep_every
        __cs_l += loss['cs'].item() / __rep_every
        
        if e % 100 == 0:
            print(f'Epoch {e:>5d} | Batch {i:>5d} | cg {__cg_l:6.5f} | cs {__cs_l:6.5f}')
            __cg_l, __cs_l = 0, 0
            
            with torch.no_grad():
                ws = dataset.code_lang.to_tokens(code_pred.argmax(dim=-1))
                for i, w in enumerate(ws):
                    print(f'\t{i}: {" ".join(w)}')
                print('\t'+'-'*80)
                ws = dataset.anno_lang.to_tokens(anno_pred.argmax(dim=-1))
                for i, w in enumerate(ws):
                    print(f'\t{i}: {" ".join(w)}')
                print()