# Install requirements
Make sure you're using the local conda env for running this notebook. If is not created yet, create one with python 3.9 by running `conda create --name myenv python=3.9`

In [1]:
! pip install --upgrade pip



In [2]:
! pip install -r "../requirements.txt"



# Load dataset

In [3]:
from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets
dataset_name = "code_search_net"

def load_from_cs_net(take: int) -> Dataset:
  ds = load_dataset(dataset_name, 'python', split='train')
  return Dataset.from_dict(ds[:take]) # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


# Embedding models

In [4]:
from sentence_transformers import SentenceTransformer

comment_model = SentenceTransformer('all-mpnet-base-v2')
code_model = SentenceTransformer('flax-sentence-embeddings/st-codesearch-distilroberta-base')
embedding_shape = (768)

# Generate negative samples

In [5]:
from typing import Iterator
from numpy.random import default_rng


random_generator = default_rng(seed=42)

def generate_negative_samples(iterator: Iterator, negative_samples_per_sample: int):
  for batched_sample in iterator:
    codes_embeddings = batched_sample['code_embedding']
    comments_embeddings = batched_sample['comment_embedding']
    batch_indexes = range(len(codes_embeddings))

    for index in batch_indexes:
      indexes = [i for i in batch_indexes if i != index]
      negative_indexes = random_generator.choice(indexes, negative_samples_per_sample, replace=False)

      yield {
        "code_embedding": codes_embeddings[index],
        "comment_embedding": comments_embeddings[index],
        "target": 1
      }

      for negative_index in negative_indexes:
        yield {
          "code_embedding": codes_embeddings[index],
          "comment_embedding": comments_embeddings[negative_index],
          "target": 0
        }

def with_neg_samples(dataset: Dataset, negative_samples_per_sample: int, batch_size = 100) -> Dataset:
  assert negative_samples_per_sample <= batch_size, "negative_samples_per_sample must not be greater than batch_size"
  if negative_samples_per_sample <= 0:
    return dataset
  
  dataset_with_negative_samples: Dataset = Dataset.from_generator(lambda: generate_negative_samples(dataset.iter(batch_size=batch_size), negative_samples_per_sample)) # type: ignore
  return dataset_with_negative_samples

# Generate embedding dataset

In [6]:
import os


train_count = 2000
train_dataset_path = f'../datasets/embeddings_python_train_{train_count}'
train_pairs = load_from_cs_net(train_count)
is_embeddings_dataset_stored = os.path.isdir(train_dataset_path)

def generate_embeddings_in_batch(batched_sample):
  codes = batched_sample['func_code_string']
  comments = batched_sample['func_documentation_string']

  return {
    "code_embedding": code_model.encode(codes),
    "comment_embedding": comment_model.encode(comments),
  }

embeddings_dataset: Dataset = Dataset.from_dict(load_from_disk(train_dataset_path)[:train_count]) if is_embeddings_dataset_stored else train_pairs.map(
  generate_embeddings_in_batch, 
  batched=True, 
  batch_size=100,
  remove_columns=list(train_pairs[0].keys()),
  desc="Generating embeddings"
) # type: ignore

if is_embeddings_dataset_stored == False:
  embeddings_dataset.save_to_disk(train_dataset_path)

# Train

In [None]:
epoch = 100
batch_size = 200

## Add negative samples to train dataset

In [None]:
def to_tf_dataset(negative_samples_per_sample: int):
  tf_train_dataset = with_neg_samples(embeddings_dataset.shuffle(), negative_samples_per_sample).to_tf_dataset().map(lambda sample: ({
    "code_embedding": sample["code_embedding"],
    "comment_embedding": sample["comment_embedding"],
  }, sample["target"]))
  
  return tf_train_dataset

## Fit

In [None]:
from keras import callbacks
from models import build_dense_model

neg_samples_count = [1, 5, 15]
num_hidden_layers = 4
for neg_count in neg_samples_count:
  model = build_dense_model(num_hidden_layers=num_hidden_layers, input_shape=embedding_shape, model_name=f'dense_{num_hidden_layers}_neg_{neg_count}')
  tf_train_dataset = to_tf_dataset(neg_count)
  tensor_board_callback = callbacks.TensorBoard(log_dir=f'../logs/{model.name}')

  model.fit(
    tf_train_dataset.batch(batch_size),
    batch_size=batch_size,
    epochs=epoch,
    callbacks=[tensor_board_callback]
  )
  model.save(f'../models/{model.name}')

# Validation

In [7]:
from typing import Optional
from tqdm import tqdm
from keras.models import load_model

## 1. CodeSearchNet queries

In [None]:
python_splits = load_dataset(dataset_name, 'python', split=['train', 'test', 'validation']) # type: ignore
python_full_dataset = concatenate_datasets(python_splits)
splits_info = python_splits[0].info.splits
python_full_dataset_count = sum([splits_info[key].num_examples for key in splits_info.keys()])

full_dataset_url_index = { sample['func_code_url']: index  for index, sample in tqdm(enumerate(python_full_dataset), desc="Generating dict lookup", total=python_full_dataset_count) }
def search_by_url(url: str) -> Optional[int]:
  try:
    return full_dataset_url_index[url]
  except:
    return None

In [None]:
query_samples_path = '../datasets/query_samples'

