© 2020 Nokia

Licensed under the BSD 3 Clause license

SPDX-License-Identifier: BSD-3-Clause

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

import time
import os
import json
import random
import logging
from functools import partial
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from fastai.basic_data import DataBunch
from fastai.basic_train import Callback, Learner
from fastai.callbacks import SaveModelCallback

from codesearch.utils import load_model, get_best_device, Saveable
from codesearch.encoders import BasicEncoder
from codesearch.data import load_snippet_collection, EVAL_DATASETS, eval_datasets_from_regex
from codesearch.data_config import CODE_FIELD, LANGUAGE_FIELD, DESCRIPTION_FIELD
from codesearch.embedding_retrieval import EmbeddingRetrievalModel
from codesearch.unif.unif_embedder import UNIFEmbedder
from codesearch.evaluation import evaluate_and_dump 
from codesearch.unif.unif_modules import SimilarityModel, MarginRankingLoss
from codesearch.unif.unif_preprocessing import Padder

start = time.time()

Read configuration paramaters from environment variables (when this notebook is run as a script).

In [None]:
# train_snippets_collection = os.environ.get("train_snippet_collection", "so-ds-feb20")
# snippets_collection = os.environ.get("snippet_collection", "so-ds-feb20")
# valid_dataset = os.environ.get("valid_dataset", "so-ds-feb20-valid")
# test_dataset = os.environ.get("test_dataset", "so-ds-feb20-test")
# ncs_embedder = os.environ.get("ncs_embedder", "../ncs/so-ds-feb20/best_ncs_embedder/")
# output_dir = os.environ.get("output_dir", "so-ds-feb20")

# train_snippets_collection = os.environ.get("train_snippet_collection", "conala-curated")
# snippets_collection = os.environ.get("snippet_collection", "conala-curated")
# valid_dataset = os.environ.get("valid_dataset", "conala-curated-0.5-test")
# test_dataset = os.environ.get("test_dataset", "conala-curated-0.5-test")
# ncs_embedder = os.environ.get("ncs_embedder", "../ncs/conala/best_ncs_embedder/")
# output_dir = os.environ.get("output_dir", "conala")

train_snippets_collection = os.environ.get("train_snippet_collection", "staqc-py-cleaned")
snippets_collection = os.environ.get("snippet_collection", "staqc-py-cleaned")
valid_dataset = os.environ.get("valid_dataset", "staqc-py-raw-valid")
test_dataset = os.environ.get("test_dataset", "staqc-py-raw-test")
ncs_embedder = os.environ.get("ncs_embedder", "../ncs/staqc-py/best_ncs_embedder/")
output_dir = os.environ.get("output_dir", "staqc-py")

if not Path(output_dir).exists():
    Path(output_dir).mkdir()
margin = float(os.environ.get("margin", 0.05))
random_init = bool(os.environ.get("random_init", False))

momentum = float(os.environ.get("momentum", 0.9))
lr = float(os.environ.get("lr", 0.001))
epochs = int(os.environ.get("epochs", 20))
fit_one_cyle = bool(os.environ.get("fit_one_cyle", False))
clip = float(os.environ.get("clip", 0.))


In [None]:
ncs_embedder

In [None]:
margin, random_init

In [None]:
lr, epochs, momentum, fit_one_cyle, clip

## Load data

In [None]:
if valid_dataset and valid_dataset not in EVAL_DATASETS:
    raise ValueError()
test_datasets = eval_datasets_from_regex(test_dataset)
snippets = load_snippet_collection(snippets_collection)
train_snippets = load_snippet_collection(train_snippets_collection)

In [None]:
snippets_collection, train_snippets_collection, valid_dataset, test_datasets

In [None]:
ncs = load_model(ncs_embedder)
ft_model = ncs._ft_model
enc =  ncs._enc 

## Dataset and DataLoader

TODO: bucketize the minimize padding in batches

In [None]:
class CodeSnippetsAndDescriptions(Dataset):
    
    def __init__(self, snippet_collection, transform=None, deterministic=False, dummy=False):
        snippets = load_snippet_collection(snippet_collection)
        snippets = [{"code": s[CODE_FIELD], "description": s[DESCRIPTION_FIELD], "language": s[LANGUAGE_FIELD]}
                    for s in snippets]
        if dummy:
            snippets = snippets[:50]
        self.snippets = pd.DataFrame(snippets)
        self.transform = transform
        if deterministic:
            random.seed(42)
            random_idx = list(range(len(snippets)))
            random.shuffle(random_idx)
            self.random_idx = np.array(random_idx)
        else:
            self.random_idx = None
    
    def random(self, idx):
        if self.random_idx is not None:
            return self.random_idx[idx]
        return random.randint(0, len(self) - 1)
    
    def __len__(self):
        return len(self.snippets)
    
    def __getitem__(self, idx):
        s = self.snippets.iloc[idx]
        code, description, language = s["code"], s["description"], s["language"]        
        idx_rand = self.random(idx)
        while self.snippets.iloc[idx_rand]["description"] == description:
            idx_rand = random.randint(0, len(self) - 1)
        random_description = self.snippets.iloc[idx_rand]["description"]
        
        x = {
            "code": code, 
            "descriptions": np.array([description, random_description]), 
            "language": language
        }
        y = torch.tensor([1, 0], dtype=torch.long, device=get_best_device())
        if self.transform:
            x, y = self.transform((x, y))
        return x, y


