In [1]:
import sys  
sys.path.insert(0, '..')

In [2]:
from datasets import load_dataset
from functools import partial
from tqdm import tqdm
from transformers import AutoTokenizer, FlaxAutoModel
from dataclasses import dataclass
from typing import Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from flax.training.common_utils import shard
import numpy as np
import csv

# for training script
from dataclasses import dataclass, field, asdict, replace
from functools import partial
from typing import Callable, List, Union

import jax
import jax.numpy as jnp
from flax import jax_utils
import faiss
from trainer.utils.ops import normalize_L2, cos_sim

In [3]:
jax.device_count()

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter TPU Host


8

In [4]:
@jax.jit
def batch_accuracy(embeddings_a: jnp.DeviceArray, embeddings_b: jnp.DeviceArray,
                   similarity_fct=cos_sim):
    """

    :param embeddings_a:
    :param embeddings_b: if passing additional hard negatives, use jnp.concatenate([positives, negatives], axis=0) as input.
    :param similarity_fct:
    :return:
    """
    assert (len(embeddings_a) <= len(embeddings_b))
    scores = similarity_fct(embeddings_a, embeddings_b)
    assert scores.shape == (len(embeddings_a), len(embeddings_b))

    indices = np.argmax(scores, axis=1)

    labels = jnp.arange(len(scores), dtype=jnp.int32)

    return np.sum(indices == labels)

    

In [5]:
ds = load_dataset("csv", data_files={"test" : "../data/codesearchnet_test.csv.gz"}, split="test")



In [6]:
#class TrainState(train_state.TrainState):
#    loss_fn: Callable = struct.field(pytree_node=False)
#    scheduler_fn: Callable = struct.field(pytree_node=False)


@partial(jax.pmap, axis_name="batch")
def embedding_step(state, model_inputs1, model_inputs2):
    train = False

    def _forward(model_input):
        attention_mask = model_input["attention_mask"][..., None]
        embedding = state.apply_fn(**model_input, params=state.params, train=train)[0]
        attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

        embedding = embedding * attention_mask
        embedding = jnp.mean(embedding, axis=1)

        modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
        embedding = embedding / jnp.maximum(modulus, 1e-12)

        # gather all the embeddings on same device for calculation loss over global batch
        embedding = jax.lax.all_gather(embedding, axis_name="batch")
        embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

        return embedding

    embedding1, embedding2 = _forward(model_inputs1), _forward(model_inputs2)
    return embedding1, embedding2

def get_batched_dataset(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield dict(batch)


@dataclass
class DataCollator:
    tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer]
    input1_maxlen: int = 128
    input2_maxlen: int = 128

    def __call__(self, batch):
        # Currently only static padding; TODO: change below for adding dynamic padding support
        model_input1 = self.tokenizer(batch["docstring"], return_tensors="jax", max_length=self.input1_maxlen, truncation=True, padding="max_length")
        model_input2 = self.tokenizer(batch["code"], return_tensors="jax", max_length=self.input2_maxlen, truncation=True, padding="max_length")
        model_input1, model_input2 = dict(model_input1), dict(model_input2)
        return shard(model_input1), shard(model_input2)
        # return model_input1, model_input2

In [7]:
# model = FlaxAutoModel.from_pretrained("microsoft/codebert-base")
# /home/pascal_voitot/flax-sentence-embeddings/notebooks/checkpoints-2gg8aig1-epoch-1
# /home/pascal_voitot/flax-sentence-embeddings/notebooks/checkpoints-2shag1q1-epoch-19
# model = FlaxAutoModel.from_pretrained("checkpoints-2shag1q1-epoch-19")
# model = FlaxAutoModel.from_pretrained("checkpoints-2gg8aig1-epoch-1")
model = FlaxAutoModel.from_pretrained("checkpoints-2shag1q1-epoch-1")

In [8]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

In [9]:
data_collator = DataCollator(
    tokenizer=tokenizer,
    input1_maxlen=200,
    input2_maxlen=200,
)

In [11]:
from flax.training import train_state
from itertools import islice
import optax

batch_size = 32
lr = 2e-5
init_lr = 1e-5
weight_decay = 1e-3
warmup_steps = 2000

def build_tx(lr, init_lr, warmup_steps, weight_decay):
    tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=None)
    return tx, lr

tx, lr = build_tx(lr, init_lr, warmup_steps, weight_decay)

state = train_state.TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx = tx,
)
state = jax_utils.replicate(state)

total = len(ds) // batch_size
batch_iterator = get_batched_dataset(ds, batch_size, seed=None)

