In [1]:
import re
import collections
from typing import List, Tuple, Dict

# Step 1: Gather Flutter and React Native keywords
flutter_keywords = [
    'class', 'import', 'void', 'int', 'String', 'double', 'bool', 'Widget', 'setState',
    'build', 'context', 'Container', 'Column', 'Row', 'Text', 'RaisedButton', 'Scaffold', 'AppBar'
]

react_native_keywords = [
    'import', 'from', 'class', 'constructor', 'render', 'return', 'Component', 'useState', 'useEffect',
    'View', 'Text', 'Button', 'StyleSheet', 'TouchableOpacity', 'FlatList', 'ScrollView', 'SafeAreaView'
]

# Step 2: Basic tokenizer for the corpus
def basic_tokenizer(text: str) -> List[str]:
    tokens = re.findall(r'\w+|\S', text)
    return tokens

# Step 3: Extract tokens and initialize the vocabulary with keywords
def initialize_vocabulary(corpus: List[str], keywords: List[str]) -> List[str]:
    vocabulary = set(keywords)
    for text in corpus:
        tokens = basic_tokenizer(text)
        vocabulary.update(tokens)
    return list(vocabulary)

# Step 4: Implement BPE to refine the vocabulary
def get_stats(vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq
    return pairs

def merge_vocab(pair: Tuple[str, str], vocab: Dict[str, int]) -> Dict[str, int]:
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in vocab:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = vocab[word]
    return v_out

def byte_pair_encoding(corpus: List[str], num_merges: int) -> List[str]:
    vocab = collections.Counter()
    for text in corpus:
        tokens = basic_tokenizer(text)
        for token in tokens:
            vocab[' '.join(token)] += 1

    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)

    bpe_vocab = set()
    for word in vocab:
        bpe_vocab.update(word.split())

    return list(bpe_vocab)

# Example usage
flutter_corpus = [
    """import 'package:flutter/material.dart';
    class MyApp extends StatelessWidget {
      @override
      Widget build(BuildContext context) {
        return MaterialApp(
          home: Scaffold(
            appBar: AppBar(title: Text('Flutter App')),
            body: Center(child: Text('Hello, world!')),
          ),
        );
      }
    }"""
]

react_native_corpus = [
    """import React from 'react';
    import { View, Text, Button, StyleSheet } from 'react-native';

    const App = () => {
      return (
        <View style={styles.container}>
          <Text>Hello, world!</Text>
          <Button title="Press me" onPress={() => alert('Button pressed')} />
        </View>
      );
    };

    const styles = StyleSheet.create({
      container: {
        flex: 1,
        justifyContent: 'center',
        alignItems: 'center',
      },
    });

    export default App;"""
]

# Initialize vocabulary with keywords
flutter_vocabulary = initialize_vocabulary(flutter_corpus, flutter_keywords)
react_native_vocabulary = initialize_vocabulary(react_native_corpus, react_native_keywords)

# Apply BPE to further refine the vocabulary
flutter_bpe_vocabulary = byte_pair_encoding(flutter_corpus, 50)
react_native_bpe_vocabulary = byte_pair_encoding(react_native_corpus, 50)

# Combine keywords and BPE tokens
final_flutter_vocabulary = list(set(flutter_vocabulary).union(set(flutter_bpe_vocabulary)))
final_react_native_vocabulary = list(set(react_native_vocabulary).union(set(react_native_bpe_vocabulary)))

print("Final Flutter Vocabulary:", final_flutter_vocabulary)
print("Final React Native Vocabulary:", final_react_native_vocabulary)

