In [4]:
import re
import collections
from typing import List, Tuple, Dict

# Step 1: Gather Flutter and React Native keywords
flutter_keywords = [
    'class', 'import', 'void', 'int', 'String', 'double', 'bool', 'Widget', 'setState',
    'build', 'context', 'Container', 'Column', 'Row', 'Text', 'RaisedButton', 'Scaffold', 'AppBar'
]

react_native_keywords = [
    'import', 'from', 'class', 'constructor', 'render', 'return', 'Component', 'useState', 'useEffect',
    'View', 'Text', 'Button', 'StyleSheet', 'TouchableOpacity', 'FlatList', 'ScrollView', 'SafeAreaView'
]

# Step 2: Basic tokenizer for the corpus
def basic_tokenizer(text: str) -> List[str]:
    tokens = re.findall(r'\w+|\S', text)
    return tokens

# Step 3: Extract tokens and initialize the vocabulary with keywords
def initialize_vocabulary(corpus: List[str], keywords: List[str]) -> List[str]:
    vocabulary = set(keywords)
    for text in corpus:
        tokens = basic_tokenizer(text)
        vocabulary.update(tokens)
    return list(vocabulary)

# Step 4: Implement BPE to refine the vocabulary
def get_stats(vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq
    return pairs

def merge_vocab(pair: Tuple[str, str], vocab: Dict[str, int]) -> Dict[str, int]:
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in vocab:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = vocab[word]
    return v_out

def byte_pair_encoding(corpus: List[str], num_merges: int) -> List[str]:
    vocab = collections.Counter()
    for text in corpus:
        tokens = basic_tokenizer(text)
        for token in tokens:
            vocab[' '.join(token)] += 1

    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)

    bpe_vocab = set()
    for word in vocab:
        bpe_vocab.update(word.split())

    return list(bpe_vocab)

# Example usage
flutter_corpus = [
    """import 'package:flutter/material.dart';
    class MyApp extends StatelessWidget {
      @override
      Widget build(BuildContext context) {
        return MaterialApp(
          home: Scaffold(
            appBar: AppBar(title: Text('Flutter App')),
            body: Center(child: Text('Hello, world!')),
          ),
        );
      }
    }"""
]

react_native_corpus = [
    """import React from 'react';
    import { View, Text, Button, StyleSheet } from 'react-native';

    const App = () => {
      return (
        <View style={styles.container}>
          <Text>Hello, world!</Text>
          <Button title="Press me" onPress={() => alert('Button pressed')} />
        </View>
      );
    };

    const styles = StyleSheet.create({
      container: {
        flex: 1,
        justifyContent: 'center',
        alignItems: 'center',
      },
    });

    export default App;"""
]

# Initialize vocabulary with keywords
flutter_vocabulary = initialize_vocabulary(flutter_corpus, flutter_keywords)
react_native_vocabulary = initialize_vocabulary(react_native_corpus, react_native_keywords)

# Apply BPE to further refine the vocabulary
flutter_bpe_vocabulary = byte_pair_encoding(flutter_corpus, 50)
react_native_bpe_vocabulary = byte_pair_encoding(react_native_corpus, 50)

# Combine keywords and BPE tokens
final_flutter_vocabulary = list(set(flutter_vocabulary).union(set(flutter_bpe_vocabulary)))
final_react_native_vocabulary = list(set(react_native_vocabulary).union(set(react_native_bpe_vocabulary)))

print("Final Flutter Vocabulary:", final_flutter_vocabulary)
print("Final React Native Vocabulary:", final_react_native_vocabulary)