queries = []
codes = []
accs = 0
for j, batch in tqdm(enumerate(batch_iterator), desc=f"Compute", total=total):
    model_input1, model_input2 = data_collator(batch)
    emb1, emb2 = embedding_step(state, model_input1, model_input2)
    emb1 = jax_utils.unreplicate(emb1)
    emb2 = jax_utils.unreplicate(emb2)
    batch_accs = batch_accuracy(emb1, emb2)
    accs += batch_accs
    
    emb1 = normalize_L2(emb1)
    emb2 = normalize_L2(emb2)
    queries.append(emb1)
    codes.append(emb2)

accs = accs / len(ds)
print("accs", accs)
    

Compute: 100%|██████████| 1113/1113 [00:51<00:00, 21.62it/s]accs 0.80162716



In [119]:
codes_all = np.vstack(codes)
print(codes_all.shape)

(35616, 768)


In [120]:
queries_all = np.vstack(queries)
print(queries_all.shape)

(35616, 768)


In [121]:
import faiss

index = faiss.IndexFlatL2(768) 
index.add(codes_all)

In [122]:
index.search(codes_all[:5], 5)

(array([[0.        , 0.37059227, 0.40204617, 0.41203886, 0.41510835],
        [0.        , 0.47468764, 0.50876725, 0.5183419 , 0.5566749 ],
        [0.        , 0.18882029, 0.22639921, 0.28134218, 0.40356898],
        [0.        , 0.18882029, 0.3476327 , 0.35966137, 0.40546203],
        [0.        , 0.2438789 , 0.28134218, 0.3476327 , 0.42306733]],
       dtype=float32),
 array([[   0, 8123,  930, 5203, 7383],
        [   1, 2558,  858, 8013, 2625],
        [   2,    3,    5,    4, 6155],
        [   3,    2,    4,    5,    6],
        [   4,    5,    2,    3, 3953]]))

In [123]:
index.search(queries_all[:10], 5)

(array([[0.37171227, 0.46303356, 0.49616933, 0.5314659 , 0.53641033],
        [0.49271297, 0.49405432, 0.5368102 , 0.53785354, 0.56322914],
        [0.43734407, 0.4489539 , 0.45203575, 0.45804134, 0.5295608 ],
        [0.45126903, 0.51035976, 0.5367552 , 0.5540974 , 0.5746402 ],
        [0.38083267, 0.38380542, 0.39851576, 0.40150896, 0.40514043],
        [0.3196303 , 0.34072146, 0.37374735, 0.38923326, 0.4108275 ],
        [0.44508794, 0.48792148, 0.4892619 , 0.5127674 , 0.51469815],
        [0.35478595, 0.37368542, 0.3775533 , 0.41385037, 0.4160037 ],
        [0.21381612, 0.33927527, 0.41578817, 0.43105626, 0.43165684],
        [0.37075606, 0.48533714, 0.5066131 , 0.52010864, 0.52266496]],
       dtype=float32),
 array([[    0,  7382,   930,  5529,  8123],
        [21268, 20635, 21087, 22332, 22105],
        [   16,  1101,     5,     2,  6294],
        [    3,     5,     2,    16,  1101],
        [    4,  6303,  1717,     2,     5],
        [ 1090,  6303,  6302,  6053,  6524],
      

In [124]:
total = len(queries_all) // batch_size

def get_batches(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield batch

queries_all_batch_iterator = get_batches(queries_all, batch_size, seed=None)
# codes_all_batch_iterator = get_batches(codes_all, batch_size, seed=None)

ranks = []

for idx, q in tqdm(enumerate(queries_all_batch_iterator), desc=f"Ranking", total=total):
    indices = np.expand_dims(np.arange(idx * batch_size, (idx + 1) * batch_size), 1)
    indices = np.tile(indices, 100)
    i, batch_ranks = index.search(q, 100)
    rank = np.argwhere(batch_ranks == indices)
    
    rank = rank[:, 1] + 1
    ranks.extend(rank)

batch_mean_mrr = np.mean(1.0 / np.array(ranks))
print("batch_mean_mrr", batch_mean_mrr)

Ranking: 100%|██████████| 1113/1113 [01:08<00:00, 16.20it/s]batch_mean_mrr 0.736877162463145



In [None]:

checkpoints-2gg8aig1-epoch-1;0.5015809076269152
checkpoints-2shag1q1-epoch-19;0.1958147080015603
checkpoints-2shag1q1-epoch-1;0.736877162463145

In [246]:
# ONE BY ONE

all_ranks = []
for idx, q in tqdm(enumerate(queries_all)):
    q = np.expand_dims(q, 0)
    i, d = index.search(q, 100)
    ranks = d[0]
    rank = np.argwhere(ranks == idx)
    if rank.size > 0:
        all_ranks.append(rank.item() + 1)
# print(all_ranks)

mean_mrr = np.mean(1.0 / np.array(all_ranks))
print("mean_mrr", mean_mrr)

35616it [23:09, 25.64it/s]mean_mrr 0.5015806329720339