In [None]:
class Preprocess(object):
    
    def __init__(self, encoder):
        self.encoder = encoder
        
    def __call__(self, sample):
        x, y = sample
        code, _ = self.encoder.encode_code(x["code"], x["language"])
        descriptions = [self.encoder.encode_description(descr) for descr in x["descriptions"]]
        x = {"code": code, "descriptions": descriptions,  "language": x["language"]}
        return x, y

## Retrieval model

## Training utils

In [None]:
def get_data(train_ds, valid_ds, bs, max_code_len=200, max_description_len=25, ft_model=None):
    padder = Padder(max_code_len, max_code_len, ft_model=ft_model)
    return (
        padder,
        DataLoader(train_ds, batch_size=bs, shuffle=True, pin_memory=False, collate_fn=padder),
        DataLoader(valid_ds, batch_size=bs * 2, collate_fn=padder, pin_memory=False) if valid_ds else None
    )

def get_model(ft_model, random_init=False):
    model = SimilarityModel(ft_model, random_init=random_init)
    return model

def create_retrieval_model(model, encoder, snippets, ft_model=None):
    unif_embedder = UNIFEmbedder(model, encoder, ft_model, batch_size=2, max_code_len=200, max_description_len=25)
    retrieval_model = EmbeddingRetrievalModel(unif_embedder)
    retrieval_model.add_snippets(snippets)
    return retrieval_model

def eval_retrieval(model, encoder, snippets, valid_dataset, test_datasets, ft_model=None):
    retrieval_model = create_retrieval_model(model, encoder, snippets, ft_model=ft_model)
    results = evaluate_and_dump(retrieval_model, {}, output_dir, valid_dataset, test_datasets)
    print(results)
    return results[valid_dataset]["mrr"]



In [None]:
train_ds = CodeSnippetsAndDescriptions(train_snippets_collection, transform=Preprocess(enc))
# we use the retrieval model for validation, this is only a dummy set
valid_ds = CodeSnippetsAndDescriptions(train_snippets_collection, transform=Preprocess(enc), deterministic=True, dummy=True)

In [None]:
class MRR(Callback):
    "Wrap a `func` in a callback for metrics computation."
    def __init__(self, encoder, snippets, valid_dataset, test_datasets, ft_model, model):
        super().__init__()
        self.encoder = encoder
        self.snippets = pd.DataFrame(snippets)
        self.ft_model = ft_model
        self.model = model
        eval_retrieval_fn = partial(eval_retrieval, 
                            encoder=encoder,
                            snippets=snippets, 
                            valid_dataset=valid_dataset,
                            test_datasets=test_datasets,
                            ft_model=ft_model
                           )
        self.name = "mrr"
        self.func = eval_retrieval_fn
        self.model = model
        self.best_result = 0
        self.epoch = 0


    def on_epoch_end(self, last_metrics, **kwargs):
        "Set the final result in `last_metrics`."
        result = self.func(self.model)
        print(result, self.best_result)
        print(type(result), type(self.best_result))
        if result > self.best_result:
            print("saving model")
            self.model.save(output_dir + f"/model-epoch={self.epoch}")
            self.best_result = result
        self.epoch += 1
        return {'last_metrics': last_metrics + [result]}


## Training

initial observations:

- higher margin does not help

## Initial model

In [None]:
model = get_model(ft_model, random_init)
collate_fn, train_dl, valid_ds = get_data(train_ds, valid_ds, 32, ft_model=ft_model)
db = DataBunch(train_dl, valid_ds, collate_fn=collate_fn)

In [None]:
loss_func = MarginRankingLoss(margin)
opt_func = partial(optim.Adam, betas=(momentum, 0.999))
learner = Learner(db, model, loss_func=loss_func, opt_func=opt_func, wd=0, metrics=[MRR(enc, snippets, valid_dataset, test_datasets, ft_model, model)])
if clip:
    learner.clip_grad(clip)

In [None]:
learner.model

## Train model

In [None]:
learner.unfreeze()
if fit_one_cyle:
    learner.fit_one_cyle(epochs)
else:
    learner.fit(epochs, lr=lr)

## Save best retrieval model

In [None]:
# change "model-epoch=17" with the best checkpoint
best_model = Saveable.load(f"{output_dir}/model-epoch=14")
#retrieval_model = create_retrieval_model(best_model, enc, snippets, ft_model)


In [None]:
unif_embedder = UNIFEmbedder(best_model, enc, ft_model, batch_size=32, max_code_len=200, max_description_len=25)
unif_embedder.save(f"{output_dir}/best_unif_embedder")

In [None]:
config = {"model": "unif_best"} 
evaluate_and_dump(retrieval_model, config, output_dir, valid_dataset, test_datasets)