In [2]:
import logging
from transformers import AutoTokenizer, TFAutoModel, logging as transformers_logging
import tensorflow as tf
import os
import numpy as np
from tqdm import tqdm
from typing import Dict, Literal
from keras import Model, Sequential, backend as K
from keras.layers import Dense, Input, Concatenate, Dropout
from keras.losses import BinaryCrossentropy
from keras.metrics import BinaryAccuracy, Precision, Recall
from keras.optimizers import Adam
from typing import List
from mongo_db_client import MongoDbClient
from models import MongoDbPairDoc
import more_itertools
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
transformers_logging.set_verbosity_error()

## Generate embeddings from sentences

In [4]:
embedding_max_length = 256

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]), tf.float32)
    return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)

def generate_embeddings(sentences):
    tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
    model = TFAutoModel.from_pretrained("bert-large-uncased")
    
    encoded_input = tokenizer(
        sentences, 
        padding='max_length', 
        max_length=embedding_max_length, 
        truncation=True, 
        return_tensors='tf',
    )
    model_output = model(**encoded_input, return_dict=True)

    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    embeddings = tf.math.l2_normalize(embeddings, axis=1)
    return embeddings

In [5]:
def pre_process_tokens(tokens) -> str:
    parsed = ' '.join(tokens).replace('\n', ' ')
    parsed = ' '.join(parsed.strip().split())
    return parsed

In [6]:
db_client = MongoDbClient()

# Training phase

In [7]:
NumDenseLayers = Literal[2, 4, 8]
input_shape = (1024,) # TODO: Use variables
hidden_layer_activation = 'relu'
output_activation = 'sigmoid'
dense_layers: Dict[NumDenseLayers, List] = {
  2: [
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ],
  4: [
    Dense(400, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
  8: [
    Dense(800, activation=hidden_layer_activation),
    Dense(600, activation=hidden_layer_activation),
    Dense(500, activation=hidden_layer_activation),
    Dense(400, activation=hidden_layer_activation),
    Dense(300, activation=hidden_layer_activation),
    Dense(200, activation=hidden_layer_activation),
    Dense(100, activation=hidden_layer_activation),
    Dense(50, activation=hidden_layer_activation),
  ], 
}
dropout_rate=0.25

In [8]:
def build_model(num_hidden_layers: NumDenseLayers):
  code_input = Input(
    shape=input_shape,
    name="code_embedding_input",
  )
  comment_input = Input(
    shape=input_shape,
    name="comment_embedding_input",
  )
  concatenated_inputs = Concatenate(axis=1)([code_input, comment_input])
  dropout = Dropout(
    dropout_rate,
    name='dropout',
  )(concatenated_inputs)
  hidden_layers = Sequential(dense_layers[num_hidden_layers], name="hidden_layers")(dropout)
  output = Dense(1, activation=output_activation, name="output")(hidden_layers)
  model = Model(
    inputs=[code_input, comment_input],
    outputs=output,
  )

  threshold = 0.5
  
  model.compile(
    optimizer=Adam(),
    loss=BinaryCrossentropy(),
    metrics=[
      BinaryAccuracy(
        name=f"Accurary (with threshold of {threshold})", 
        threshold=threshold,
      ),
      Precision(thresholds=threshold),
      Recall(thresholds=threshold),
      # f1_score, # TODO: Reactivate
    ],
  )

  return model

In [9]:

def get_embeddings(pairs: List[MongoDbPairDoc]):
  codes = [pre_process_tokens(pair['code_tokens']) for pair in pairs]
  comments = [pre_process_tokens(pair['comment_tokens']) for pair in pairs]

  return [generate_embeddings(codes), generate_embeddings(comments)]


In [10]:
model = build_model(8)
num_samples = 100

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [11]:
def get_random_targets(batch_size: int, dropout = 0.1):
  array_shape = (batch_size)
  ones_array = np.ones(array_shape)

  random_mask = np.random.random(array_shape) < dropout
  return ones_array * (1 - random_mask)

In [12]:
# train_pairs = list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(batch_size))
# model.fit(x=get_embeddings(train_pairs), y=get_random_targets(batch_size=batch_size), epochs=20)

In [13]:
# valid_pairs = list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(batch_size))
# model.predict(get_embeddings(valid_pairs), batch_size=batch_size)

# Creating an embedding dataset

In [14]:
train_samples_count = 5000
test_samples_count = 1000
valid_samples_count = 1000

In [22]:
embedding_dataset_dir = '../datasets/embeddings/'

In [29]:
# '64aea2b37321b5c1dba81a3e'
len(os.listdir(embedding_dataset_dir))

7000

In [38]:
def save_embeddings_dataset(pairs: List[MongoDbPairDoc], batch_size = 100):
  stored_pairs_ids = { pair_id.replace('.npy', ''): "" for pair_id in os.listdir(embedding_dataset_dir) if pair_id.endswith('.npy') }
  
  new_pairs = [pair for pair in pairs if str(pair['_id']) not in stored_pairs_ids]

  for batch_pairs in more_itertools.chunked(new_pairs, batch_size):
    [code_embeddings, comment_embeddings] = get_embeddings(batch_pairs)
    for pair, code_embedding, comment_embedding in zip(batch_pairs, code_embeddings, comment_embeddings):
      np.save(os.path.join(embedding_dataset_dir, f'{pair["_id"]}.npy'), [code_embedding, comment_embedding])

def get_stored_embeddings(pair_id: str):
  return np.load(os.path.join(embedding_dataset_dir, f'{pair_id}.npy'))

def validate_embeddings_dataset(pairs: List[MongoDbPairDoc]):
  pairs_len = len(pairs)
  if pairs_len > 100:
    raise ValueError("The pairs length should be <= 100")

  random_index = random.randint(0, pairs_len - 1)
  [code_embeddings, comment_embeddings] = get_embeddings(pairs)
  [store_code_emb, store_comment_emb] = get_stored_embeddings(str(pairs[random_index]["_id"]))

  correct_indexes = []
  for index, (code_emb, comment_emb) in enumerate(zip(code_embeddings, comment_embeddings)):
    is_correct = np.array_equal(code_emb, store_code_emb) and np.array_equal(comment_emb, store_comment_emb)
    if is_correct:
      correct_indexes.append(index)
  
  return len(correct_indexes) == 1 and correct_indexes[0] == random_index

In [39]:
save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(train_samples_count)))
save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(test_samples_count)))
save_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(valid_samples_count)))

In [17]:
is_train_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "train", "language": "python" }).limit(100)))
is_test_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "test", "language": "python" }).limit(100)))
is_valid_correct = validate_embeddings_dataset(list(db_client.get_pairs_collection().find({ "partition": "valid", "language": "python" }).limit(100)))

print(f'is train dataset correct? {is_train_correct}') 
print(f'is test dataset correct? {is_test_correct}') 
print(f'is valid dataset correct? {is_valid_correct}') 

is train dataset correct? True
is test dataset correct? True
is valid dataset correct? True
