In [1]:
import os
os.putenv("CUDA_VISIBLE_DEVICES", "0")


from transformers import BertConfig, EncoderDecoderModel, EncoderDecoderConfig, BertModel, BertLMHeadModel
from dataclasses import dataclass
from typing import Union, Iterable, List
from pathlib import Path
from tensor2tensor.data_generators import text_encoder
import json
import torch
from catalyst import dl
import re
from catalyst.utils import set_global_seed
import random


def read_jsonl(path):
    with open(path, 'r') as istream:
        return [json.loads(l) for l in istream]


class TokenToIndexConverter:

    def __init__(
        self,
        vocab_path: Union[str, Path],
        unk_token = '[UNK]_',
        bos_token = '\\u\\u\\uNL\\u\\u\\u_',
        eos_token = '\\u\\u\\uNEWLINE\\u\\u\\u_',
        pad_token = '<pad>_'
    ):
        vocab_path = Path(vocab_path)
        self.subword_tokenizer = text_encoder.SubwordTextEncoder(vocab_path.as_posix())
        self.token_to_index_map = {
            tok: i
            for i, tok in enumerate(self.subword_tokenizer.all_subtoken_strings)
        }
        self.index_to_token_map = {v: k for k, v in self.token_to_index_map.items()}
        self.unk_token = unk_token
        self.unk_index = self[unk_token]
        self.bos_token = bos_token
        self.bos_index = self[bos_token]
        self.eos_token = eos_token
        self.eos_index = self[eos_token]
        self.pad_token = pad_token
        self.pad_index = self[pad_token]
        self.util_tokens = set(self.subword_tokenizer.all_subtoken_strings[:5])

        self.bos_exp = re.compile(r"^(___NL___)*")  # Remove any number of repeating newline characters
        self.eos_exp = re.compile(r"___NEWLINE___.*$")  # Remove evrythin after first eod

    def __getitem__(self, key):
        return (self.token_to_index_map[key] 
                if key in self.token_to_index_map 
                else self.unk_index)

    def encode(self, tokens: str) -> List[int]:
        return [self[tok] for tok in tokens]
    
    def encode_code(self, code: List[List[str]]) -> List[int]:
        return [self.bos_index] + [
            tok for line in code 
            for tok in self.encode(line)
        ]
    
    def decode(self, tokens: List[int]) -> str:
        text = self.subword_tokenizer.decode(tokens, strip_extraneous=True)
        text = self.bos_exp.sub("", text)
        text = self.eos_exp.sub("", text)
        return text

    def decode_list(self, tokens: List[int]) -> List[str]:
        tokens = self.subword_tokenizer.decode_list(tokens)
        while tokens[0] == self.bos_token and len(tokens) != 0:
            tokens = tokens[1:]
        end_position = tokens.index(self.eos_token) if self.eos_token in tokens else -1 
        tokens = tokens[:end_position]
        return [tok for tok in tokens if tok not in self.util_tokens]

    @property
    def vocab_size(self):
        return self.subword_tokenizer.vocab_size


EXPERIMENT_NAME = "cubert-method-name-encdec-init"


@dataclass
class Config:
    
    #  finetune config:
    pretrained_path: str = (Path.home() / "models/cubert/pre_trained_torch_epoch_2").as_posix()

    encoder_num_hidden_layers: int = 6
    encoder_init_pretrained: bool = True
    encoder_n_layers_to_freeze: int = 4
    encoder_freeze_bert_embeddings: bool = True

    decoder_num_hidden_layers: int = 2
    decoder_init_pretrained: bool = True
    decoder_n_layers_to_freeze: int = 0
    decoder_freeze_bert_embeddings: bool = False
        
    hidden_dropout_prob: float = 0.1
    attention_probs_dropout_prob: float = 0.1

    #  optimization:
    max_lr: float = 1e-5
    batch_size: int = 32
    accumulation_steps: int = 16
    
    weight_decay: float = 0
        
    num_epochs: int = 50
    patience: int = 5

    #  lr scheduling:
    warmup_prop: float = 0.15
        
    #  generation parameters:
    eval_set_size: int = 2000
    num_return_sequences: int = 5

    logdir: str = f'logdir_{EXPERIMENT_NAME}'
    resume: str = None
        
    seed: int = 19