Final Flutter Vocabulary: ['le', 'ss', ')', 'Container', 'package', 'Widget', 'StatelessWidget', ':', '!', 'F', 'setState', 'h', 't', 'AppBar', 'd', '{', 'v', 'ter', 'build', 'H', 'flutter', "'", 'f', 'context', 'Center', 'override', 'double', 'l', 's', 'pp', 'b', 'm', 'S', 'n', 'home', 'void', 'import', 'class', 'MyApp', 'return', 'c', 'Column', 'a', 'e', 'dart', '/', 'child', 'Text', 'B', 'y', 'ld', 'ontext', 'C', 'aterial', 'M', 'App', 'String', 'RaisedButton', 'u', 'Hello', 'bool', 'w', '(', '}', 'i', 'o', 'lutter', '@', 'Flutter', '.', ';', 'Bar', 'material', 'ild', 'uild', 'BuildContext', 'Scaffold', 'world', 'te', 'appBar', 'int', 'id', ',', 'exte', 'MaterialApp', 'r', 'body', 'or', 'extends', 'Row', 'title']
Final React Native Vocabulary: ['render', 'le', '<', 'export', ')', '-', ':', '!', 'ress', 'j', 'React', 'Component', 'create', 't', 'd', 'ex', '{', 'v', 'I', 'H', '>', 'from', 'constructor', 'Button', 'SafeAreaView', "'", 'f', 'l', 's', 'native', 'm', '=', 'container', 'n'

In [2]:
from backend.utils.transformer import Transformer
from utils.parser import react_native_vocab, flutter_vocab, START_TOKEN, END_TOKEN, PADDING_TOKEN

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
react_vocab_size = len(react_native_vocab)

react_native_to_index = {word: i for i, word in enumerate(react_native_vocab)}
flutter_to_index = {word: i for i, word in enumerate(flutter_vocab)}

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          react_vocab_size,
                          flutter_to_index,
                          react_native_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [3]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SnippetEmbedding(
      (embedding): Embedding(117, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding):

In [4]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, flutter_snippets, react_snippets):
        self.flutter_snippets = flutter_snippets
        self.react_snippets = react_snippets

    def __len__(self):
        return len(self.flutter_snippets)

    def __getitem__(self, idx):
        return self.flutter_snippets[idx], self.react_snippets[idx]

In [5]:
import pandas as pd
df = pd.read_csv('./data/raw/code.csv')

dataset = TextDataset(df['dart'].values, df['js'].values)

In [6]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [7]:
react_native_vocab

['9',
 ')',
 'interface',
 '!',
 'do',
 'function',
 'catch',
 '\t',
 'of',
 '2',
 '6',
 'abstract',
 'typeof',
 '4',
 'else',
 '=',
 'finally',
 'console',
 'class',
 '~',
 'continue',
 'const',
 'in',
 '(',
 ']',
 'implements',
 'with',
 '.',
 ';',
 'default',
 'public',
 'private',
 ',',
 'while',
 'debugger',
 'super',
 'null',
 'eval',
 '1',
 'new',
 'let',
 '<',
 'export',
 'package',
 '5',
 '-',
 ':',
 'yield',
 'true',
 '{',
 'false',
 '\n',
 '+',
 '>',
 '*',
 '0',
 '&',
 'var',
 'delete',
 'protected',
 '[',
 'await',
 'enum',
 '3',
 'for',
 'void',
 'instanceof',
 'import',
 'return',
 '/',
 'log',
 'this',
 '8',
 'try',
 'final',
 'break',
 'static',
 '7',
 ' ',
 '}',
 'switch',
 '%',
 'async',
 '|',
 'if',
 'throw',
 'extends',
 'case',
 'arguments',
 '?',
 '<START>',
 '<PAD>',
 '<END>',
 '<EOF>']

In [8]:
from torch import nn
import torch

criterian = nn.CrossEntropyLoss(ignore_index=react_native_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [9]:
import numpy as np

NEG_INFTY = -1e9

def create_masks(flutter_batch, react_batch):
    num_sentences = len(flutter_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      flutter_sentence_length, react_sentence_length = len(flutter_batch[idx]), len(react_batch[idx])
      flutter_tokens_to_padding_mask = np.arange(flutter_sentence_length + 1, max_sequence_length)
      react_tokens_to_padding_mask = np.arange(react_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, flutter_tokens_to_padding_mask] = True
      encoder_padding_mask[idx, flutter_tokens_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, react_tokens_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, react_tokens_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, flutter_tokens_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, react_tokens_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [10]:
flutter_to_index

{'9': 0,
 'rethrow': 1,
 ')': 2,
 'interface': 3,
 '!': 4,
 'do': 5,
 'deferred': 6,
 'function': 7,
 'catch': 8,
 '\t': 9,
 '2': 10,
 '6': 11,
 'abstract': 12,
 'operator': 13,
 '4': 14,
 'else': 15,
 'double': 16,
 'native': 17,
 '=': 18,
 'Stream': 19,
 'finally': 20,
 'class': 21,
 '~': 22,
 'print': 23,
 'main': 24,
 'continue': 25,
 'part': 26,
 'const': 27,
 'bool': 28,
 'in': 29,
 '(': 30,
 ']': 31,
 'implements': 32,
 'with': 33,
 'List': 34,
 'Set': 35,
 'mixin': 36,
 '.': 37,
 ';': 38,
 'default': 39,
 'covariant': 40,
 'is': 41,
 'sync': 42,
 'varfinal': 43,
 'library': 44,
 ',': 45,
 'while': 46,
 'Future': 47,
 'super': 48,
 'null': 49,
 '1': 50,
 'new': 51,
 'Function': 52,
 '<': 53,
 'export': 54,
 'assert': 55,
 '5': 56,
 '-': 57,
 ':': 58,
 'yield': 59,
 'show': 60,
 'true': 61,
 '{': 62,
 'false': 63,
 '\n': 64,
 '+': 65,
 '>': 66,
 'late': 67,
 'external': 68,
 'Map': 69,
 '*': 70,
 '0': 71,
 '&': 72,
 'var': 73,
 '[': 74,
 'await': 75,
 'enum': 76,
 '3': 77,
 'for'

In [11]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        flutter_batch, react_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(flutter_batch, react_batch)
        optim.zero_grad()
        react_pred = transformer(flutter_batch,
                                     react_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(react_batch, start_token=False, end_token=True)
        loss = criterian(
            react_pred.view(-1, react_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == react_native_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        # if batch_num % 100 == 0:
        #     print(f"Iteration {batch_num} : {loss.item()}")
        #     print(f"Flutter: {flutter_batch[0]}")
        #     print(f"React Translation: {react_batch[0]}")
        #     kn_sentence_predicted = torch.argmax(react_pred[0], axis=1)
        #     predicted_sentence = ""
        #     for idx in kn_sentence_predicted:
        #         if idx == react_native_to_index[END_TOKEN]:
        #             break
        #     predicted_sentence += react_native_vocabulary[idx.item()]
        #     print(f"React Prediction: {predicted_sentence}")


            # transformer.eval()
            # react_code = ("",)
            # flutter_code = ("should we go to the mall?",)
            # for word_counter in range(max_sequence_length):
            #     encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(flutter_code, react_code)
            #     predictions = transformer(flutter_code,
            #                               react_code,
            #                               encoder_self_attention_mask.to(device), 
            #                               decoder_self_attention_mask.to(device), 
            #                               decoder_cross_attention_mask.to(device),
            #                               enc_start_token=False,
            #                               enc_end_token=False,
            #                               dec_start_token=True,
            #                               dec_end_token=False)
            #     next_token_prob_distribution = predictions[0][word_counter] # not actual probs
            #     next_token_index = torch.argmax(next_token_prob_distribution).item()
            #     next_token = react_native_vocabulary[next_token_index]
            #     react_code = (react_code[0] + next_token, )
            #     if next_token == END_TOKEN:
            #       break
            
            # print(f"Evaluation translation (should we go to the mall?) : {react_code}")
            # print("-------------------------------------------")

Epoch 0


RuntimeError: The size of tensor a (1008) must match the size of tensor b (200) at non-singleton dimension 1