# Training Example

In [1]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

True
NVIDIA GeForce GTX 1060 6GB


## Model Hyperparameters

In [2]:
START_TOKEN = '<START>'
END_TOKEN = '<END>'
PADDING_TOKEN = '<PAD>'

d_model = 512
batch_size = 2
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 500

## Create Flutter Tokenizer

In [3]:
from pygments.lexers import DartLexer
from utils.vocab import dart_vocab
from utils.code_tokenizer import CodeTokenizer

flutter_tokenizer = CodeTokenizer(
    DartLexer(),
    framework_vocab=["Scaffold", "Widget", "setState"],
    language_vocab=dart_vocab,
    START_TOKEN=START_TOKEN,
    END_TOKEN=END_TOKEN,
    PAD_TOKEN=PADDING_TOKEN
)

print(f"Token Count: {len(flutter_tokenizer)}")

Token Count: 126


## Create React Native Tokenizer

In [4]:
from pygments.lexers import JavascriptLexer
from utils.vocab import javascript_vocab
from utils.code_tokenizer import CodeTokenizer

react_native_tokenizer = CodeTokenizer(
    JavascriptLexer(),
    framework_vocab=["View", "Text", "useState"],
    language_vocab=javascript_vocab,
    START_TOKEN=START_TOKEN,
    END_TOKEN=END_TOKEN,
    PAD_TOKEN=PADDING_TOKEN,
)

print(f"Token Count: {len(react_native_tokenizer)}")

Token Count: 128



### Basic Letter Tokenizer
For debugging

In [5]:
from utils.code_tokenizer import CodeTokenizer

letter_tokenizer = CodeTokenizer(
    None,
    framework_vocab=[],
    language_vocab=[
        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'
        '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '{', '}', '(', ')',
        '[', ']', '=', '+', '-', '*', '/', '%', '^', '&', '|', '!', '?', '<',
        '>', ':', ';', ',', '.', '_', '#', '@', '$', '~', '`', '"', "'", '\\',
        '/', '\n', ' ', '\t', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
        'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
        'W', 'X', 'Y', 'Z'
    ],
)

print(f"Token Count: {len(letter_tokenizer)}")

Token Count: 101


In [6]:
from utils.transformer import Transformer

transformer = Transformer(
    d_model, 
    ffn_hidden,
    num_heads, 
    drop_prob, 
    num_layers, 
    max_sequence_length,
    flutter_tokenizer,
    react_native_tokenizer,
)

transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SnippetEmbedding(
      (embedding): Embedding(126, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding):

In [7]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, a_snippets, a_parser, b_snippets, b_parser):
        self.a_snippets = a_snippets
        self.a_parser = a_parser
        self.b_snippets = b_snippets
        self.b_parser = b_parser

    def __len__(self):
        return len(self.a_snippets)

    def __getitem__(self, idx):
        return self.a_snippets[idx], self.b_snippets[idx]

In [8]:
import pandas as pd
from utils.parsing.code_parser import DartParser, JavascriptParser

df = pd.read_csv('./data/raw/samples.csv')

dart_samples = df['dart'].values
js_samples = df['javascript'].values

# filter for max length of sample at `max_sequence_len`` characters for training purposes
zipped = list(zip(dart_samples, js_samples))
samples = [(d, j) for d, j in zipped]
dart_samples, js_samples = zip(*samples)

print("Number of Dart samples:", len(dart_samples))
print("Number of JS samples:", len(js_samples))

dataset = TextDataset(dart_samples, DartParser(""), js_samples, JavascriptParser(""))

Number of Dart samples: 212
Number of JS samples: 212


In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [10]:
from torch import nn
import torch

