In [None]:
import sys
sys.path.insert(0, './language_model/')

import warnings
warnings.simplefilter('ignore', UserWarning)

In [None]:
import os
from timeit import default_timer as timer
from functools import reduce
import operator

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as O
from torch.utils.data import DataLoader

from django import Django
from config import Config
from language_model.lm_train import train_language_model

import pprint
pp = pprint.PrettyPrinter(width=160, indent=2, compact=True)

In [None]:
def from_home(x):
    return os.path.join(os.environ['HOME'], x)

## Config

In [None]:
ROOT_DIR   = from_home('workspace/ml-data/msc-research')
DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/testing')
# DJANGO_DIR = os.path.join(ROOT_DIR, 'raw-datasets/django')
EMB_DIR    = os.path.join(ROOT_DIR, 'embeddings')

CFG = Config() # main config

# sub-config for dataset
CFG.dataset_config = Config()
CFG.dataset_config.__dict__ = {
    'root_dir': DJANGO_DIR,
    'anno_min_freq': 1,
    'code_min_freq': 1,
    'anno_seq_maxlen': 20,
    'code_seq_maxlen': 10,
    'emb_file': os.path.join(EMB_DIR, 'glove.6B.50d.txt.pickle'),
}

dataset = Django(config=CFG.dataset_config)

# sub-config for NL intents
CFG.anno = Config() 
CFG.anno.__dict__ = {
    'lstm_hidden_size': 128,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.anno_lang,
    'load_pretrained_emb': True,
    'emb_size': 50,
}

# sub-config for source code
CFG.code = Config() 
CFG.code.__dict__ = {
    'lstm_hidden_size': 128,
    'lstm_dropout_p': 0.0,
    'att_dropout_p': 0.0,
    'lang': dataset.code_lang,
    'load_pretrained_emb': False,
    'emb_size': 50,
}

## Compute LM probabilities

### Get train/test/valid splits

In [None]:
splits = dataset.train_test_valid_split(test_p=1/5, valid_p=1/5, seed=42)

### Train language model

**Note:** Must do this for both anno and code.

In [None]:
CFG.language_model = Config()
CFG.language_model.__dict__ = {
    'dataset'     : 'django',
    'model'       : 'LSTM', # type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)
    'n_head'      : None,   # number of heads in the enc/dec of the Transformers
    'emb_size'    : 10,     # size of the word embeddings
    'n_hid'       : 32,     # number of hidden units per layer
    'n_layers'    : 1,      # number of layers
    'lr'          : 0.1,    # initial learning rate
    'clip'        : 0.25,   # gradient clipping
    'bptt'        : 20,     # seq len
    'dropout_p'   : 0.0,    # dropout applied to layers
    'tied'        : False,  # whether to tie the word embeddings and softmax weights
    'log_interval': 10,
    'epochs'      : 10, # upper epoch limit
    'batch_size'  : 4,
    'seed'        : None # for reproducibility
}

CFG.language_model

In [None]:
lm_cfg = CFG.language_model

for kind in ['anno', 'code']:
    print(f'Training LM for {kind}\n')

    lm_cfg.kind = kind
    lm_cfg.save_path = f'./data/lm/lm-{lm_cfg.dataset}-{lm_cfg.kind}.pt' # path to save the final model
    
    num_tokens = len(dataset)
    
    train_language_model(lm_cfg, 
                         num_tokens=len(getattr(dataset, f'{kind}_lang')),
                         train_nums=torch.stack(splits[kind]['train']),
                         test_nums=torch.stack(splits[kind]['test']),
                         valid_nums=torch.stack(splits[kind]['valid']))
    
    print('*' * 120, '\n')

### Compute LM probs

In [None]:
lm_root_dir = './data/lm'
lm_paths = {
    'anno': f'{lm_root_dir}/lm-{lm_cfg.dataset}-anno.pt',
    'code': f'{lm_root_dir}/lm-{lm_cfg.dataset}-code.pt'
}

for f in lm_paths.values():
    assert os.path.exists(f), f'Language Model: file <{f}> does not exist!'
    
_ = dataset.compute_lm_probs(lm_paths)

## Dual CS/CG Model

In [None]:
def get_embeddings(config: Config):
    emb = nn.Embedding(len(config.lang), config.emb_size, padding_idx=config.lang.pad_idx)
    
    if config.load_pretrained_emb:
        assert config.lang.emb_matrix is not None
        emb.weight = nn.Parameter(torch.tensor(config.lang.emb_matrix, dtype=torch.float32))
        emb.weight.requires_grad = False
        
    return emb