def make_model(config):
    pretrained_config = BertConfig.from_pretrained(config.pretrained_path)
    encoder_config = BertConfig(
        vocab_size=pretrained_config.vocab_size,
        hidden_size=pretrained_config.hidden_size,
        num_hidden_layers=config.encoder_num_hidden_layers,
        num_attention_heads=pretrained_config.num_attention_heads,
        intermediate_size=pretrained_config.intermediate_size,
        max_position_embeddings=pretrained_config.max_position_embeddings,
        pad_token_id=pretrained_config.pad_token_id,
        hidden_dropout_prob = config.hidden_dropout_prob,
        attention_probs_dropout_prob = config.attention_probs_dropout_prob
    )
    decoder_config = BertConfig(
        vocab_size=pretrained_config.vocab_size,
        hidden_size=pretrained_config.hidden_size,
        num_hidden_layers=config.decoder_num_hidden_layers,
        num_attention_heads=pretrained_config.num_attention_heads,
        intermediate_size=pretrained_config.intermediate_size,
        max_position_embeddings=pretrained_config.max_position_embeddings,
        pad_token_id=pretrained_config.pad_token_id,
        hidden_dropout_prob = config.hidden_dropout_prob,
        attention_probs_dropout_prob = config.attention_probs_dropout_prob
    )
    model = EncoderDecoderModel(
        config=EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
    )
    
    pretrained = BertModel.from_pretrained(config.pretrained_path)
    
    # Encoder initialization and freezing block:
    if config.encoder_init_pretrained:
        copy_params_by_names(pretrained, model.encoder)
    
    freeze_lower_n_bert_layers(model.encoder, config.encoder_n_layers_to_freeze)
    if config.encoder_freeze_bert_embeddings:
        freeze_module(model.encoder.embeddings)

    # Decoder initialization and freezing block:
    if config.decoder_init_pretrained:
        copy_params_by_names(pretrained, model.decoder.bert)
        
    freeze_lower_n_bert_layers(model.decoder.bert, config.decoder_n_layers_to_freeze)
    if config.decoder_freeze_bert_embeddings:
        freeze_module(model.decoder.bert.embeddings)

    return model


def copy_params_by_names(src, tgt, verbose=True):
    src_params = dict(src.named_parameters())
    for name, param in tgt.named_parameters():
        if not name in src_params:
            if verbose:
                print(f"Skipping initialization of {name}")
            continue
        param.data.copy_(src_params[name])


def freeze_module(module):
    for param in module.parameters():
        param.requires_grad = False


def prune_bert(bert_model, n_layers_to_leave):
    bert_model.config.num_hidden_layers = n_layers_to_leave
    bert_model.encoder.layer = bert_model.encoder.layer[:n_layers_to_leave]


def freeze_lower_n_bert_layers(bert_model, n_layers_to_freeze):
    for i, layer in enumerate(bert_model.encoder.layer):
        if i == n_layers_to_freeze:
            break
        freeze_module(layer)


MODELS_DIR = Path.home() / "models/cubert"


token_to_index = TokenToIndexConverter(
    (MODELS_DIR / "github_python_minus_ethpy150open_deduplicated_vocabulary.txt").as_posix()
)
config = Config()

set_global_seed(config.seed)

In [2]:
DATA_FOLDER = Path.home() / "data/method_name_prediction/python/final/jsonl"


train = read_jsonl(DATA_FOLDER / "train_preprocessed.jsonl")
valid = read_jsonl(DATA_FOLDER / "valid_preprocessed.jsonl")
test = read_jsonl(DATA_FOLDER / "test_preprocessed.jsonl")

In [3]:
from torch.utils.data import Dataset, DataLoader
from torch import Tensor, LongTensor
from torch.nn.utils.rnn import pad_sequence
from typing import Callable, Iterable, Optional


class SequenceToSequenceDataset(Dataset):

    def __init__(
        self,
        src_stream: Iterable['T'],
        src_encoder: Callable[['T'], Tensor],
        ref_stream: Iterable['T'],
        ref_encoder: Callable[['T'], Tensor],
        src_pad_index: int,
        ref_pad_index: Optional[int] = None
    ):
        self.src = [src_encoder(s) for s in src_stream]
        self.ref = [ref_encoder(s) for s in ref_stream]
        assert len(self.src) == len(self.ref)
        self.src_pad_index = src_pad_index
        self.ref_pad_index = ref_pad_index if ref_pad_index is not None else src_pad_index

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.ref[idx]

    def collate_fn(self, data):
        src_batch, ref_batch = zip(*data)
        input_ids = pad_sequence(
            src_batch,
            padding_value=self.src_pad_index,
            batch_first=True
        )
        attention_mask = input_ids != self.src_pad_index
        decoder_input_ids = pad_sequence(
            ref_batch,
            padding_value=self.ref_pad_index,
            batch_first=True
        )
        labels = decoder_input_ids[:,1:]
        decoder_input_ids = decoder_input_ids[:,:-1]
        decoder_attention_mask = decoder_input_ids != self.ref_pad_index
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'decoder_attention_mask': decoder_attention_mask,
            'labels': labels
        }

    def make_loader(self, *args, **kwargs):
        return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs)


def get_method_name_dataset(data, token_to_index, pad_index, max_length):

    def truncated_encoder(encoder, max_length):
        def wrapper(*args, **kwargs):
            return encoder(*args, **kwargs)[:max_length]
        return wrapper

    def to_torch_encoder(encoder):
        def wrapper(*args, **kwargs):
            return LongTensor(encoder(*args, **kwargs))
        return wrapper

    return SequenceToSequenceDataset(
        src_stream = (e['function_body_tokenized'] for e in data),
        src_encoder = to_torch_encoder(
            truncated_encoder(token_to_index.encode_code, max_length)
        ),
        ref_stream = (e['function_name_tokenized'] for e in data),
        ref_encoder = to_torch_encoder(
            truncated_encoder(token_to_index.encode_code, max_length)
        ),
        src_pad_index = pad_index
    )


