In [1]:
import os
os.putenv("CUDA_VISIBLE_DEVICES", "0")


from transformers import BertConfig, EncoderDecoderModel, EncoderDecoderConfig
from dataclasses import dataclass
from typing import Union, Iterable, List
from pathlib import Path
from tensor2tensor.data_generators import text_encoder
import json
import torch
from catalyst import dl
import re
from catalyst.utils import set_global_seed
import random


def read_jsonl(path):
    with open(path, 'r') as istream:
        return [json.loads(l) for l in istream]


class TokenToIndexConverter:

    def __init__(
        self,
        vocab_path: Union[str, Path],
        unk_token = '[UNK]_',
        bos_token = '\\u\\u\\uNL\\u\\u\\u_',
        eos_token = '\\u\\u\\uNEWLINE\\u\\u\\u_',
        pad_token = '<pad>_'
    ):
        vocab_path = Path(vocab_path)
        self.subword_tokenizer = text_encoder.SubwordTextEncoder(vocab_path.as_posix())
        self.token_to_index_map = {
            tok: i
            for i, tok in enumerate(self.subword_tokenizer.all_subtoken_strings)
        }
        self.index_to_token_map = {v: k for k, v in self.token_to_index_map.items()}
        self.unk_token = unk_token
        self.unk_index = self[unk_token]
        self.bos_token = bos_token
        self.bos_index = self[bos_token]
        self.eos_token = eos_token
        self.eos_index = self[eos_token]
        self.pad_token = pad_token
        self.pad_index = self[pad_token]

        self.bos_exp = re.compile(r"^(___NL___)*")  # Remove any number of repeating newline characters
        self.eos_exp = re.compile(r"___NEWLINE___.*$")  # Remove evrythin after first eod

    def __getitem__(self, key):
        return (self.token_to_index_map[key] 
                if key in self.token_to_index_map 
                else self.unk_index)

    def encode(self, tokens: str) -> List[int]:
        return [self[tok] for tok in tokens]
    
    def encode_code(self, code: List[List[str]]) -> List[int]:
        return [self.bos_index] + [
            tok for line in code 
            for tok in self.encode(line)
        ]
    
    def decode(self, tokens: List[int]) -> str:
        text = self.subword_tokenizer.decode(tokens, strip_extraneous=True)
        text = self.bos_exp.sub("", text)
        text = self.eos_exp.sub("", text)
        return text

    @property
    def vocab_size(self):
        return self.subword_tokenizer.vocab_size


EXPERIMENT_NAME = "seq2seq-transformer"


@dataclass
class Config:
    
    #  bert config:
    vocab_size: int
    pad_token_id: int
    hidden_size: int = 1024
    num_attention_heads: int = 16
    intermediate_size: int = 4096
    max_position_embeddings: int = 512
    
    encoder_num_hidden_layers: int = 6
    decoder_num_hidden_layers: int = 2

    #  optimization:
    max_lr: float = 5e-4
    batch_size: int = 32
    accumulation_steps: int = 16
    
    weight_decay: float = 0
        
    num_epochs: int = 50
    patience: int = 5

    #  lr scheduling:
    warmup_prop: float = 0.15
        
    #  generation parameters:
    eval_set_size: int = 2000
    num_return_sequences: int = 5

    logdir: str = f'logdir_{EXPERIMENT_NAME}'
    resume: str = None
        
    seed: int = 19


def make_model(config):
    encoder_config = BertConfig(
        vocab_size=config.vocab_size,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.encoder_num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        intermediate_size=config.intermediate_size,
        max_position_embeddings=config.max_position_embeddings,
        pad_token_id=config.pad_token_id
    )
    decoder_config = BertConfig(
        vocab_size=config.vocab_size,
        hidden_size=config.hidden_size,
        num_hidden_layers=config.decoder_num_hidden_layers,
        num_attention_heads=config.num_attention_heads,
        intermediate_size=config.intermediate_size,
        max_position_embeddings=config.max_position_embeddings,
        pad_token_id=config.pad_token_id
    )
    model = EncoderDecoderModel(
        config=EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
    )
    return model


