In [119]:
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import trange, tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from info_nce import InfoNCE
import pyarrow as pa

### Training functions

In [120]:
class CustomDataset(TensorDataset):

    def __init__(self, dataframe, tokenizer):
        self.tokenizer = tokenizer
        self.doc = dataframe["Comment"]
        self.code = dataframe["Code"]
        # self.targets = dataframe.labels
        self.max_len = 256

    def __len__(self):
        assert len(self.doc) == len(self.code)
        return len(self.doc)

    def __getitem__(self, index):
        doc = str(self.doc[index])
        doc = " ".join(doc.split())

        code = str(self.code[index])
        doc_inputs = self.tokenizer.encode_plus(
            doc,
            add_special_tokens=False,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_token_type_ids=False
        )
        doc_ids = doc_inputs['input_ids']
        doc_mask = doc_inputs['attention_mask']

        code_inputs = self.tokenizer.encode_plus(
            code,
            add_special_tokens=False,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_token_type_ids=False
        )

        code_ids = code_inputs['input_ids']
        code_mask = code_inputs['attention_mask']

        return {
            'doc_ids': torch.tensor(doc_ids, dtype=torch.long),
            'doc_mask': torch.tensor(doc_mask, dtype=torch.long),
            'code_ids': torch.tensor(code_ids, dtype=torch.long),
            'code_mask': torch.tensor(code_mask, dtype=torch.long),
        }

def train(model, optimizer, training_set, epochs):
    print("Start training")
    model.train()
    train_dataloader = DataLoader(training_set, batch_size=32, shuffle=True)
    for epoch in tqdm(range(1, epochs + 1)):

        all_losses = []   
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch}')
        
        for idx, batch in enumerate(progress_bar):

            if batch['code_ids'].size(0) < 32:
                continue

            query_id = batch['doc_ids'].to(device)#.unsqueeze(0)
            query_mask = batch['doc_mask'].to(device)#.unsqueeze(0)
            inputs = {'input_ids': query_id[0,:].unsqueeze(0), 'attention_mask': query_mask[0,:].unsqueeze(0)}
            
            query = model(**inputs)
            #print("query shape: ", query[0].shape)
            #query = query.last_hidden_state.squeeze(0).mean(dim=0)
            #print("query shape: ", query.shape)
            
            code_ids = batch['code_ids'].to(device)
            code_masks = batch['code_mask'].to(device)
            

            inputs = {'input_ids': code_ids[0,:].unsqueeze(0), 'attention_mask': code_masks[0,:].unsqueeze(0)}
            positive_code_key = model(**inputs)
            #positive_code_key = positive_code_key.last_hidden_state.squeeze(0).mean(dim=0)
            #print("p code", positive_code_key.shape)
            

            inputs = {'input_ids': code_ids[1:], 'attention_mask': code_masks[1:]}
            negative_code_keys = model(**inputs)
            #negative_code_keys = negative_code_keys.last_hidden_state.mean(dim=1)
            #print("n code", negative_code_keys.shape)
            
            #break
            
            loss = info_nce_loss(query, positive_code_key, negative_code_keys)
            
            loss.backward()
            
            all_losses.append(loss.to("cpu").detach().numpy())

            if (idx + 1) % 2 == 0:
                optimizer.step()
                optimizer.zero_grad()
                      
        train_mean_loss = np.mean(all_losses)
        print(f'Epoch {epoch} - Train-Loss: {train_mean_loss}')

def info_nce_loss(query, positive_key, negative_keys):
    query = torch.flatten(query).view(-1).unsqueeze(0)
    positive_key = torch.flatten(positive_key).view(-1).unsqueeze(0)
    negative_keys = torch.stack([torch.flatten(code_emb).view(-1) for code_emb in negative_keys])
    return InfoNCE(negative_mode='unpaired')(query, positive_key, negative_keys)

### Eval functions