Final Flutter Vocabulary: ['import', 'id', 'b', 'h', 'M', 'exte', 'Scaffold', 'appBar', 'child', 'ontext', 'home', 'App', 'aterial', 'Row', 'a', 'dart', 'B', 'lutter', '{', 'o', 'class', 'Center', '(', 's', 'material', '}', '!', 'e', 'Widget', 'Bar', 'String', 'Container', 'u', 'Column', 'le', 'v', 'w', ';', 'pp', 'MaterialApp', 'setState', 'build', 'context', '.', 'r', 'f', 'return', 'ld', 'y', 'C', ',', 't', 'ter', 'package', 'd', 'BuildContext', '@', 'F', 'flutter', 'm', 'Hello', ':', 'world', 'Text', 'void', "'", 'int', 'double', 'body', 'AppBar', '/', 'ss', 'i', 'title', 'StatelessWidget', 'S', 'l', 'Flutter', 'c', 'override', 'H', 'ild', 'extends', 'or', 'MyApp', 'bool', 'te', 'uild', ')', 'RaisedButton', 'n']
Final React Native Vocabulary: ['import', 'react', 'st', 'flex', 'useState', 'justifyContent', 'R', 'Component', '1', 'App', 'a', '<', '{', 'const', 'useEffect', 'from', 'class', 'React', '(', 'o', 's', 'g', 'act', 'alignItems', '}', 'style', '!', 'View', 'e', '>', '=', 'st

In [18]:
from utils.transformer import Transformer
from utils.parser import react_native_vocab, flutter_vocab, START_TOKEN, END_TOKEN, PADDING_TOKEN

d_model = 512
batch_size = 1
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 500
react_vocab_size = len(react_native_vocab)

react_native_to_index = {word: i for i, word in enumerate(react_native_vocab)}
flutter_to_index = {word: i for i, word in enumerate(flutter_vocab)}

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          react_vocab_size,
                          flutter_to_index,
                          react_native_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [19]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, flutter_snippets, react_snippets):
        self.flutter_snippets = flutter_snippets
        self.react_snippets = react_snippets

    def __len__(self):
        return len(self.flutter_snippets)

    def __getitem__(self, idx):
        return self.flutter_snippets[idx], self.react_snippets[idx]

In [20]:
import pandas as pd
df = pd.read_csv('./data/raw/code.csv')

dart_samples = [sample.lower() for sample in df['dart'].values]
js_samples = [sample.lower() for sample in df['js'].values]

# # filter for max length of sample at `max_sequence_len`` characters for training purposes
dart_samples = [sample for sample in dart_samples if len(sample) < max_sequence_length]
js_samples = [sample for sample in js_samples if len(sample) < max_sequence_length]

print("Number of Dart samples:", len(dart_samples))
print("Number of JS samples:", len(js_samples))

dataset = TextDataset(dart_samples, js_samples)

Number of Dart samples: 32
Number of JS samples: 33


In [21]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [22]:
from torch import nn
import torch

