In [171]:
import sys  
sys.path.insert(0, '..')

In [172]:
from datasets import load_dataset
from functools import partial
from tqdm import tqdm
from transformers import AutoTokenizer, FlaxAutoModel
from dataclasses import dataclass
from typing import Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from flax.training.common_utils import shard
import numpy as np
import csv

# for training script
from dataclasses import dataclass, field, asdict, replace
from functools import partial
from typing import Callable, List, Union

import jax
import jax.numpy as jnp
from flax import jax_utils
import faiss
from trainer.utils.ops import normalize_L2

In [31]:
ds = load_dataset("csv", data_files={"test" : "../data/codesearchnet_test.csv.gz"}, split="test")

Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/pascal_voitot/.cache/huggingface/datasets/csv/default-2ad89bbbeda92323/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23...


0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /home/pascal_voitot/.cache/huggingface/datasets/csv/default-2ad89bbbeda92323/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23. Subsequent calls will reuse this data.


In [190]:
#class TrainState(train_state.TrainState):
#    loss_fn: Callable = struct.field(pytree_node=False)
#    scheduler_fn: Callable = struct.field(pytree_node=False)


@partial(jax.pmap, axis_name="batch")
def embedding_step(model_inputs1, model_inputs2):
    train = False

    def _forward(model_input):
        attention_mask = model_input["attention_mask"][..., None]
        embedding = model(**model_input, params=model.params, train=train)[0]
        attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

        embedding = embedding * attention_mask
        embedding = jnp.mean(embedding, axis=1)

        modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
        embedding = embedding / jnp.maximum(modulus, 1e-12)

        # gather all the embeddings on same device for calculation loss over global batch
        embedding = jax.lax.all_gather(embedding, axis_name="batch")
        embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

        return embedding

    embedding1, embedding2 = _forward(model_inputs1), _forward(model_inputs2)
    return embedding1, embedding2

def get_batched_dataset(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield dict(batch)


@dataclass
class DataCollator:
    tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer]
    input1_maxlen: int = 128
    input2_maxlen: int = 128

    def __call__(self, batch):
        # Currently only static padding; TODO: change below for adding dynamic padding support
        model_input1 = self.tokenizer(batch["docstring"], return_tensors="jax", max_length=self.input1_maxlen, truncation=True, padding="max_length")
        model_input2 = self.tokenizer(batch["code"], return_tensors="jax", max_length=self.input2_maxlen, truncation=True, padding="max_length")
        model_input1, model_input2 = dict(model_input1), dict(model_input2)
        return shard(model_input1), shard(model_input2)
        # return model_input1, model_input2

In [196]:
# model = FlaxAutoModel.from_pretrained("microsoft/codebert-base")
model = FlaxAutoModel.from_pretrained("checkpoints-2gg8aig1-epoch-1")

In [192]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

In [193]:
data_collator = DataCollator(
    tokenizer=tokenizer,
    input1_maxlen=200,
    input2_maxlen=200,
)

In [212]:
import faiss

index = faiss.IndexFlatL2(768) 

In [101]:
data_collator(next(val_batch_iterator))[0]["input_ids"].shape

(32, 128)

In [204]:
batch_size = 32
from itertools import islice
total = len(ds) // batch_size
val_batch_iterator = get_batched_dataset(ds, batch_size, seed=None)

queries = []
codes = []
for j, batch in tqdm(enumerate(val_batch_iterator), desc=f"Compute", total=total):
    model_input1, model_input2 = data_collator(batch)
    emb1, emb2 = embedding_step(model_input1, model_input2)
    emb1 = jax_utils.unreplicate(emb1)
    emb1 = normalize_L2(emb1)

    emb2 = jax_utils.unreplicate(emb2)
    emb2 = normalize_L2(emb2)
    queries.append(emb1)
    codes.append(emb2)
    

Compute: 100%|██████████| 1113/1113 [00:31<00:00, 35.13it/s]


