In [38]:
import logging
from datetime import datetime
from transformers import AutoTokenizer, TFAutoModel, logging as transformers_logging
import tensorflow as tf
import os
import numpy as np
from tqdm import tqdm
from typing import Dict, Literal, List
from keras import Model, Sequential, callbacks
from keras.layers import Dense, Input, Concatenate, Dot
from keras.losses import Loss
from keras.utils import losses_utils
from keras.metrics import BinaryAccuracy, Precision, Recall
from keras.optimizers import Adam
from mongo_db_client import MongoDbClient
from models import MongoDbPairDoc
import more_itertools
import random

In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
transformers_logging.set_verbosity_error()

# Utils

In [3]:
random.seed(42)

def shuffle(list: List) -> List:
  shuffled_list = list.copy()
  random.shuffle(shuffled_list)
  return shuffled_list

## Generate embeddings from sentences

In [4]:
embedding_max_length = 256

def pre_process_tokens(tokens) -> str:
    parsed = ' '.join(tokens).replace('\n', ' ')
    parsed = ' '.join(parsed.strip().split())
    return parsed

# TODO: rename to 'generate_embeddings'
def get_embeddings(pairs: List[MongoDbPairDoc]):
  codes = [pre_process_tokens(pair['code_tokens']) for pair in pairs]
  comments = [pre_process_tokens(pair['comment_tokens']) for pair in pairs]

  return [generate_embeddings(codes), generate_embeddings(comments)]

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]), tf.float32)
    return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)

# TODO: rename to generate_sentences_embeddings
def generate_embeddings(sentences: List[str]):
    tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
    model = TFAutoModel.from_pretrained("bert-large-uncased")
    
    encoded_input = tokenizer(
        sentences, 
        padding='max_length', 
        max_length=embedding_max_length, 
        truncation=True, 
        return_tensors='tf',
    )
    model_output = model(**encoded_input, return_dict=True)

    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    embeddings = tf.math.l2_normalize(embeddings, axis=1)
    return embeddings

# Create an embedding dataset

In [5]:
train_samples_count = 5000
test_samples_count = 1000
valid_samples_count = 1000
embedding_dataset_dir = '../datasets/embeddings/'
db_client = MongoDbClient()

In [6]:
# TODO: remove unused functions
def save_embeddings_dataset(pairs: List[MongoDbPairDoc], batch_size = 100):
  stored_pairs_ids = { pair_id.replace('.npy', ''): "" for pair_id in os.listdir(embedding_dataset_dir) if pair_id.endswith('.npy') }
  new_pairs = [pair for pair in pairs if str(pair['_id']) not in stored_pairs_ids]

  with tqdm(total=len(new_pairs), desc=f"Saving {len(new_pairs)} pairs into embedding dataset") as progress_bar:
    for batch_pairs in more_itertools.chunked(new_pairs, batch_size):
      [code_embeddings, comment_embeddings] = get_embeddings(batch_pairs)
      for pair, code_embedding, comment_embedding in zip(batch_pairs, code_embeddings, comment_embeddings):
        np.save(os.path.join(embedding_dataset_dir, f'{pair["_id"]}.npy'), [code_embedding, comment_embedding])
      progress_bar.update(len(batch_pairs))

def get_stored_embeddings(pair_id: str):
  return np.load(os.path.join(embedding_dataset_dir, f'{pair_id}.npy'))

def validate_embeddings_dataset(pairs: List[MongoDbPairDoc]):
  pairs_len = len(pairs)
  if pairs_len > 100:
    raise ValueError("The pairs length should be <= 100")

  random_index = random.randint(0, pairs_len - 1)
  [code_embeddings, comment_embeddings] = get_embeddings(pairs)
  [store_code_emb, store_comment_emb] = get_stored_embeddings(str(pairs[random_index]["_id"]))

  correct_indexes = []
  for index, (code_emb, comment_emb) in enumerate(zip(code_embeddings, comment_embeddings)):
    is_correct = np.array_equal(code_emb, store_code_emb) and np.array_equal(comment_emb, store_comment_emb)
    if is_correct:
      correct_indexes.append(index)
  
  return len(correct_indexes) == 1 and correct_indexes[0] == random_index

In [7]:
def create_tf_dataset(pairs_ids: List[str], for_training = True) -> tf.data.Dataset:
  def dataset_generator():
    if for_training:
      negative_pairs_ids = shuffle(pairs_ids)
      for pair_id, negative_pair_id in zip(pairs_ids, negative_pairs_ids):
        [code_embedding, comment_embedding] = get_stored_embeddings(pair_id)
        [_, negative_comment_embedding] = get_stored_embeddings(negative_pair_id)

        yield {
          "code_embedding": code_embedding,
          "comment_embedding": comment_embedding,
        }, 1.0

        yield {
          "code_embedding": code_embedding,
          "comment_embedding": negative_comment_embedding,
        }, 0.0
    else:
      for pair_id in pairs_ids:
        [code_embedding, comment_embedding] = get_stored_embeddings(pair_id)
        yield {
          "code_embedding": code_embedding,
          "comment_embedding": comment_embedding,
        }
  
  training_output_types = ({
    "code_embedding": tf.float32, 
    "comment_embedding": tf.float32,
  }, tf.float16)
  test_output_types = {
    "code_embedding": tf.float32, 
    "comment_embedding": tf.float32,
  }

  return tf.data.Dataset.from_generator(dataset_generator, output_types=training_output_types if for_training else test_output_types) 

In [8]:
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(train_samples_count)))
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(test_samples_count)))
# save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(valid_samples_count)))

In [9]:
# is_train_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(100)))
# is_test_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(100)))
# is_valid_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(100)))