MODELS_DIR = Path.home() / "models/cubert"


token_to_index = TokenToIndexConverter(
    (MODELS_DIR / "github_python_minus_ethpy150open_deduplicated_vocabulary.txt").as_posix()
)
config = Config(vocab_size=token_to_index.vocab_size, pad_token_id=token_to_index.pad_index)

set_global_seed(config.seed)

In [2]:
DATA_FOLDER = Path.home() / "data/method_name_prediction/python/final/jsonl"


train = read_jsonl(DATA_FOLDER / "train_preprocessed.jsonl")
valid = read_jsonl(DATA_FOLDER / "valid_preprocessed.jsonl")
test = read_jsonl(DATA_FOLDER / "test_preprocessed.jsonl")

In [3]:
from torch.utils.data import Dataset, DataLoader
from torch import Tensor, LongTensor
from torch.nn.utils.rnn import pad_sequence
from typing import Callable, Iterable, Optional


class SequenceToSequenceDataset(Dataset):

    def __init__(
        self,
        src_stream: Iterable['T'],
        src_encoder: Callable[['T'], Tensor],
        ref_stream: Iterable['T'],
        ref_encoder: Callable[['T'], Tensor],
        src_pad_index: int,
        ref_pad_index: Optional[int] = None
    ):
        self.src = [src_encoder(s) for s in src_stream]
        self.ref = [ref_encoder(s) for s in ref_stream]
        assert len(self.src) == len(self.ref)
        self.src_pad_index = src_pad_index
        self.ref_pad_index = ref_pad_index if ref_pad_index is not None else src_pad_index

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.ref[idx]

    def collate_fn(self, data):
        src_batch, ref_batch = zip(*data)
        input_ids = pad_sequence(
            src_batch,
            padding_value=self.src_pad_index,
            batch_first=True
        )
        attention_mask = input_ids != self.src_pad_index
        decoder_input_ids = pad_sequence(
            ref_batch,
            padding_value=self.ref_pad_index,
            batch_first=True
        )
        labels = decoder_input_ids[:,1:]
        decoder_input_ids = decoder_input_ids[:,:-1]
        decoder_attention_mask = decoder_input_ids != self.ref_pad_index
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'decoder_attention_mask': decoder_attention_mask,
            'labels': labels
        }

    def make_loader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs)


def get_method_name_dataset(data, token_to_index, pad_index, max_length):

    def truncated_encoder(encoder, max_length):
        def wrapper(*args, **kwargs):
            return encoder(*args, **kwargs)[:max_length]
        return wrapper

    def to_torch_encoder(encoder):
        def wrapper(*args, **kwargs):
            return LongTensor(encoder(*args, **kwargs))
        return wrapper

    return SequenceToSequenceDataset(
        src_stream = (e['function_body_tokenized'] for e in data),
        src_encoder = to_torch_encoder(
            truncated_encoder(token_to_index.encode_code, max_length)
        ),
        ref_stream = (e['function_name_tokenized'] for e in data),
        ref_encoder = to_torch_encoder(
            truncated_encoder(token_to_index.encode_code, max_length)
        ),
        src_pad_index = pad_index
    )


In [4]:
train_dataset = get_method_name_dataset(train, token_to_index, token_to_index.pad_index, config.max_position_embeddings)
valid_dataset = get_method_name_dataset(valid, token_to_index, token_to_index.pad_index, config.max_position_embeddings)
beam_dataset = get_method_name_dataset(
    random.sample(valid, config.eval_set_size), token_to_index, token_to_index.pad_index, config.max_position_embeddings
)

In [5]:
model = make_model(config)

In [6]:
from transformers import get_linear_schedule_with_warmup


