In [1]:
import sys
sys.path.append("../")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext
import torchtext.data as data
from torchtext.data import Field, BucketIterator

import random
from tqdm import tqdm
import time
import math

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import typing

from src.tokenize import base_tokenizer, Vocab, TextPreprocessor

In [3]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
global max_src_in_batch, max_tgt_in_batch

def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.source))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.target) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [5]:
class FastIterator(data.Iterator):
    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)
            
        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

# Prepare the data

In [6]:
MAX_LEN = 1000
BATCH_SIZE = 8

In [7]:
SRC_TEXT = Field(init_token="<s>", eos_token="</s>", include_lengths=True)
TRG_TEXT = Field(init_token="<s>", eos_token="</s>")

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
data_fields = [("source", SRC_TEXT), ("target", TRG_TEXT)]

def filter_pred(x: str) -> bool: 
    return len(x.source) <= MAX_LEN and len(x.target) <= MAX_LEN

dataset = torchtext.datasets.TranslationDataset(
    path="/workspace/tmp/dataset/train", 
    exts=(".diff", ".msg"), 
    fields=data_fields,
    filter_pred=filter_pred
)

In [10]:
train_data, test_data, valid_data = dataset.split(split_ratio=[0.8, 0.15, 0.05])

In [11]:
SRC_TEXT.build_vocab(train_data.source, min_freq=0)
TRG_TEXT.build_vocab(train_data.target, min_freq=0)

In [12]:
train_iterator, valid_iterator, test_iterator = FastIterator.splits(
    datasets=(train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    sort_key=lambda x: (len(x.source)),
    sort_within_batch=True,
    # batch_size_fn=batch_size_fn,
    shuffle=True,
    device=device
)

In [13]:
src_data = next(iter(train_iterator))
src_data


[torchtext.data.batch.Batch of size 8]
	[.source]:('[torch.LongTensor of size 413x8]', '[torch.LongTensor of size 8]')
	[.target]:[torch.LongTensor of size 20x8]

# Build the seq2seq baseline

In [15]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings=input_dim, embedding_dim=emb_dim
        )
        self.rnn = nn.GRU(
            input_size=emb_dim, hidden_size=enc_hid_dim, bidirectional=True
        )
        self.linear = nn.Linear(
            in_features=enc_hid_dim * 2, out_features=dec_hid_dim
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src: torch.Tensor, lengths: torch.Tensor):
        embedded = self.embedding(src)
        embedded = self.dropout(embedded)
        
        packed_input = pack_padded_sequence(embedded, lengths)
        packed_outputs, hidden = self.rnn(packed_input)
        outputs, _ = pad_packed_sequence(packed_outputs)

        # [batch_size, enc_hid_dim * 2]
        hidden_cat_layers = torch.cat(
            (hidden[-2, :, :], hidden[-1, :, :]), dim=1
        )

        # Like in the original paper
        # hidden_cat_layers = torch.tanh(self.linear(hidden[-1, :, :]))

        # [batch_size, dec_hid_dim]
        hidden = torch.tanh(self.linear(hidden_cat_layers))

        # outputs = [seq_len, batch_size, enc_hid_dim * 2]
        # hidden = [batch_size, dec_hid_dim]
        return outputs, hidden

