In [2]:
import logging
from datetime import datetime
from transformers import AutoTokenizer, TFAutoModel, logging as transformers_logging
import tensorflow as tf
import os
import numpy as np
from tqdm import tqdm
from typing import Dict, Literal, Tuple, List, Iterator
from keras import Model, Sequential, callbacks
from keras.layers import Dense, Input, Concatenate
from keras.losses import Loss
from keras.utils import losses_utils
from keras.metrics import BinaryAccuracy, Precision, Recall
from keras.optimizers import Adam
from mongo_db_client import MongoDbClient
from models import MongoDbPairDoc, Partition
import more_itertools
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
transformers_logging.set_verbosity_error()

## Generate embeddings from sentences

In [4]:
embedding_max_length = 256

def pre_process_tokens(tokens) -> str:
    parsed = ' '.join(tokens).replace('\n', ' ')
    parsed = ' '.join(parsed.strip().split())
    return parsed

# TODO: rename to 'generate_embeddings'
def get_embeddings(pairs: List[MongoDbPairDoc]):
  codes = [pre_process_tokens(pair['code_tokens']) for pair in pairs]
  comments = [pre_process_tokens(pair['comment_tokens']) for pair in pairs]

  return [generate_embeddings(codes), generate_embeddings(comments)]

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]), tf.float32)
    return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)

# TODO: rename to generate_sentences_embeddings
def generate_embeddings(sentences: List[str]):
    tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
    model = TFAutoModel.from_pretrained("bert-large-uncased")
    
    encoded_input = tokenizer(
        sentences, 
        padding='max_length', 
        max_length=embedding_max_length, 
        truncation=True, 
        return_tensors='tf',
    )
    model_output = model(**encoded_input, return_dict=True)

    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    embeddings = tf.math.l2_normalize(embeddings, axis=1)
    return embeddings

# Creating an embedding dataset

In [5]:
train_samples_count = 5000
test_samples_count = 1000
valid_samples_count = 1000
embedding_dataset_dir = '../datasets/embeddings/'
db_client = MongoDbClient()

In [10]:
# TODO: remove unused functions
def save_embeddings_dataset(pairs: List[MongoDbPairDoc], batch_size = 100):
  stored_pairs_ids = { pair_id.replace('.npy', ''): "" for pair_id in os.listdir(embedding_dataset_dir) if pair_id.endswith('.npy') }
  new_pairs = [pair for pair in pairs if str(pair['_id']) not in stored_pairs_ids]

  with tqdm(total=len(new_pairs), desc=f"Saving {len(new_pairs)} pairs into embedding dataset") as progress_bar:
    for batch_pairs in more_itertools.chunked(new_pairs, batch_size):
      [code_embeddings, comment_embeddings] = get_embeddings(batch_pairs)
      for pair, code_embedding, comment_embedding in zip(batch_pairs, code_embeddings, comment_embeddings):
        np.save(os.path.join(embedding_dataset_dir, f'{pair["_id"]}.npy'), [code_embedding, comment_embedding])
      progress_bar.update(len(batch_pairs))

def get_stored_embeddings(pair_id: str):
  return np.load(os.path.join(embedding_dataset_dir, f'{pair_id}.npy'))

def validate_embeddings_dataset(pairs: List[MongoDbPairDoc]):
  pairs_len = len(pairs)
  if pairs_len > 100:
    raise ValueError("The pairs length should be <= 100")

  random_index = random.randint(0, pairs_len - 1)
  [code_embeddings, comment_embeddings] = get_embeddings(pairs)
  [store_code_emb, store_comment_emb] = get_stored_embeddings(str(pairs[random_index]["_id"]))

  correct_indexes = []
  for index, (code_emb, comment_emb) in enumerate(zip(code_embeddings, comment_embeddings)):
    is_correct = np.array_equal(code_emb, store_code_emb) and np.array_equal(comment_emb, store_comment_emb)
    if is_correct:
      correct_indexes.append(index)
  
  return len(correct_indexes) == 1 and correct_indexes[0] == random_index

In [16]:
def create_tf_dataset(pairs: List[MongoDbPairDoc]) -> tf.data.Dataset:
  def gen():
    for pair in pairs:
      pair_id = str(pair["_id"])
      [code_embedding, comment_embedding] = get_stored_embeddings(pair_id)
      yield {
        "id": pair_id,
        "code_embedding": code_embedding,
        "comment_embedding": comment_embedding,
      }
  
  return tf.data.Dataset.from_generator(gen, output_types={ 
    "id": tf.string, 
    "code_embedding": tf.float32, 
    "comment_embedding": tf.float32,
  })

