# Training Example

In [1]:
import torch

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

True
NVIDIA GeForce GTX 1060 6GB


## Model Hyperparameters

In [2]:
START_TOKEN = '<START>'
END_TOKEN = '<END>'
PADDING_TOKEN = '<PAD>'

d_model = 512
batch_size = 2
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 500

## Create Flutter Tokenizer

In [3]:
from pygments.lexers import DartLexer
from utils.vocab import dart_vocab
from utils.code_tokenizer import CodeTokenizer

flutter_tokenizer = CodeTokenizer(
    DartLexer(),
    framework_vocab=["Scaffold", "Widget", "setState"],
    language_vocab=dart_vocab,
    START_TOKEN=START_TOKEN,
    END_TOKEN=END_TOKEN,
    PAD_TOKEN=PADDING_TOKEN
)

print(f"Token Count: {len(flutter_tokenizer)}")

Token Count: 29119


## Create React Native Tokenizer

In [4]:
from pygments.lexers import JavascriptLexer
from utils.vocab import javascript_vocab
from utils.code_tokenizer import CodeTokenizer

react_native_tokenizer = CodeTokenizer(
    JavascriptLexer(),
    framework_vocab=["View", "Text", "useState"],
    language_vocab=javascript_vocab,
    START_TOKEN=START_TOKEN,
    END_TOKEN=END_TOKEN,
    PAD_TOKEN=PADDING_TOKEN,
)

print(f"Token Count: {len(react_native_tokenizer)}")

Token Count: 29121


### Basic Letter Tokenizer
For debugging

In [5]:
from utils.code_tokenizer import CodeTokenizer

letter_tokenizer = CodeTokenizer(
    None,
    framework_vocab=[],
    language_vocab=[
        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
        'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'
        '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '{', '}', '(', ')',
        '[', ']', '=', '+', '-', '*', '/', '%', '^', '&', '|', '!', '?', '<',
        '>', ':', ';', ',', '.', '_', '#', '@', '$', '~', '`', '"', "'", '\\',
        '/', '\n', ' ', '\t', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H',
        'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
        'W', 'X', 'Y', 'Z'
    ],
)

print(f"Token Count: {len(letter_tokenizer)}")

Token Count: 100


In [6]:
from utils.transformer import Transformer

transformer = Transformer(
    d_model, 
    ffn_hidden,
    num_heads, 
    drop_prob, 
    num_layers, 
    max_sequence_length,
    flutter_tokenizer,
    react_native_tokenizer,
)

transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SnippetEmbedding(
      (embedding): Embedding(29119, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding

In [7]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, a_snippets, a_parser, b_snippets, b_parser):
        self.a_snippets = a_snippets
        self.a_parser = a_parser
        self.b_snippets = b_snippets

    def __len__(self):
        return len(self.a_snippets)

    def __getitem__(self, idx):
        return self.a_snippets[idx], self.b_snippets[idx]

In [8]:
import pandas as pd
from utils.parsing.code_parser import DartParser, JavascriptParser

df = pd.read_csv('./data/raw/samples.csv')

dart_samples = df['dart'].values
js_samples = df['javascript'].values

# filter for max length of sample at `max_sequence_len`` characters for training purposes
zipped = list(zip(dart_samples, js_samples))
samples = [(d, j) for d, j in zipped]
dart_samples, js_samples = zip(*samples)

print("Number of Dart samples:", len(dart_samples))
print("Number of JS samples:", len(js_samples))

dataset = TextDataset(dart_samples, DartParser(), js_samples, JavascriptParser())

Number of Dart samples: 212
Number of JS samples: 212


In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [10]:
from torch import nn
import torch

criterian = nn.CrossEntropyLoss(ignore_index=react_native_tokenizer[PADDING_TOKEN], reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.AdamW(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [11]:
import numpy as np

NEG_INFTY = -1e9

def create_masks(flutter_batch, react_batch):
    num_seqs = len(flutter_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_seqs, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_seqs):
        flutter_sentence_length, react_sentence_length = len(flutter_batch[idx]), len(react_batch[idx])
        flutter_tokens_to_padding_mask = np.arange(flutter_sentence_length + 1, max_sequence_length)
        react_tokens_to_padding_mask = np.arange(react_sentence_length + 1, max_sequence_length)
        encoder_padding_mask[idx, :, flutter_tokens_to_padding_mask] = True
        encoder_padding_mask[idx, flutter_tokens_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[idx, :, react_tokens_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, react_tokens_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[idx, :, flutter_tokens_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, react_tokens_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [12]:
from tqdm import tqdm

transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 100

for epoch in tqdm(range(num_epochs)):
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        flutter_batch, react_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_batch, react_batch)
        optim.zero_grad()
        react_pred = transformer(
            flutter_batch,
            react_batch,
            encoder_self_attention_mask.to(device), 
            decoder_self_attention_mask.to(device), 
            decoder_cross_attention_mask.to(device),
        )
        labels = transformer.decoder.sentence_embedding.batch_tokenize(react_batch)
        loss = criterian(
            react_pred.view(-1, len(react_native_tokenizer)).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == react_native_tokenizer[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        total_loss += loss.item()

print(f"Total Loss: {total_loss}")

100%|██████████| 100/100 [26:42<00:00, 16.02s/it]

Total Loss: 4435.746726499201





In [15]:
transformer.eval()
# predict the translation of a flutter code snippet
flutter_code = dart_samples[2]
print(f"double x = 1; double y = 2; double z = x + y; print(z);")

flutter_code = (flutter_code,)
react_code = ("",)
for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_code, react_code)
    predictions = transformer(
        flutter_code,
        react_code,
        encoder_self_attention_mask.to(device), 
        decoder_self_attention_mask.to(device), 
        decoder_cross_attention_mask.to(device),
    )
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = react_native_tokenizer.get_token(next_token_index)
    react_code = (react_code[0] + next_token, )
    if next_token == END_TOKEN:
        break
print(f"React Translation: {react_code[0]}")

double x = 1; double y = 2; double z = x + y; print(z);
React Translation: <START><START>abstractabstractvar>'##ep##ep##T>##fi##fi##ub##ows##quence>abstract##tr##tre##pcleanedcleanede##pcleanedcleanedcleaned>cleanedcleaned^^##as##as##as##asbreak##ition##ition##itionnumbersnumbersZnumbersque##0>abstract##ricetankt##rcleanednumberst##reprefixprefixp##ase##ase##d##ase##aseprefixnumbersScpSc##reak##d##tre##tre##drepeatedrepeated##d>cleaned##ed##edheight##fit##fitnumbersnumbersabstractnumbersnumbersabstractnumbersnumbersabstractnumbersnumbers##do##thestnumbersnumbersnumbers##tringnumbersnumbersnumbers##tring##ition##ition##as##tring##metric##tringcleaned##tringprefixcleaned##tring##trackcleanedcleaned0>abstractcleanedcleanedprefix##ckcleanedcleanedcleanedrcleanedsequencesequence##umsequence##tsequencesequenceheight##p##resequence##x##p##resequence##x##pprefixprefixlocal##eprefixprefixpar##esequencesequencedsequencesequenceas##esequencesequencecomplementcleaned##resequence##xsequence##tring#

In [17]:
# save the model
torch.save(transformer.state_dict(), './models/transformer.pth')