In [121]:
def create_embs(test_set, model):

    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
    
    model.to(device)
    
    representations = []
    # Process each input separately
    for input_text in test_set:
        # Tokenize without padding
        encoded_input = tokenizer(input_text, padding=False, truncation=True, max_length=256, return_tensors="pt")
        
        # Move input to the same device as the model
        encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}
        
        # Get model output
        with torch.no_grad():
            #output = model.encoder(**encoded_input) # For CodeT5+ last hidden layer
            output = model(**encoded_input) # Standard forwardin
        
        # Add to representations list
        representations.append(output.cpu().numpy())

    return np.array(representations)

def predict_distances(doc_embs, code_embs):
    all_distances = []
    progress_bar = tqdm(range(len(doc_embs)), desc="Compute Distances")
    for idx in progress_bar:
        # Sample Query
        query = doc_embs[idx]
        query = np.expand_dims(query, axis=0)
        
        # Sample Positive Code
        positive_code = code_embs[idx]
        positive_code = np.expand_dims(positive_code, axis=0)
        
        # Calculate Cosine distance for positive code
        positive_distance = calculate_cosine_distance(query, positive_code)
        
        # Calculate distances for all negative codes
        negative_distances = []
        for neg_idx, negative_code in enumerate(code_embs):
            if neg_idx != idx:  # Exclude the positive pair
                negative_code = np.expand_dims(negative_code, axis=0)
                distance = calculate_cosine_distance(query, negative_code)
                negative_distances.append(distance)
        
        # Combine positive and negative distances
        distances = [positive_distance] + negative_distances
        all_distances.append(distances)

    return all_distances

def calculate_mrr_from_distances(distances_lists):
    ranks = []
    for batch_idx, predictions in enumerate(distances_lists):
        correct_score = predictions[0]
        scores = np.array([prediction for prediction in predictions])
        rank = np.sum(scores <= correct_score)
        ranks.append(rank)
    mean_mrr = np.mean(1.0 / np.array(ranks))

    return mean_mrr

def calculate_cosine_distance(code_key, query):
    code_key = np.squeeze(code_key)
    query = np.squeeze(query)
    
    cosine_similarity = (np.dot(query, code_key) /
                         (np.linalg.norm(query) * np.linalg.norm(code_key)))
    # Compute cosine distance
    return 1 - cosine_similarity

### Execute

In [122]:
ds = load_dataset("drndr/statcodesearch")

In [123]:
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
dataset = ds["test"]

# Split the dataset into train and test sets with an 80/20 ratio
train_data, test_data = train_test_split(
    dataset, 
    test_size=0.2,    # 20% for testing
    train_size=0.8,   # 80% for training
    random_state=42   # For reproducible results
)

training_set = CustomDataset(train_data, tokenizer)
test_set = CustomDataset(test_data, tokenizer)


In [134]:
model = AutoModel.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