In [4]:
# crossattention is expected to be unitialized!
model = make_model(config)

Skipping initialization of encoder.layer.0.crossattention.self.query.weight
Skipping initialization of encoder.layer.0.crossattention.self.query.bias
Skipping initialization of encoder.layer.0.crossattention.self.key.weight
Skipping initialization of encoder.layer.0.crossattention.self.key.bias
Skipping initialization of encoder.layer.0.crossattention.self.value.weight
Skipping initialization of encoder.layer.0.crossattention.self.value.bias
Skipping initialization of encoder.layer.0.crossattention.output.dense.weight
Skipping initialization of encoder.layer.0.crossattention.output.dense.bias
Skipping initialization of encoder.layer.0.crossattention.output.LayerNorm.weight
Skipping initialization of encoder.layer.0.crossattention.output.LayerNorm.bias
Skipping initialization of encoder.layer.1.crossattention.self.query.weight
Skipping initialization of encoder.layer.1.crossattention.self.query.bias
Skipping initialization of encoder.layer.1.crossattention.self.key.weight
Skipping initi

In [5]:
checkpoint = torch.load(Path(config.logdir) / "checkpoints/best.pth")

In [6]:
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [7]:
train_dataset = get_method_name_dataset(train, token_to_index, token_to_index.pad_index, model.encoder.config.max_position_embeddings)
valid_dataset = get_method_name_dataset(valid, token_to_index, token_to_index.pad_index, model.encoder.config.max_position_embeddings)

set_global_seed(config.seed)
beam_subset = random.sample(valid, config.eval_set_size)
beam_dataset = get_method_name_dataset(
   beam_subset, token_to_index, token_to_index.pad_index, model.encoder.config.max_position_embeddings
)

test_dataset = get_method_name_dataset(test, token_to_index, token_to_index.pad_index, model.encoder.config.max_position_embeddings)

In [8]:
import numpy as np


def beam_search(src, model, bos_id, pad_id, end_id, device, max_len=10, k=5):
    src = src.view(1,-1).to(device)
    src_mask = (src != pad_id).to(device)
    
    memory = None
    
    input_seq = [bos_id]
    beam = [(input_seq, 0)] 
    for i in range(max_len):
        candidates = []
        candidates_proba = []
        for snt, snt_proba in beam:
            if snt[-1] == end_id:
                candidates.append(snt)
                candidates_proba.append(snt_proba)
            else:    
                snt_tensor = torch.tensor(snt).view(1, -1).long().to(device)
                
                if memory is None:
                    memory = model(
                        input_ids=src, 
                        attention_mask=src_mask,
                        decoder_input_ids=snt_tensor,
                        return_dict=False
                    )
                else:
                    memory = model(
                        input_ids=src, 
                        attention_mask=src_mask,
                        decoder_input_ids=snt_tensor,
                        encoder_outputs=(memory[1], memory[-1]),
                        return_dict=False
                    )
                    
                proba = memory[0].cpu()[0,-1, :]
                proba = torch.log_softmax(proba, dim=-1).numpy()
                best_k = np.argpartition(-proba, k - 1)[:k]

                for tok in best_k:
                    candidates.append(snt + [tok])
                    candidates_proba.append(snt_proba + proba[tok]) 
                    
        best_candidates = np.argpartition(-np.array(candidates_proba), k - 1)[:k]
        beam = [(candidates[j], candidates_proba[j]) for j in best_candidates]
        beam = sorted(beam, key=lambda x: -x[1])
        
    return beam

In [9]:
from tqdm import tqdm
import pandas as pd
from utils import compute_metrics


DEVICE = torch.device("cuda")
model.to(DEVICE).eval()

metrics = []

with open("cubert-generated-as-lists.jsonl", "w") as ostream:
    with torch.no_grad():
        for i in tqdm(range(len(test_dataset))):
            src, ref = test_dataset[i]
            name = token_to_index.decode_list(ref)
            gen = beam_search(
                src,
                model,
                bos_id=token_to_index.bos_index,
                pad_id=token_to_index.pad_index,
                end_id=token_to_index.eos_index,
                device=DEVICE
            )
            generated = sorted(
                [{"cand": token_to_index.decode_list(t), "score": s} for t, s in gen],
                key=lambda e: e["score"],
                reverse=True
            )
            entry = {
                "original": name,
                "generated": generated
            }
            ostream.write(f"{json.dumps(entry)}\n")
            
            candidates = [g["cand"] for g in generated]
            metrics.append(compute_metrics(name, candidates))

metrics = pd.DataFrame(metrics)

100%|██████████| 21877/21877 [42:23<00:00,  8.60it/s] 


In [10]:
metrics.mean()

exact-top-1        0.090140
exact-top-5        0.164831
precision-top-1    0.261434
precision-top-5    0.432833
recall-top-1       0.228448
recall-top-5       0.407052
f1-top1            0.237528
f1-top5            0.410191
dtype: float64