# print(f'is train dataset correct? {is_train_correct}') 
# print(f'is test dataset correct? {is_test_correct}') 
# print(f'is valid dataset correct? {is_valid_correct}') 

# Training the model

In [10]:
NumDenseLayers = Literal[2, 4, 8]
input_shape = (1024) # TODO: Use variables
hidden_layer_activation = 'tanh'
output_activation = 'sigmoid'
dense_layers: Dict[NumDenseLayers, List] = {
  2: [
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ],
  4: [
    Dense(400, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
  8: [
    Dense(800, activation=hidden_layer_activation),
    Dense(600, activation=hidden_layer_activation),
    Dense(500, activation=hidden_layer_activation),
    Dense(400, activation=hidden_layer_activation),
    Dense(300, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
}
dropout_rate=0.1

In [11]:
class ConstrastiveLoss(Loss):
   def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name="constrastive_loss", margin=1):
      self.margin = margin
      super().__init__(reduction, name)

   def call(self, y_true, y_pred):
      square_pred = tf.math.square(y_pred)
      margin_square = tf.math.square(tf.math.maximum(self.margin - (y_pred), 0))
      return tf.math.reduce_mean(
        (1 - y_true) * square_pred + (y_true) * margin_square
      )

def build_model(num_hidden_layers: NumDenseLayers):
  code_input = Input(
    shape=input_shape,
    name="code_embedding",
  )
  comment_input = Input(
    shape=input_shape,
    name="comment_embedding",
  )

  concatenated_inputs = Concatenate()([code_input, comment_input])
  hidden_layers = Sequential(dense_layers[num_hidden_layers], name="hidden_layers")(concatenated_inputs)
  output = Dense(1, activation=output_activation, name="output")(hidden_layers)
  model = Model(
    inputs=[code_input, comment_input],
    outputs=output,
    name="embedding_comparator"
  )

  model.compile(
    optimizer=Adam(),
    loss=ConstrastiveLoss(),
    metrics=[
      BinaryAccuracy(),
      Precision(name="precision"),
      Recall(name="recall"),
      # f1_score, # TODO: Reactivate
    ],
  )

  return model

In [12]:
embedding_comparator = build_model(num_hidden_layers=2)
tensor_board_callback = callbacks.TensorBoard(log_dir=f"../logs/scalars/{datetime.now().strftime('%Y%m%d-%H%M%S')}")

train_pairs = [str(pair['_id']) for pair in list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(train_samples_count))]
valid_pairs = [str(pair['_id']) for pair in list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(valid_samples_count))]

pairs_dataset = create_tf_dataset(train_pairs).shuffle(buffer_size=int(train_samples_count * 0.4)).batch(100)
valid_pairs = create_tf_dataset(valid_pairs).batch(100)

# results = embedding_comparator.fit(
#     pairs_dataset,
#     validation_data=valid_pairs,
#     epochs=10,
#     callbacks=[tensor_board_callback],
# )

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [41]:
def euclidean_distance(vects):
  [x, y] = vects
  sum_square = tf.math.reduce_sum(tf.math.square(x - y), axis=0)
  distance = tf.math.sqrt(tf.math.maximum(sum_square, tf.keras.backend.epsilon()))
  return distance

def build_siamese_model():
  code_embedding_input = Input(
    shape=input_shape,
    name="code_embedding",
  )
  comment_embedding_input = Input(
    shape=input_shape,
    name="comment_embedding",
  )
  similarity_score = Dot(normalize=True, axes=1)([code_embedding_input, comment_embedding_input])

  # normal_layer = BatchNormalization()(concatenated_inputs)
  output_layer = Dense(1, activation="sigmoid")(similarity_score)
  model = Model(inputs=[code_embedding_input, comment_embedding_input], outputs=output_layer, name="siamese_model")
  model.compile(
    optimizer=Adam(),
    loss=ConstrastiveLoss(),
    metrics=[
      BinaryAccuracy(),
      Precision(name="precision"),
      Recall(name="recall"),
      # f1_score, # TODO: Reactivate
    ],
  )
  return model

In [42]:
siamese_model = build_siamese_model()
siamese_model.fit(
    pairs_dataset,
    validation_data=valid_pairs,
    epochs=10,
    callbacks=[tensor_board_callback],
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x298f58a60>

In [43]:
test_pairs = [str(pair['_id']) for pair in list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(test_samples_count))]
test_dataset = create_tf_dataset(test_pairs, for_training=False).batch(100)

predicts = siamese_model.predict(test_dataset)



In [44]:
predicts

array([[0.49001256],
       [0.48823547],
       [0.49016502],
       [0.49716097],
       [0.5032835 ],
       [0.4932924 ],
       [0.5113383 ],
       [0.49203396],
       [0.49851727],
       [0.5054143 ],
       [0.49279276],
       [0.49300584],
       [0.49157286],
       [0.49411175],
       [0.49608177],
       [0.49472323],
       [0.5002973 ],
       [0.49251303],
       [0.49543485],
       [0.49366418],
       [0.4883237 ],
       [0.49596703],
       [0.49617872],
       [0.49846917],
       [0.49569196],
       [0.49192384],
       [0.49165168],
       [0.49066103],
       [0.49273133],
       [0.49222624],
       [0.4952268 ],
       [0.49444643],
       [0.5023585 ],
       [0.49322832],
       [0.50029457],
       [0.4948389 ],
       [0.49472323],
       [0.49787334],
       [0.49762088],
       [0.49765688],
       [0.49834386],
       [0.5005353 ],
       [0.4942247 ],
       [0.49851975],
       [0.49824095],
       [0.500186  ],
       [0.49653998],
       [0.492