CodeT5pEmbeddingModel(
  (shared): Embedding(32103, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32103, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dropout(

### Pre-trained eval

In [125]:
code_embs = create_embs(test_data["Code"], model)
doc_embs = create_embs(test_data["Comment"], model)
all_distances = predict_distances(doc_embs, code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/214 [00:00<?, ?it/s]

0.5577294425478596


In [126]:
all_code_embs = create_embs(ds["test"]["Code"], model)
all_doc_embs = create_embs(ds["test"]["Comment"], model)
all_distances = predict_distances(all_doc_embs, all_code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/1070 [00:00<?, ?it/s]

0.410904967148927


### Fine-tuned eval

In [135]:
train(model, optimizer, training_set, 100)

Start training


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 1 - Train-Loss: 2.9733657836914062


Epoch 2:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 2 - Train-Loss: 3.0439810752868652


Epoch 3:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 3 - Train-Loss: 2.815943479537964


Epoch 4:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 4 - Train-Loss: 3.0919394493103027


Epoch 5:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 5 - Train-Loss: 2.8836183547973633


Epoch 6:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 6 - Train-Loss: 2.120537757873535


Epoch 7:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 7 - Train-Loss: 2.2722861766815186


Epoch 8:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 8 - Train-Loss: 2.326578378677368


Epoch 9:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 9 - Train-Loss: 1.8428643941879272


Epoch 10:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 10 - Train-Loss: 1.503125548362732


Epoch 11:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 11 - Train-Loss: 1.6889506578445435


Epoch 12:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 12 - Train-Loss: 1.8915958404541016


Epoch 13:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 13 - Train-Loss: 2.0917539596557617


Epoch 14:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 14 - Train-Loss: 1.4211101531982422


Epoch 15:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 15 - Train-Loss: 1.5350956916809082


Epoch 16:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 16 - Train-Loss: 1.5758650302886963


Epoch 17:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 17 - Train-Loss: 1.4516236782073975


Epoch 18:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 18 - Train-Loss: 1.7721606492996216


Epoch 19:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 19 - Train-Loss: 1.234879970550537


Epoch 20:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 20 - Train-Loss: 1.145817518234253


Epoch 21:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 21 - Train-Loss: 1.4163014888763428


Epoch 22:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 22 - Train-Loss: 1.183079719543457


Epoch 23:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 23 - Train-Loss: 1.4572244882583618


Epoch 24:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 24 - Train-Loss: 1.0373612642288208


Epoch 25:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 25 - Train-Loss: 1.3805421590805054


Epoch 26:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 26 - Train-Loss: 1.1584335565567017


Epoch 27:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 27 - Train-Loss: 1.4361605644226074


Epoch 28:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 28 - Train-Loss: 1.4778696298599243


Epoch 29:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 29 - Train-Loss: 0.8979343771934509


Epoch 30:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 30 - Train-Loss: 0.8618685007095337


Epoch 31:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 31 - Train-Loss: 0.9631727933883667


Epoch 32:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 32 - Train-Loss: 0.943940281867981


Epoch 33:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 33 - Train-Loss: 0.6244890689849854


Epoch 34:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 34 - Train-Loss: 1.1094568967819214


Epoch 35:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 35 - Train-Loss: 0.6984044313430786


Epoch 36:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 36 - Train-Loss: 0.8112542033195496


Epoch 37:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 37 - Train-Loss: 1.1778550148010254


Epoch 38:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 38 - Train-Loss: 0.9326366186141968


Epoch 39:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 39 - Train-Loss: 0.8686723113059998


Epoch 40:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 40 - Train-Loss: 0.93643718957901


Epoch 41:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 41 - Train-Loss: 0.6750185489654541


Epoch 42:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 42 - Train-Loss: 0.8052972555160522


Epoch 43:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 43 - Train-Loss: 0.7559515833854675


Epoch 44:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 44 - Train-Loss: 0.6679538488388062


Epoch 45:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 45 - Train-Loss: 0.953650176525116


Epoch 46:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 46 - Train-Loss: 0.7961544394493103


Epoch 47:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 47 - Train-Loss: 0.6182606816291809


Epoch 48:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 48 - Train-Loss: 0.9303345084190369


Epoch 49:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 49 - Train-Loss: 0.8115555644035339


Epoch 50:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 50 - Train-Loss: 0.7102044820785522


Epoch 51:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 51 - Train-Loss: 0.8598788380622864


Epoch 52:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 52 - Train-Loss: 0.4025196433067322


Epoch 53:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 53 - Train-Loss: 0.582079291343689


Epoch 54:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 54 - Train-Loss: 0.3322986960411072


Epoch 55:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 55 - Train-Loss: 0.6144028902053833


Epoch 56:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 56 - Train-Loss: 0.7335962057113647


Epoch 57:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 57 - Train-Loss: 0.34839102625846863


Epoch 58:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 58 - Train-Loss: 0.5544589161872864


Epoch 59:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 59 - Train-Loss: 0.7332513332366943


Epoch 60:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 60 - Train-Loss: 0.4195105731487274


Epoch 61:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 61 - Train-Loss: 0.5531988739967346


Epoch 62:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 62 - Train-Loss: 0.7552397847175598


Epoch 63:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 63 - Train-Loss: 0.4143303334712982


Epoch 64:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 64 - Train-Loss: 0.451160728931427


Epoch 65:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 65 - Train-Loss: 0.585671067237854


Epoch 66:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 66 - Train-Loss: 0.4660654664039612


Epoch 67:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 67 - Train-Loss: 0.6509288549423218


Epoch 68:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 68 - Train-Loss: 0.43307268619537354


Epoch 69:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 69 - Train-Loss: 0.3179006576538086


Epoch 70:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 70 - Train-Loss: 0.489341676235199


Epoch 71:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 71 - Train-Loss: 0.6692525148391724


Epoch 72:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 72 - Train-Loss: 0.40133577585220337


Epoch 73:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 73 - Train-Loss: 0.46304255723953247


Epoch 74:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 74 - Train-Loss: 0.3226430118083954


Epoch 75:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 75 - Train-Loss: 0.41696852445602417


Epoch 76:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 76 - Train-Loss: 0.47212931513786316


Epoch 77:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 77 - Train-Loss: 0.21958932280540466


Epoch 78:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 78 - Train-Loss: 0.293622761964798


Epoch 79:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 79 - Train-Loss: 0.5317126512527466


Epoch 80:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 80 - Train-Loss: 0.2843141257762909


Epoch 81:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 81 - Train-Loss: 0.12811098992824554


Epoch 82:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 82 - Train-Loss: 0.2430417686700821


Epoch 83:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 83 - Train-Loss: 0.19585858285427094


Epoch 84:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 84 - Train-Loss: 0.14572674036026


Epoch 85:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 85 - Train-Loss: 0.27383679151535034


Epoch 86:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 86 - Train-Loss: 0.28062567114830017


Epoch 87:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 87 - Train-Loss: 0.20561109483242035


Epoch 88:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 88 - Train-Loss: 0.32412979006767273


Epoch 89:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 89 - Train-Loss: 0.3823480010032654


Epoch 90:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 90 - Train-Loss: 0.2876257300376892


Epoch 91:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 91 - Train-Loss: 0.1914314329624176


Epoch 92:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 92 - Train-Loss: 0.22656740248203278


Epoch 93:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 93 - Train-Loss: 0.2977907359600067


Epoch 94:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 94 - Train-Loss: 0.2719506025314331


Epoch 95:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 95 - Train-Loss: 0.19415174424648285


Epoch 96:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 96 - Train-Loss: 0.2660572826862335


Epoch 97:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 97 - Train-Loss: 0.38374966382980347


Epoch 98:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 98 - Train-Loss: 0.14534829556941986


Epoch 99:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 99 - Train-Loss: 0.19305379688739777


Epoch 100:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 100 - Train-Loss: 0.1766578108072281


In [136]:
code_embs = create_embs(test_data["Code"], model)
doc_embs = create_embs(test_data["Comment"], model)
all_distances = predict_distances(doc_embs, code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/214 [00:00<?, ?it/s]

0.5913581378532725


In [137]:
all_code_embs = create_embs(ds["test"]["Code"], model)
all_doc_embs = create_embs(ds["test"]["Comment"], model)
all_distances = predict_distances(all_doc_embs, all_code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/1070 [00:00<?, ?it/s]

0.7471818138181374


In [138]:
df = pd.DataFrame(np.squeeze(all_code_embs, axis=1))
# Create Hugging Face Dataset
hf_dataset = Dataset.from_pandas(df)
# Save to disk
hf_dataset.save_to_disk("embeddings_code")

Saving the dataset (0/1 shards):   0%|          | 0/1070 [00:00<?, ? examples/s]

In [139]:
df = pd.DataFrame(np.squeeze(all_doc_embs, axis=1))
# Create Hugging Face Dataset
hf_dataset = Dataset.from_pandas(df)
# Save to disk
hf_dataset.save_to_disk("embeddings_doc")

Saving the dataset (0/1 shards):   0%|          | 0/1070 [00:00<?, ? examples/s]

In [140]:
torch.save(model.state_dict(), 'fine-tuned_codet5p.pth')