criterian = nn.CrossEntropyLoss(ignore_index=react_native_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

In [23]:
import numpy as np

NEG_INFTY = -1e9

def create_masks(flutter_batch, react_batch):
    num_sentences = len(flutter_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
        flutter_sentence_length, react_sentence_length = len(flutter_batch[idx]), len(react_batch[idx])
        flutter_tokens_to_padding_mask = np.arange(flutter_sentence_length + 1, max_sequence_length)
        react_tokens_to_padding_mask = np.arange(react_sentence_length + 1, max_sequence_length)
        encoder_padding_mask[idx, :, flutter_tokens_to_padding_mask] = True
        encoder_padding_mask[idx, flutter_tokens_to_padding_mask, :] = True
        decoder_padding_mask_self_attention[idx, :, react_tokens_to_padding_mask] = True
        decoder_padding_mask_self_attention[idx, react_tokens_to_padding_mask, :] = True
        decoder_padding_mask_cross_attention[idx, :, flutter_tokens_to_padding_mask] = True
        decoder_padding_mask_cross_attention[idx, react_tokens_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [24]:
from tqdm import tqdm

transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 3

for epoch in tqdm(range(num_epochs)):
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        flutter_batch, react_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_batch, react_batch)
        optim.zero_grad()
        react_pred = transformer(flutter_batch,
                                     react_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(react_batch, start_token=False, end_token=True)
        loss = criterian(
            react_pred.view(-1, react_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == react_native_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        # if batch_num % 100 == 0:
        #     print(f"Iteration {batch_num} : {loss.item()}")
        #     print(f"Flutter: {flutter_batch[0]}")
        #     print(f"React Translation: {react_batch[0]}")
        #     kn_sentence_predicted = torch.argmax(react_pred[0], axis=1)
        #     predicted_sentence = ""
        #     for idx in kn_sentence_predicted:
        #         if idx == react_native_to_index[END_TOKEN]:
        #             break
        #     predicted_sentence += react_native_vocabulary[idx.item()]
        #     print(f"React Prediction: {predicted_sentence}")


            # transformer.eval()
            # react_code = ("",)
            # flutter_code = ("should we go to the mall?",)
            # for word_counter in range(max_sequence_length):
            #     encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(flutter_code, react_code)
            #     predictions = transformer(flutter_code,
            #                               react_code,
            #                               encoder_self_attention_mask.to(device), 
            #                               decoder_self_attention_mask.to(device), 
            #                               decoder_cross_attention_mask.to(device),
            #                               enc_start_token=False,
            #                               enc_end_token=False,
            #                               dec_start_token=True,
            #                               dec_end_token=False)
            #     next_token_prob_distribution = predictions[0][word_counter] # not actual probs
            #     next_token_index = torch.argmax(next_token_prob_distribution).item()
            #     next_token = react_native_vocabulary[next_token_index]
            #     react_code = (react_code[0] + next_token, )
            #     if next_token == END_TOKEN:
            #       break
            
            # print(f"Evaluation translation (should we go to the mall?) : {react_code}")
            # print("-------------------------------------------")

100%|██████████| 3/3 [11:06<00:00, 222.21s/it]


In [46]:
transformer.eval()
def translate(eng_sentence):
  eng_sentence = (eng_sentence,)
  kn_sentence = ("",)
  for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
    predictions = transformer(eng_sentence,
                              kn_sentence,
                              encoder_self_attention_mask.to(device), 
                              decoder_self_attention_mask.to(device), 
                              decoder_cross_attention_mask.to(device),
                              enc_start_token=False,
                              enc_end_token=False,
                              dec_start_token=True,
                              dec_end_token=False)
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    print(f"Next Token Idx: {next_token_index}")
    next_token = react_native_vocab[next_token_index]
    print(f"Next Token: {next_token}")
    kn_sentence = (kn_sentence[0] + next_token, )
    if next_token == END_TOKEN:
        print("THE END")
        break
    return kn_sentence[0]

In [48]:
react_native_vocab

['<START>',
 ' ',
 '!',
 '"',
 '#',
 '$',
 '%',
 '&',
 "'",
 '(',
 ')',
 '*',
 '+',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 '<',
 '=',
 '>',
 '?',
 '@',
 '[',
 '\\',
 ']',
 '^',
 '_',
 '`',
 ';',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '\n',
 '\t',
 '\r',
 '\x0b',
 '\x0c',
 '{',
 '|',
 '}',
 '~',
 '<PAD>',
 '<END>']

In [47]:
# predict the translation of a flutter code snippet
snippet = dart_samples[7]
pred = translate(snippet)
print(dart_samples[1])
print("PREDICTION")
print(pred)

Next Token Idx: 1
Next Token:  
  int lengthoflongestsubstring(string s) {
    map<char, int> map = {};
    int maxlength = 0, start = 0;
  
    for (int i = 0; i < s.length; i++) {
      if (map.containskey(s[i]) && map[s[i]]! >= start) {
        start = map[s[i]]! + 1;
      }
      map[s[i]] = i;
      maxlength = max(maxlength, i - start + 1);
    }
  
    return maxlength;
  }
  
PREDICTION
 


In [36]:
# save the model
torch.save(transformer.state_dict(), './models/transformer.pth')

In [37]:
torch.cuda.is_available()

False