In [None]:
class Model(nn.Module):
    def __init__(self, config: Config, model_type):
        """
        :param model_type: cs / cg
        cs: code -> anno
        cg: anno -> code
        """
        super(Model, self).__init__()
        
        assert model_type in ['cs', 'cg']
        self.model_type = model_type
        
        src_cfg = config.anno if model_type == 'cg' else config.code
        tgt_cfg = config.code if model_type == 'cg' else config.anno
        
        # 1. ENCODER
        self.src_embedding = get_embeddings(src_cfg)
        self.encoder = nn.LSTM(input_size=src_cfg.emb_size,
                               hidden_size=src_cfg.lstm_hidden_size,
                               dropout=src_cfg.lstm_dropout_p,
                               bidirectional=True,
                               batch_first=True)
        
        self.decoder_cell_init_linear = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                                  out_features=tgt_cfg.lstm_hidden_size)
        
        # 2. ATTENTION
        # project source encoding to decoder rnn's h space (W from Luong score general)
        self.att_src_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the attentional vector in (W from Luong eq. 5)
        self.att_vec_W = nn.Linear(in_features=2*src_cfg.lstm_hidden_size + tgt_cfg.lstm_hidden_size,
                                   out_features=tgt_cfg.lstm_hidden_size,
                                   bias=False)
        
        # 3. DECODER
        self.tgt_embedding = get_embeddings(tgt_cfg)
        self.decoder = nn.LSTMCell(input_size=tgt_cfg.emb_size + tgt_cfg.lstm_hidden_size,
                                   hidden_size=tgt_cfg.lstm_hidden_size)
       
        # prob layer over target language
        self.readout = nn.Linear(in_features=tgt_cfg.lstm_hidden_size,
                                 out_features=len(tgt_cfg.lang),
                                 bias=False)
        
        self.dropout = nn.Dropout(tgt_cfg.att_dropout_p)
        
        # save configs
        self.src_cfg = src_cfg
        self.tgt_cfg = tgt_cfg
        
        
    def forward(self, src, tgt):
        """
        src: bs, max_src_len
        tgt: bs, max_tgt_len
        """
        enc_out, (h0_dec, c0_dec) = self.encode(src)
        scores, att_mats = self.decode(enc_out, h0_dec, c0_dec, tgt)
        
        return scores, att_mats
    
    
    def encode(self, src):
        """
        src : bs x max_src_len (emb look-up indices)
        out : bs x max_src_len x 2*hid_size
        h/c0: bs x tgt_hid_size
        """
        emb = self.src_embedding(src)
        out, (hn, cn) = self.encoder(emb) # hidden is zero by default
        
        # construct initial state for the decoder
        c0_dec = self.decoder_cell_init_linear(torch.cat([cn[0], cn[1]], dim=1))
        h0_dec = c0_dec.tanh()
        
        return out, (h0_dec, c0_dec)
    
    
    def decode(self, src_enc, h0_dec, c0_dec, tgt):
        """
        src_enc: bs, max_src_len, 2*hid_size (== encoder output)
        h/c0   : bs, tgt_hid_size
        tgt    : bs, max_tgt_len (emb look-up indices)
        
        scores :
        """
        batch_size, tgt_len = tgt.shape
        scores, att_mats = [], []
        
        hidden = (h0_dec, c0_dec)
        
        emb = self.tgt_embedding(tgt) # bs, max_tgt_len, tgt_emb_size
        
        att_vec = torch.zeros(batch_size, self.tgt_cfg.lstm_hidden_size, requires_grad=False)
        
        # Luong W*hs: same for each timestep of the decoder
        src_enc_att = self.att_src_W(src_enc) # bs, max_src_len, tgt_hid_size
        
        for t in range(tgt_len):
            emb_t = emb[:, t, :]
            x = torch.cat([emb_t, att_vec], dim=-1)
            h_t, c_t = self.decoder(x, hidden)

            ctx_t, att_mat = self.luong_attention(h_t, src_enc, src_enc_att)
            
            # Luong eq. (5)
            att_t = self.att_vec_W(torch.cat([h_t, ctx_t], dim=1))
            att_t = att_t.tanh() 
            att_t = self.dropout(att_t)
            
            # Luong eq. (6)
            score_t = self.readout(att_t)
            score_t = F.softmax(score_t, dim=-1)
            
            scores   += [score_t]
            att_mats += [att_mat]
            
            # for next state t+1
            att_vec = att_t
            hidden  = (h_t, c_t)
        
        # bs, max_tgt_len, tgt_vocab_size
        scores = torch.stack(scores).permute((1, 0, 2))
        # each element: bs, 1, max_tgt_len
        att_mats = att_mats[:-1]
        
        return scores, att_mats
            
        
    def luong_attention(self, h_t, src_enc, src_enc_att, mask=None):
        """
        h_t               : bs, hid_size
        src_enc (hs)      : bs, max_src_len, 2*src_hid_size 
        src_enc_att (W*hs): bs, max_src_len, tgt_hid_size
        mask              : bs, max_src_len
        
        ctx_vec    : bs, 2*src_hid_size
        att_weight : bs, max_src_len
        att_mat    : bs, 1, max_src_len
        """
        
        # bs x src_max_len
        score = torch.bmm(src_enc_att, h_t.unsqueeze(2)).squeeze(2)
        
        if mask:
            score.data.masked_fill_(mask, -np.inf)
        
        att_mat = score.unsqueeze(1)
        att_weights = F.softmax(score, dim=-1)
        
        # sum per timestep
        ctx_vec = torch.sum(att_weights.unsqueeze(2) * src_enc, dim=1)
        
        return ctx_vec, att_mat