## Writing embeddings pairs

In [6]:
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(train_samples_count)))
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(test_samples_count)))
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(valid_samples_count)))

In [7]:
# is_train_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(100)))
# is_test_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(100)))
# is_valid_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(100)))

# print(f'is train dataset correct? {is_train_correct}') 
# print(f'is test dataset correct? {is_test_correct}') 
# print(f'is valid dataset correct? {is_valid_correct}') 

# Training the model

In [None]:
NumDenseLayers = Literal[2, 4, 8]
input_shape = (1024,) # TODO: Use variables
hidden_layer_activation = 'tanh'
output_activation = 'sigmoid'
dense_layers: Dict[NumDenseLayers, List] = {
  2: [
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ],
  4: [
    Dense(400, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
  8: [
    Dense(800, activation=hidden_layer_activation),
    Dense(600, activation=hidden_layer_activation),
    Dense(500, activation=hidden_layer_activation),
    Dense(400, activation=hidden_layer_activation),
    Dense(300, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
}
dropout_rate=0.1

In [None]:
class ConstrastiveLoss(Loss):
   def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name="constrastive_loss", margin=1):
      self.margin = margin
      super().__init__(reduction, name)

   def call(self, y_true, y_pred):
      square_pred = tf.math.square(y_pred)
      margin_square = tf.math.square(tf.math.maximum(self.margin - (y_pred), 0))
      return tf.math.reduce_mean(
        (1 - y_true) * square_pred + (y_true) * margin_square
      )

def build_model(num_hidden_layers: NumDenseLayers):
  code_input = Input(
    shape=input_shape,
    name="code_embedding_input",
  )
  comment_input = Input(
    shape=input_shape,
    name="comment_embedding_input",
  )
  concatenated_inputs = Concatenate(axis=1)([code_input, comment_input])
  hidden_layers = Sequential(dense_layers[num_hidden_layers], name="hidden_layers")(concatenated_inputs)
  output = Dense(1, activation=output_activation, name="output")(hidden_layers)
  model = Model(
    inputs=[code_input, comment_input],
    outputs=output,
    name="embedding_comparator"
  )

  model.compile(
    optimizer=Adam(),
    loss=ConstrastiveLoss(),
    metrics=[
      BinaryAccuracy(),
      Precision(name="precision"),
      Recall(name="recall"),
      # f1_score, # TODO: Reactivate
    ],
  )

  return model

In [None]:
def generate_negative_pairs(pairs: List[MongoDbPairDoc]):
  negative_pairs = pairs.copy()
  random.shuffle(negative_pairs)
  return negative_pairs

def generate_dataset(samples_count = 1000, batch_size=100, partition: Partition = 'train') -> Iterator[Tuple[List[np.ndarray], np.ndarray]]:
  def get_targets_array(size: int):
    targets = np.empty((size, ))
    targets[::2] = 1
    targets[1::2] = 0
    return targets

  pairs = list(db_client.get_pairs_collection().find({ "language": "python", "partition": partition }).limit(samples_count))
  
  for batch_pairs in more_itertools.chunked(pairs, batch_size):
    if len(batch_pairs) != batch_size:
      continue

    negative_pairs = generate_negative_pairs(batch_pairs)
    code_embedings, comment_embeddings = [], []

    for pair, negative_pair in zip(batch_pairs, negative_pairs):
      code_emb, comment_emb = get_stored_embeddings(str(pair['_id']))
      negative_code_emb, negative_comment_emb = get_stored_embeddings(str(negative_pair['_id']))

      code_embedings.append(code_emb)
      comment_embeddings.append(comment_emb)
      
      code_embedings.append(code_emb)
      comment_embeddings.append(negative_comment_emb)
    
    stacked_code_embeddings = np.stack(code_embedings)
    stacked_comment_embeddings = np.stack(comment_embeddings)
    yield [stacked_code_embeddings, stacked_comment_embeddings], get_targets_array(stacked_code_embeddings.shape[0])

In [None]:
embedding_comparator = build_model(num_hidden_layers=2)
embedding_comparator.summary()

In [None]:
tensor_board_callback = callbacks.TensorBoard(log_dir=f"../logs/scalars/{datetime.now().strftime('%Y%m%d-%H%M%S')}")

for inputs, target in generate_dataset(samples_count=train_samples_count):
    embedding_comparator.fit(
        inputs, 
        target,
        shuffle=True,
        epochs=10,
        batch_size=inputs[0].shape[0],
        callbacks=[tensor_board_callback],
    )