def init_scheduler(optimizer, num_epochs, num_steps_epoch, warmup_prop):
    num_training_steps = num_steps_epoch * num_epochs + 1
    warmup_steps = int(num_training_steps * warmup_prop)
    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_training_steps
    )

In [7]:
import numpy as np


def beam_search(src, model, bos_id, pad_id, end_id, device, max_len=10, k=5):
    src = src.view(1,-1).to(device)
    src_mask = (src != pad_id).to(device)
    
    memory = None
    
    input_seq = [bos_id]
    beam = [(input_seq, 0)] 
    for i in range(max_len):
        candidates = []
        candidates_proba = []
        for snt, snt_proba in beam:
            if snt[-1] == end_id:
                candidates.append(snt)
                candidates_proba.append(snt_proba)
            else:    
                snt_tensor = torch.tensor(snt).view(1, -1).long().to(device)
                
                if memory is None:
                    memory = model(
                        input_ids=src, 
                        attention_mask=src_mask,
                        decoder_input_ids=snt_tensor,
                        return_dict=False
                    )
                else:
                    memory = model(
                        input_ids=src, 
                        attention_mask=src_mask,
                        decoder_input_ids=snt_tensor,
                        encoder_outputs=(memory[1], memory[-1]),
                        return_dict=False
                    )
                    
                proba = memory[0].cpu()[0,-1, :]
                proba = torch.log_softmax(proba, dim=-1).numpy()
                best_k = np.argpartition(-proba, k - 1)[:k]

                for tok in best_k:
                    candidates.append(snt + [tok])
                    candidates_proba.append(snt_proba + proba[tok]) 
                    
        best_candidates = np.argpartition(-np.array(candidates_proba), k - 1)[:k]
        beam = [(candidates[j], candidates_proba[j]) for j in best_candidates]
        beam = sorted(beam, key=lambda x: -x[1])
        
    return beam

In [8]:
from catalyst import dl


class MethodNameRunner(dl.Runner):

    def __init__(
        self,
        bos_index: int,
        pad_index: int,
        eos_index: int,
        beam_loader_name: str,
        num_return_sequences: int,
        model=None,
        device=None
    ):
        self.bos_index = bos_index
        self.pad_index = pad_index
        self.eos_index = eos_index
        self.beam_loader_name = beam_loader_name
        self.num_return_sequences = num_return_sequences
        super().__init__(model=model, device=device)

    def _handle_batch(self, batch):
        logits = self.model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            decoder_input_ids=batch['decoder_input_ids'],
            decoder_attention_mask=batch['decoder_attention_mask']
        ).logits
        
        pad_index = self.model.decoder.config.pad_token_id
        loss = torch.nn.functional.cross_entropy(
            input=logits.permute(0, 2, 1),
            target=batch['labels'],
            ignore_index=pad_index
        )
        self.batch_metrics['loss'] = loss
        
        if self.loader_name == self.beam_loader_name:
            with torch.no_grad():
                generated = []
                for src in batch['input_ids']:
                    gen = beam_search(
                        src, 
                        self.model,
                        bos_id=self.bos_index,
                        pad_id=self.pad_index,
                        end_id=self.eos_index,
                        device=self.device,
                        max_len=10,
                        k=self.num_return_sequences
                    )
                    generated.append(gen)
                self.output = {'generated': generated}


class GenerationAccuracyCallback(dl.Callback):

    def __init__(
        self,
        decoder,
        beam_loader_name: str,
        num_return_sequences: int,
        generated_key: str = 'generated',
        target_key: str = 'labels',
        prob_print: float = 0.025
    ):
        super().__init__(order=dl.CallbackOrder.Metric)
        self.decoder = decoder
        self.beam_loader_name = beam_loader_name
        self.num_return_sequences = num_return_sequences
        self.generated_key = generated_key
        self.target_key = target_key
        self.prob_print = prob_print

    def on_batch_end(self, runner):
        if not runner.loader_name == self.beam_loader_name:
            return

        labels = runner.input[self.target_key]
        generated = runner.output[self.generated_key]
        
        num_top_n_correct = 0
        for lab, candidates in zip(labels, generated):
            example = self.decoder.decode(lab)

            print_this_example = random.random() < self.prob_print
            for cand, prob in candidates:

                decoded_cand = self.decoder.decode(cand)
                if print_this_example:
                    print("TARGET:", example)
                    print("GENERATED:", decoded_cand)
    
                if decoded_cand == example:
                    num_top_n_correct += 1
                    break

        top_n_acc = num_top_n_correct / len(labels)
        runner.batch_metrics['accuracy'] = top_n_acc