In [None]:
def get_vocab_mask(lang):
    mask = torch.ones(len(lang))
    mask[lang.token2index['<pad>']] = 0
    return mask


def jensen_shannon_divergence(a, b, mask, length):
    """
    a     :
    b     :
    mask  :
    length:
    """
    eps, inf = 1e-8, 1e+8
    
    length = torch.tensor(length).float()
    
    a.data.masked_fill_(mask, -inf)
    b.data.masked_fill_(mask, -inf)
    
    a = F.softmax(a, dim=2) + eps
    b = F.softmax(b, dim=2) + eps
    
    kl_a = a * torch.log(a / ((a+b)/2))
    kl_b = b * torch.log(b / ((a+b)/2))
    kl_a.data.masked_fill_(mask, 0)
    kl_b.data.masked_fill_(mask, 0)
    
    kl_a = torch.sum(kl_a, dim=2)
    kl_b = torch.sum(kl_b, dim=2)
    kl_a = torch.sum(kl_a, dim=1) / length
    kl_b = torch.sum(kl_b, dim=1) / length
    
    js_div = (kl_a + kl_b) / 2.0
    
    return js_div

## Train

$\widehat{code} = M_{CG}(anno, code)$

$\widehat{anno} = M_{CS}(code, anno)$

In [None]:
kwargs = {} # {'num_workers': 4, 'pin_memory': True}
train_loader = DataLoader(dataset, batch_size=2, shuffle=False, **kwargs)

model = {
    'cg': Model(CFG, model_type='cg'),
    'cs': Model(CFG, model_type='cs')
}

opt = {
    'cg': O.Adam(lr=0.001, params=filter(lambda p: p.requires_grad, model['cg'].parameters())),
    'cs': O.Adam(lr=0.001, params=filter(lambda p: p.requires_grad, model['cs'].parameters()))
}

crit = {
    'cg': nn.CrossEntropyLoss(weight=get_vocab_mask(dataset.code_lang)),
    'cs': nn.CrossEntropyLoss(weight=get_vocab_mask(dataset.anno_lang))
}


for e in range(1):
    for i, (anno, code) in enumerate(train_loader):        
        anno_len, code_len = anno.shape[1], code.shape[1]
            
        code_, c_am = model['cg'](src=anno, tgt=code)
        anno_, a_am = model['cs'](src=code, tgt=anno)
        
        loss = {k: 0 for k in [
            'cg', 'cs',         # total
            'cg_ce', 'cs_ce',   # cross-entropy
            'cg_att', 'cs_att', # attention
            'dual'              # common dual loss
        ]}
        
        for t in range(code_len):
            loss['cg_ce'] += crit['cg'](code_[:, t, :], code[:, t]) / code_len
        
        for t in range(anno_len):
            loss['cs_ce'] += crit['cs'](anno_[:, t, :], anno[:, t]) / anno_len

        
        loss['dual'] = ...
        loss['cg_att'] = jensen_shannon_divergence(...)
        loss['cs_att'] = jensen_shannon_divergence(...)
        att_loss = loss['cg_att'] + loss['cs_att']
        
        loss['cg'] = loss['cg_ce'] + 0.1 * loss['cg_att'] + 0.2 * loss['dual']
        loss['cs'] = loss['cs_ce'] + 0.3 * loss['cs_att'] + 0.4 * loss['dual']
        
        
        if e % 400 == 0:
            print(f'{e:>5d} | cg {loss["cg"].item():6.5f} | cs {loss["cs"].item():6.5f}')
            with torch.no_grad():
                ws = dataset.code_lang.to_tokens(code_.argmax(dim=-1))
                for i, w in enumerate(ws):
                    print(f'\t{i}: {" ".join(w)}')
                print('\t'+'-'*80)
                ws = dataset.anno_lang.to_tokens(anno_.argmax(dim=-1))
                for i, w in enumerate(ws):
                    print(f'\t{i}: {" ".join(w)}')
                print()
        
        break