criterian = nn.CrossEntropyLoss(ignore_index=react_native_tokenizer[PADDING_TOKEN], reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.AdamW(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [11]:
import numpy as np
from collections import defaultdict

NEG_INFTY = -1e9

def bracket_pairs_indices(tokens):
    stacks = defaultdict(list)
    pairs = {}
    for i, (_, token) in enumerate(tokens):
        if token == '(':
            stacks['PAREN'].append(i)
        elif token == ')':
            if len(stacks['PAREN']) > 0:
                pairs[stacks['PAREN'].pop()] = i
        if token == '[':
            stacks['BRACKET'].append(i)
        elif token == ']':
            if len(stacks['BRACKET']) > 0:
                pairs[stacks['BRACKET'].pop()] = i
        if token == '{':
            stacks['BRACE'].append(i)
        elif token == '}':
            if len(stacks['BRACE']) > 0:
                pairs[stacks['BRACE'].pop()] = i
    return pairs

def create_masks(flutter_batch, react_batch):
    num_seqs = len(flutter_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)

    encoder_padding_mask = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_seqs):
        flutter_sequence, react_sequence = flutter_tokenizer.tokenize(flutter_batch[idx]), react_native_tokenizer.tokenize(react_batch[idx])
        flutter_seq_length, react_seq_length = len(flutter_sequence), len(react_sequence)
        
        flutter_tokens_to_padding_mask = np.arange(flutter_seq_length + 1, max_sequence_length)
        react_tokens_to_padding_mask = np.arange(react_seq_length + 1, max_sequence_length)
        
        encoder_padding_mask[idx, :, flutter_tokens_to_padding_mask] = True
        encoder_padding_mask[idx, flutter_tokens_to_padding_mask, :] = True
        
        decoder_padding_mask_self_attention[idx, :, react_tokens_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, react_tokens_to_padding_mask, :] = True
        
        decoder_padding_mask_cross_attention[idx, :, flutter_tokens_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, react_tokens_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(look_ahead_mask + encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [15]:
from tqdm import tqdm

transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 100

for epoch in tqdm(range(num_epochs)):
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        flutter_batch, react_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_batch, react_batch)
        optim.zero_grad()
        react_pred = transformer(
            flutter_batch,
            react_batch,
            encoder_self_attention_mask.to(device), 
            decoder_self_attention_mask.to(device), 
            decoder_cross_attention_mask.to(device),
        )
        labels = transformer.decoder.sentence_embedding.batch_tokenize(react_batch)
        loss = criterian(
            react_pred.view(-1, len(react_native_tokenizer)).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == react_native_tokenizer[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        total_loss += loss.item()

100%|██████████| 100/100 [11:22<00:00,  6.82s/it]


In [17]:
transformer.eval()
# predict the translation of a flutter code snippet
flutter_code = dart_samples[8]
print(f"Flutter code: {flutter_code}")

tokenized_flutter_code = flutter_tokenizer.tokenize(flutter_code)
print(f"Tokenized Flutter Code: {tokenized_flutter_code}")

flutter_code = (flutter_code,)
react_code = ("",)
for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_code, react_code)
    predictions = transformer(
        flutter_code,
        react_code,
        encoder_self_attention_mask.to(device), 
        decoder_self_attention_mask.to(device), 
        decoder_cross_attention_mask.to(device),
    )
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = react_native_tokenizer.get_token(next_token_index)
    react_code = (react_code[0] + next_token, )
    if next_token == END_TOKEN:
        break
print(f"React Translation: {react_code[0]}")

Flutter code:   bool isPalindrome(int x) {
    if (x < 0 || (x % 10 == 0 && x != 0)) return false;
  
    int revertedNumber = 0;
    while (x > revertedNumber) {
      revertedNumber = revertedNumber * 10 + x % 10;
      x ~/= 10;
    }
  
    return x == revertedNumber || x == revertedNumber ~/ 10;
  }
  
Tokenized Flutter Code: [(1, '<START>'), (4, ' '), (4, ' '), (4, 'bool'), (4, ' '), (3, 'UNK'), (73, '('), (4, 'int'), (4, ' '), (3, 'UNK'), (74, ')'), (4, ' '), (112, '{'), (121, '\n'), (4, ' '), (4, ' '), (4, ' '), (4, ' '), (34, 'if'), (4, ' '), (73, '('), (3, 'UNK'), (4, ' '), (92, '<'), (4, ' '), (3, 'UNK'), (4, ' '), (113, '|'), (113, '|'), (4, ' '), (73, '('), (3, 'UNK'), (4, ' '), (68, '%'), (4, ' '), (3, 'UNK'), (4, ' '), (96, '='), (96, '='), (4, ' '), (3, 'UNK'), (4, ' '), (70, '&'), (70, '&'), (4, ' '), (3, 'UNK'), (4, ' '), (66, '!'), (96, '='), (4, ' '), (3, 'UNK'), (74, ')'), (74, ')'), (4, ' '), (49, 'return'), (4, ' '), (27, 'false'), (91, ';'), (121, '\n'), (4, ' '

In [18]:
# save the model
torch.save(transformer.state_dict(), './models/transformer.pth')