In [9]:
loaders = {
    'train': train_dataset.make_loader(batch_size=config.batch_size, shuffle=True),
    'valid': valid_dataset.make_loader(batch_size=config.batch_size),
    'beam_loader': beam_dataset.make_loader(batch_size=1)
}

optimizer = torch.optim.Adam(model.parameters(), weight_decay=config.weight_decay)
scheduler = init_scheduler(optimizer, config.num_epochs, len(loaders['train']), config.warmup_prop)
callbacks = [
    dl.OptimizerCallback(metric_key='loss', accumulation_steps=config.accumulation_steps),
    dl.SchedulerCallback(mode='batch'),
    dl.EarlyStoppingCallback(patience=config.patience),
    dl.WandbLogger(
        entity='dimaorekhov',
        project='bert4source',
        group='method-name-prediction',
        name=EXPERIMENT_NAME,
        config=vars(config)
    ),
    dl.CheckpointCallback(resume=config.resume),
    GenerationAccuracyCallback(
        decoder=token_to_index,
        beam_loader_name='beam_loader',
        num_return_sequences=config.num_return_sequences
    )
]

In [10]:
Path(config.logdir).mkdir(exist_ok=True)

In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


count_parameters(model)

214754116

In [12]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

encoder.embeddings.word_embeddings.weight True
encoder.embeddings.position_embeddings.weight True
encoder.embeddings.token_type_embeddings.weight True
encoder.embeddings.LayerNorm.weight True
encoder.embeddings.LayerNorm.bias True
encoder.encoder.layer.0.attention.self.query.weight True
encoder.encoder.layer.0.attention.self.query.bias True
encoder.encoder.layer.0.attention.self.key.weight True
encoder.encoder.layer.0.attention.self.key.bias True
encoder.encoder.layer.0.attention.self.value.weight True
encoder.encoder.layer.0.attention.self.value.bias True
encoder.encoder.layer.0.attention.output.dense.weight True
encoder.encoder.layer.0.attention.output.dense.bias True
encoder.encoder.layer.0.attention.output.LayerNorm.weight True
encoder.encoder.layer.0.attention.output.LayerNorm.bias True
encoder.encoder.layer.0.intermediate.dense.weight True
encoder.encoder.layer.0.intermediate.dense.bias True
encoder.encoder.layer.0.output.dense.weight True
encoder.encoder.layer.0.output.dense.bia

In [None]:
runner = MethodNameRunner(
    bos_index=token_to_index.bos_index,
    pad_index=token_to_index.pad_index,
    eos_index=token_to_index.pad_index,
    beam_loader_name='beam_loader',
    num_return_sequences=config.num_return_sequences,
    device=torch.device("cuda")
)
runner.train(
    model=model,
    loaders=loaders,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=config.num_epochs,
    logdir=config.logdir,
    verbose=True,
    callbacks=callbacks
)

[34m[1mwandb[0m: Currently logged in as: [33mdimaorekhov[0m (use `wandb login --relogin` to force relogin)


1/50 * Epoch (train):   0% 1/12733 [00:00<3:05:35,  1.14it/s, loss=10.933, lr=1.047e-08, momentum=0.900]



1/50 * Epoch (train):   1% 106/12733 [01:25<2:53:56,  1.21it/s, loss=10.747, lr=1.110e-06, momentum=0.900]