def remove_duplicates(dataset: Dataset) -> Dataset:
  pandas_dataset = dataset.to_pandas().drop_duplicates(subset=['Language', 'Query', 'GitHubUrl', 'Relevance'], ignore_index=True) # type: ignore
  dedup_dataset = Dataset.from_pandas(pandas_dataset)
  return dedup_dataset

def remove_queries_without_code(dataset: Dataset) -> Dataset:
  return dataset.filter(lambda sample: search_by_url(sample['GitHubUrl']) is not None, desc="Filtering queries with no corresponding code")

def pre_process_query_samples() -> Dataset:
  cs_net_queries_dataset: Dataset = Dataset.from_csv('../datasets/code_search_net_queries.csv') # type: ignore
  
  return remove_queries_without_code(remove_duplicates(cs_net_queries_dataset))

def get_query_samples() -> Dataset:
  try:
    return Dataset.load_from_disk(query_samples_path)
  except:
    query_samples = pre_process_query_samples()
    query_samples.save_to_disk(query_samples_path)
    return query_samples

In [None]:
query_samples: Dataset = get_query_samples()

### Predict

In [None]:
def get_query_code_embeddings(samples) -> Dataset:
  query_texts = [sample['Query'] for sample in samples]
  query_codes = [python_full_dataset[search_by_url(sample['GitHubUrl'])]['func_code_string'] for sample in samples]
  assert len(query_texts) == len(query_codes), "query_texts and query_codes arrays doesn't have the same length"

  query_embeddings = comment_model.encode(query_texts)
  code_embeddings = code_model.encode(query_codes)

  validation_dataset = []
  for query_embedding, code_embedding in zip(query_embeddings, code_embeddings):
    validation_dataset.append({
      "code_embedding": code_embedding,
      "comment_embedding": query_embedding,
    })

  return Dataset.from_list(validation_dataset)

In [None]:
from keras.models import load_model

def validate(model, samples):
  validation_dataset = get_query_code_embeddings(samples).to_tf_dataset(batch_size=10)

  return {
    "predictions": model.predict(validation_dataset, verbose=0).flatten(),
    "targets": [sample['Relevance'] for sample in samples]
  }


In [None]:
def is_prediction_correct(prediction, target) -> bool:
  if target in [0, 1]:
    return prediction <= 0.5
  
  if target in [2, 3]:
    return prediction > 0.5
  
  raise ValueError(f"target should be in range of [0, 3]. Instead, it has value of {target}")

In [None]:
validation_query_samples = [sample for sample in query_samples if sample['Language'].lower() == 'python']
validation_query_samples_count = len(validation_query_samples)

for model_name in os.listdir('../models/'):
  model = load_model(f'../models/{model_name}')
  result = validate(model, validation_query_samples)
  
  hits = sum([is_prediction_correct(prediction, target) for prediction, target in zip(result['predictions'], result['targets'])])
  success_percentage = hits / validation_query_samples_count

  print(f"model {model_name}: {success_percentage:.2%} - {hits} of {validation_query_samples_count}")

## 2. Generalization experiment

In [8]:
from typing import List


def search(query, model) -> List:
  query_embedding = comment_model.encode([query]).flatten()
  samples = Dataset.from_list([{ "code_embedding": embedding_pair["code_embedding"], 'comment_embedding': query_embedding } for embedding_pair in embeddings_dataset]).to_tf_dataset(batch_size=10)

  predictions = model.predict(samples).flatten()
  results = [(prediction, index) for index, prediction in enumerate(predictions)]

  return results

def get_code(index: int) -> str:
  return train_pairs[index]['func_code_string']

def top_k(k: int, results: List):
  return [get_code(index) for _, index in results[:k]]

def bottom_k(k: int, results: List):
  return [get_code(index) for _, index in results[-k:]]

In [27]:
results = search(
  query="aes encryption",
  model=load_model('../models/dense_4_neg_1/')
)





In [28]:
top_k(5, results)

['def addidsuffix(self, idsuffix, recursive = True):\n        """Appends a suffix to this element\'s ID, and optionally to all child IDs as well. There is sually no need to call this directly, invoked implicitly by :meth:`copy`"""\n        if self.id: self.id += idsuffix\n        if recursive:\n            for e in self:\n                try:\n                    e.addidsuffix(idsuffix, recursive)\n                except Exception:\n                    pass',
 'def setparents(self):\n        """Correct all parent relations for elements within the scop. There is sually no need to call this directly, invoked implicitly by :meth:`copy`"""\n        for c in self:\n            if isinstance(c, AbstractElement):\n                c.parent = self\n                c.setparents()',
 'def setdoc(self,newdoc):\n        """Set a different document. Usually no need to call this directly, invoked implicitly by :meth:`copy`"""\n        self.doc = newdoc\n        if self.doc and self.id:\n            s

In [29]:
bottom_k(5, results)

['def read_ipv4(self, length):\n        """Read Internet Protocol version 4 (IPv4).\n\n        Structure of IPv4 header [RFC 791]:\n\n             0                   1                   2                   3\n             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1\n            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n            |Version|  IHL  |Type of Service|          Total Length         |\n            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n            |         Identification        |Flags|      Fragment Offset    |\n            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n            |  Time to Live |    Protocol   |         Header Checksum       |\n            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n            |                       Source Address                          |\n            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+\n 

In [16]:
import plotly.express as px
import pandas as pd

df = pd.DataFrame(dict(
    index = list(range(len(results))),
    prediction = [prediction for prediction, index in sorted(results, reverse=True)]
))
px.line(df, x='index', y='prediction', title="Results (descending)")