In [209]:
codes_all = np.vstack(codes)
print(codes_all.shape)

(35616, 768)


In [210]:
queries_all = np.vstack(queries)
print(queries_all.shape)

(35616, 768)


In [213]:
index.add(codes_all)

In [214]:
index.search(codes_all[:5], 5)

(array([[0.        , 0.15305121, 0.1570281 , 0.15887779, 0.16033879],
        [0.        , 0.11687908, 0.16540855, 0.17207018, 0.18670824],
        [0.        , 0.03636249, 0.07812731, 0.10465132, 0.10556918],
        [0.        , 0.03636249, 0.09375702, 0.10378852, 0.11198861],
        [0.        , 0.07812731, 0.08444145, 0.09375702, 0.09401216]],
       dtype=float32),
 array([[   0, 8123, 2582,  213, 5665],
        [   1, 8013, 7307,  923,  922],
        [   2,    3,    4, 2188, 6296],
        [   3,    2,    4,  794, 2188],
        [   4,    2, 2648,    3, 6052]]))

In [245]:
index.search(queries_all[:10], 5)

(array([[0.19371068, 0.21425068, 0.23767886, 0.24520345, 0.2534646 ],
        [0.22690834, 0.24544305, 0.24942982, 0.25861603, 0.2601923 ],
        [0.25210357, 0.2698902 , 0.27462995, 0.2758479 , 0.27776644],
        [0.26015598, 0.26120454, 0.26836994, 0.27711028, 0.28397453],
        [0.09678011, 0.12610778, 0.13524188, 0.1370242 , 0.14111684],
        [0.09203597, 0.10387987, 0.11392431, 0.11646447, 0.1178569 ],
        [0.19822818, 0.20245108, 0.21533564, 0.21593082, 0.21958163],
        [0.1793623 , 0.18039578, 0.18245962, 0.18963315, 0.197237  ],
        [0.1444544 , 0.2306124 , 0.24778871, 0.25728324, 0.2623035 ],
        [0.18207812, 0.23218426, 0.23329931, 0.23586027, 0.24060239]],
       dtype=float32),
 array([[  662,     0,  8047,  6118,  7258],
        [    1,  3480,  4683, 21365, 20562],
        [ 6294,   794,     3,  6155,   185],
        [    3,   794,   185,  1149,  7664],
        [ 6291,  6297,  7195,  6052,     4],
        [ 1090,    16,  6291,  6522,  7313],
      

In [246]:
all_ranks = []
for idx, q in tqdm(enumerate(queries_all)):
    q = np.expand_dims(q, 0)
    i, d = index.search(q, 100)
    ranks = d[0]
    rank = np.argwhere(ranks == idx)
    if rank.size > 0:
        all_ranks.append(rank.item() + 1)
# print(all_ranks)

mean_mrr = np.mean(1.0 / np.array(all_ranks))
print("mean_mrr", mean_mrr)

35616it [23:09, 25.64it/s]mean_mrr 0.5015806329720339



In [253]:
total = len(queries_all) // batch_size

def get_batches(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield batch

queries_all_batch_iterator = get_batches(queries_all, batch_size, seed=None)

ranks = []
for idx, q in tqdm(enumerate(islice(queries_all_batch_iterator, 2)), desc=f"Compute", total=total):
    i, batch_ranks = index.search(q, 100)
    rank = np.nonzero(batch_ranks == idx)
    print(batch_ranks.shape, rank)
    if rank.size > 0:
        ranks.append(rank.item() + 1)
# print(all_ranks)

batch_mean_mrr = np.mean(1.0 / np.array(ranks))
print("batch_mean_mrr", batch_mean_mrr)

Compute:   0%|          | 0/1113 [00:00<?, ?it/s](32, 100) (array([0, 2]), array([ 1, 54]))



AttributeError: 'tuple' object has no attribute 'size'