In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import Django
from utils import from_home
from config import Config

### Config

In [None]:
CFG     = Config() # main config
CFG.src = Config() # sub-config for NL intents
CFG.tgt = Config() # sub-config for source code

### Dataset

In [None]:
DJANGO_DIR = from_home('workspace/ml-data/msc-research/raw-datasets/testing')
EMB_DIR = from_home('workspace/ml-data/msc-research/embeddings')

In [None]:
CFG.dataset_config = Config()
CFG.dataset_config.__dict__ = {
    'root_dir': DJANGO_DIR,
    'anno_min_freq': 1,
    'code_min_freq': 1,
    'anno_seq_maxlen': 10,
    'code_seq_maxlen': 10,
    'emb_file': os.path.join(EMB_DIR, 'glove.6B.50d.txt.pickle')
}

django = Django(config=CFG.dataset_config)

In [None]:
CFG

In [None]:
for i in range(len(django)):
    print(django.raw(i))

In [None]:
django[1]

### Loaders

In [None]:
def get_embeddings(config: Namespace):
    emb = nn.Embedding(len(config.lang), config.emb_size, padding_idx=config.lang.pad_idx)
    
    if config.load_pretrained_emb:
        emb.weight = nn.Parameter(torch.tensor(config.emb_matrix, dtype=torch.float32))
        emb.weight.requires_grad = False
        
    return emb

### Model

In [None]:
class Model(nn.Module):
    def __init__(self, config: Namespace):
        super(Model, self).__init__()
        
        # 1. ENCODER
        self.src_embedding = get_embeddings(config.src)
        self.encoder = nn.LSTM(input_size=config.src.emb_size,
                               hidden_size=config.src.hidden_size,
                               dropout=config.src.lstm_dropout,
                               bidirectional=True,
                               batch_first=True)
        
        self.decoder_cell_init = nn.Linear(in_features=2*config.src.hidden_size,
                                           out_features=config.tgt.hidden_size)
        
        # 2. ATTENTION
        # project source encoding to decoder rnn's h space
        self.att_src_linear = nn.Linear(in_features=2*config.src.hidden_size,
                                        out_features=config.tgt.hidden_size,
                                        bias=False)
        
        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)
        self.att_vec_linear = nn.Linear(in_features=2*config.src.hidden_size + config.tgt.hidden_size,
                                        out_features=config.tgt.hidden_size,
                                        bias=False)
        
        # 3. DECODER
        self.tgt_embedding = get_embeddings(config.tgt)
        self.decoder = nn.LSTMCell(input_size=config.tgt.emb_size + config.tgt.hidden_size,
                                   hidden_size=config.tgt.hidden_size)
       
        # prob layer over target language
        self.readout = nn.Linear(in_features=config.tgt.hidden_size,
                                 out_features=len(config.tgt.lang),
                                 bias=False)
        
        self.dropout = nn.Dropout(config.tgt.dropout_p)
        
        # save the entire config
        self.config = config
        
        
    def forward(self, src, tgt):
        """
        src: bs, max_src_len
        tgt: bs, max_tgt_len
        """
        enc_out, (h0_dec, c0_dec) = self.encode(src)
        scores, att_mats = self.decode(enc_out, h0_dec, c0_dec, tgt)
        
        return scores, (h0_dec, c0_dec), att_mats
    
    
    def encode(self, src):
        """
        src : bs x max_src_len (emb look-up indices)
        out : bs x max_src_len x 2*hid_size
        h/c0: bs x tgt_hid_size
        """
        emb = self.src_embedding(src)
        out, (hn, cn) = self.encoder(emb)
        
        c0_dec = self.decoder_cell_init(torch.cat([cn[0], cn[1]], dim=1))
        h0_dec = c0_dec.tanh()
        
        return out, (h0_dec, c0_dec)
    
    
    def decode(self, src_enc, h0_dec, c0_dec, tgt):
        """
        src_enc: bs, max_src_len, 2*hid_size (== encoder output)
        h/c0   : bs, tgt_hid_size
        tgt    : bs, max_tgt_len (emb look-up indices)
        """
        batch_size, tgt_len = tgt.shape
        scores = []
        att_mats = []
        
        emb = self.tgt_embedding(tgt) # bs, max_tgt_len, tgt_emb_size
        
        att_vec = torch.zeros(batch_size, self.config.tgt.hidden_size, requires_grad=False)
        src_enc_att = self.att_src_linear(src_enc) # bs, max_src_len, tgt_hid_size
        
        hidden = (h0_dec, c0_dec)
        
        for t in range(tgt_len):
            emb_t = emb[:, t, :]
            x = torch.cat([emb_t, att_vec], dim=-1)
            h_t, c_t = self.decoder(x, hidden)

            ctx_t, _, att_mat = self.luong_attention(h_t, src_enc, src_enc_att)
            
            # Luong eq. (5)
            att_t = self.att_vec_linear(torch.cat([h_t, ctx_t], dim=1))
            att_t = att_t.tanh() 
            att_t = self.dropout(att_t)
            
            # Luong eq. (6)
            score_t = self.readout(att_t)
            score_t = F.softmax(score_t, dim=-1)
            
            scores   += [score_t]
            att_mats += [att_mat]
            
            # for next state t+1
            att_vec = att_t
            hidden  = h_t, c_t
        
        return torch.stack(scores), att_mats[:-1]
            
        
    def luong_attention(self, h_t, src_enc, src_enc_att, mask=None):
        """
        h_t        : bs, hid_size
        src_enc    : bs, max_src_len, 2*src_hid_size
        src_enc_att: bs, max_src_len, tgt_hid_size
        mask       : bs, max_src_len
        
        ctx_vec    : bs, 2*src_hid_size
        att_weight : bs, max_src_len
        att_mat    : bs, 1, max_src_len
        """
        
        # bs x src_max_len
        att_weight = torch.bmm(src_enc_att, h_t.unsqueeze(2)).squeeze(2)
        
        if mask:
            att_weight.data.masked_fill_(mask, -np.inf)
        
        att_mat = att_weight.view((att_weight.size(0), 1, att_weight.size(1)))
        att_weight = F.softmax(att_weight, dim=-1) # alignment

        att_view = (att_weight.size(0), 1, att_weight.size(1))
        ctx_vec = torch.bmm(att_weight.view(*att_view), src_enc).squeeze(1)

        return ctx_vec, att_weight, att_mat

In [None]:
class Lang:
    def __init__(self):
        self.pad_idx = 0
    def __len__(self):
        return 100
    
cfg = Namespace()

cfg.src = Namespace()
cfg.src.lang = Lang()
cfg.src.emb_size = 50
cfg.src.load_pretrained_emb = False
cfg.src.hidden_size = 128
cfg.src.lstm_dropout = 0

cfg.tgt = Namespace()
cfg.tgt.lang = Lang()
cfg.tgt.emb_size = 50
cfg.tgt.load_pretrained_emb = False
cfg.tgt.hidden_size = 128
cfg.tgt.dropout_p = 0.1


model = Model(cfg)
x = torch.randint(10, size=(17, 20))
y = torch.randint(15, size=(17, 10))

s, a = model(x, y)
s[0].shape

### Train