In [16]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear(
            in_features=(enc_hid_dim * 2) + dec_hid_dim,
            out_features=dec_hid_dim,
        )
        self.v = nn.Parameter(torch.rand(dec_hid_dim))

    def forward(self, hidden: torch.Tensor, encoder_outputs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        seq_len = encoder_outputs.size()[0]
        batch_size = encoder_outputs.size()[1]

        # [batch_size, seq_len, dec_hid_dim]
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        # [batch_size, seq_len, enc_hid_dim * 2]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        # [batch_size, seq_len, enc_hid_dim * 2 + dec_hid_dim]
        query_layers = torch.cat((hidden, encoder_outputs), dim=2)
        # [batch_size, seq_len, dec_hid_dim]
        energy = torch.tanh(self.attn(query_layers))

        # We need to get scores of size [batch_size, seq_len]
        # so we just multiply matrices with trycky permutations
        # to get rid of dec_hid_dim axis

        # [batch_size, 1, dec_hid_dim]
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        # [batch_size, dec_hid_dim, seq_len]
        energy = energy.permute(0, 2, 1)

        # [batch_size, seq_len]
        scores = torch.bmm(v, energy).squeeze(1)
        
        # skip pad tokens
        scores = scores.masked_fill(mask == 0, float('-inf'))

        # [batch_size, seq_len]
        return F.softmax(scores, dim=1)

In [17]:
class Decoder(nn.Module):
    def __init__(
        self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention
    ):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(
            num_embeddings=output_dim, embedding_dim=emb_dim
        )
        # encoder is bidirectional, so input_size=(enc_hid_dim * 2) + emb_dim
        self.rnn = nn.GRU(
            input_size=(enc_hid_dim * 2) + emb_dim, hidden_size=dec_hid_dim
        )
        self.out = nn.Linear(
            in_features=(enc_hid_dim * 2) + emb_dim + dec_hid_dim,
            out_features=output_dim,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, hidden, encoder_outputs, mask):
        # [1, batch_size]
        src = src.unsqueeze(0)
        # [1, batch_size, emb_dim]
        embedded = self.dropout(self.embedding(src))

        # [batch_size, seq_len]
        attn = self.attention(hidden, encoder_outputs, mask)
        # [batch_size, 1, seq_len]
        attn = attn.unsqueeze(1)

        # [batch_size, seq_len, enc_hid_dim * 2]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        # [batch_size, 1, enc_hid_dim * 2]
        weighted = torch.bmm(attn, encoder_outputs)
        # [1, batch_size, enc_hid_dim * 2]
        weighted = weighted.permute(1, 0, 2)

        # [1, batch_size, emb_dim + enc_hid_dim * 2]
        rnn_input = torch.cat((embedded, weighted), dim=2)
        # [1, batch_size, dec_hid_dim]
        rnn_hidden = hidden.unsqueeze(0)

        # output = [seq len, batch size, dec hid dim * n directions]
        # hidden = [n layers * n directions, batch size, dec hid dim]
        output, hidden = self.rnn(rnn_input, rnn_hidden)

        # [batch_size, emb_dim]
        embedded = embedded.squeeze(0)
        # [batch_size, dec_hid_dim]
        output = output.squeeze(0)
        # [batch_size, enc_hid_dim * 2]
        weighted = weighted.squeeze(0)

        # [batch_size, emb_dim + dec_hid_dim + enc_hid_dim * 2]
        output_cat = torch.cat((embedded, output, weighted), dim=1)
        # [batch_size, output_dim]
        pred = self.out(output_cat)

        return pred, hidden.squeeze(0)

In [18]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.device = device

    def make_mask(self, src):
        _mask = (src != self.src_pad_idx)
        _mask = _mask.permute(1, 0)
        return _mask
        
    def forward(self, src, src_lengths, trg, teacher_forcing_ratio=0.5):
        # src = [src_len, batch_size]
        # trg = [trg_len, batch_size]

        batch_size = src.size()[1]
        trg_len = trg.size()[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size)
        outputs = outputs.to(self.device)

        encoder_outputs, hidden = self.encoder(src, src_lengths)

        x = trg[0, :]
        mask = self.make_mask(src)
        for t in range(1, trg_len):
            pred, hidden = self.decoder(x, hidden, encoder_outputs, mask)
            outputs[t] = pred
            teacher_forcing = random.random() < teacher_forcing_ratio
            best_pred = pred.argmax(dim=1)
            x = trg[t] if teacher_forcing else best_pred

        # [trg_len, batch_size, trg_vocab_size]
        return outputs

# Train seq2seq baseline

In [19]:
INPUT_DIM = len(SRC_TEXT.vocab)
OUTPUT_DIM = len(TRG_TEXT.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
SRC_PAD_IDX = SRC_TEXT.vocab.stoi[SRC_TEXT.pad_token]
TRG_PAD_IDX = TRG_TEXT.vocab.stoi[TRG_TEXT.pad_token]
BOS_IDX = SRC_TEXT.vocab.stoi[SRC_TEXT.init_token]
EOS_IDX = SRC_TEXT.vocab.stoi[SRC_TEXT.eos_token]

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, SRC_PAD_IDX, device).to(device)

In [20]:
def init_weights(m):
    for name, param in m.named_parameters():
        if "weight" in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(10128, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (linear): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
    )
    (embedding): Embedding(3743, 256)
    (rnn): GRU(1280, 512)
    (out): Linear(in_features=1792, out_features=3743, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 16,695,455 trainable parameters


In [22]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)

In [23]:
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

In [24]:
def train(
    model: Seq2Seq,
    iterator,
    optimizer: optim.Optimizer,
    criterion: nn.CrossEntropyLoss,
    clip,
):
    model.train()
    epoch_loss = 0
    for batch in tqdm(iterator):
        src, src_lengths = batch.source
        trg = batch.target

        optimizer.zero_grad()

        # [trg len, batch size, output dim]
        output = model(src, src_lengths, trg)

        output_dim = output.size()[-1]

        # [trg len * batch size, output dim]
        output = output[1:].view(-1, output_dim)
        # [trg len * batch size, output dim]
        trg = trg[1:].view(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [25]:
def evaluate(model, iterator, criterion):    
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src, src_len = batch.src
            trg = batch.trg

            output = model(src, src_len, trg, 0)
            
            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]
            output_dim = output.shape[-1]
            
            # [(trg len - 1) * batch size, output dim]
            output = output[1:].view(-1, output_dim)
            # [(trg len - 1) * batch size]
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [26]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [27]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'code-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

  0%|          | 3/951 [00:07<43:05,  2.73s/it]

KeyboardInterrupt: 

In [29]:
src_data = next(iter(train_iterator))
src = src_data.source
trg = src_data.target

In [None]:
symbs = TextPreprocessor().symbols
src_vocab = Vocab(symbs, model_path="/workspace/diffs.model")
trg_vocab = Vocab(symbs, model_path="/workspace/messages.model")
tokenizer = base_tokenizer(src_vocab, trg_vocab)

In [None]:
def numericalize_source_sents(
    sentences: typing.List[str], 
    bos_idx: int, 
    eos_idx: int, 
    pad_idx: int, 
    device: torch.device
) -> torch.Tensor:
    sents_tokens = [tokenizer.tokenize_source_as_ids(sent) for sent in sentences]
    max_length = len(sorted(sents_tokens, key=lambda x: len(x), reverse=True)[0])
    sents_var = torch.tensor([pad_idx]).repeat([len(sentences), max_length + 2])
    sents_lengths = torch.zeros(len(sentences))
    for idx, sent_ids in enumerate(sents_tokens):
        sents_var[idx, :len(sent_ids) + 2] = torch.tensor([bos_idx] + sent_ids + [eos_idx])
        sents_lengths[idx] = len(sent_ids) + 2
    return sents_var.to(device), sents_lengths.to(device)

In [None]:
sents = ["def say_helo(): print()"]

numericalize_source_sents(
    sentences=sents, 
    bos_idx=BOS_IDX, 
    eos_idx=EOS_IDX, 
    pad_idx=SRC_PAD_IDX, 
    device=device
)