In [113]:
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import trange, tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from info_nce import InfoNCE
import pyarrow as pa

### Training functions

In [85]:
class CustomDataset(TensorDataset):

    def __init__(self, dataframe, tokenizer):
        self.tokenizer = tokenizer
        self.doc = dataframe["Comment"]
        self.code = dataframe["Code"]
        # self.targets = dataframe.labels
        self.max_len = 256

    def __len__(self):
        assert len(self.doc) == len(self.code)
        return len(self.doc)

    def __getitem__(self, index):
        doc = str(self.doc[index])
        doc = " ".join(doc.split())

        code = str(self.code[index])
        doc_inputs = self.tokenizer.encode_plus(
            doc,
            add_special_tokens=False,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_token_type_ids=False
        )
        doc_ids = doc_inputs['input_ids']
        doc_mask = doc_inputs['attention_mask']

        code_inputs = self.tokenizer.encode_plus(
            code,
            add_special_tokens=False,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_token_type_ids=False
        )

        code_ids = code_inputs['input_ids']
        code_mask = code_inputs['attention_mask']

        return {
            'doc_ids': torch.tensor(doc_ids, dtype=torch.long),
            'doc_mask': torch.tensor(doc_mask, dtype=torch.long),
            'code_ids': torch.tensor(code_ids, dtype=torch.long),
            'code_mask': torch.tensor(code_mask, dtype=torch.long),
        }

def train(model, optimizer, training_set, epochs):
    print("Start training")
    model.train()
    train_dataloader = DataLoader(training_set, batch_size=32, shuffle=True)
    for epoch in tqdm(range(1, epochs + 1)):

        all_losses = []   
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch}')
        
        for idx, batch in enumerate(progress_bar):

            if batch['code_ids'].size(0) < 32:
                continue

            query_id = batch['doc_ids'].to(device)#.unsqueeze(0)
            query_mask = batch['doc_mask'].to(device)#.unsqueeze(0)
            inputs = {'input_ids': query_id[0,:].unsqueeze(0), 'attention_mask': query_mask[0,:].unsqueeze(0)}
            
            query = model(**inputs)
            #print("query shape: ", query[0].shape)
            #query = query.last_hidden_state.squeeze(0).mean(dim=0)
            #print("query shape: ", query.shape)
            
            code_ids = batch['code_ids'].to(device)
            code_masks = batch['code_mask'].to(device)
            

            inputs = {'input_ids': code_ids[0,:].unsqueeze(0), 'attention_mask': code_masks[0,:].unsqueeze(0)}
            positive_code_key = model(**inputs)
            #positive_code_key = positive_code_key.last_hidden_state.squeeze(0).mean(dim=0)
            #print("p code", positive_code_key.shape)
            

            inputs = {'input_ids': code_ids[1:], 'attention_mask': code_masks[1:]}
            negative_code_keys = model(**inputs)
            #negative_code_keys = negative_code_keys.last_hidden_state.mean(dim=1)
            #print("n code", negative_code_keys.shape)
            
            #break
            
            loss = info_nce_loss(query, positive_code_key, negative_code_keys)
            
            loss.backward()
            
            all_losses.append(loss.to("cpu").detach().numpy())

            if (idx + 1) % 2 == 0:
                optimizer.step()
                optimizer.zero_grad()
                      
        train_mean_loss = np.mean(all_losses)
        print(f'Epoch {epoch} - Train-Loss: {train_mean_loss}')

def info_nce_loss(query, positive_key, negative_keys):
    query = torch.flatten(query).view(-1).unsqueeze(0)
    positive_key = torch.flatten(positive_key).view(-1).unsqueeze(0)
    negative_keys = torch.stack([torch.flatten(code_emb).view(-1) for code_emb in negative_keys])
    return InfoNCE(negative_mode='unpaired')(query, positive_key, negative_keys)

### Eval functions

In [86]:
def create_embs(test_set, model):

    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
    
    model.to(device)
    
    representations = []
    # Process each input separately
    for input_text in test_set:
        # Tokenize without padding
        encoded_input = tokenizer(input_text, padding=False, truncation=True, max_length=256, return_tensors="pt")
        
        # Move input to the same device as the model
        encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}
        
        # Get model output
        with torch.no_grad():
            #output = model.encoder(**encoded_input) # For CodeT5+ last hidden layer
            output = model(**encoded_input) # Standard forwardin
        
        # Add to representations list
        representations.append(output.cpu().numpy())

    return np.array(representations)

def predict_distances(doc_embs, code_embs):
    all_distances = []
    progress_bar = tqdm(range(len(doc_embs)), desc="Compute Distances")
    for idx in progress_bar:
        # Sample Query
        query = doc_embs[idx]
        query = np.expand_dims(query, axis=0)
        
        # Sample Positive Code
        positive_code = code_embs[idx]
        positive_code = np.expand_dims(positive_code, axis=0)
        
        # Calculate Cosine distance for positive code
        positive_distance = calculate_cosine_distance(query, positive_code)
        
        # Calculate distances for all negative codes
        negative_distances = []
        for neg_idx, negative_code in enumerate(code_embs):
            if neg_idx != idx:  # Exclude the positive pair
                negative_code = np.expand_dims(negative_code, axis=0)
                distance = calculate_cosine_distance(query, negative_code)
                negative_distances.append(distance)
        
        # Combine positive and negative distances
        distances = [positive_distance] + negative_distances
        all_distances.append(distances)

    return all_distances

def calculate_mrr_from_distances(distances_lists):
    ranks = []
    for batch_idx, predictions in enumerate(distances_lists):
        correct_score = predictions[0]
        scores = np.array([prediction for prediction in predictions])
        rank = np.sum(scores <= correct_score)
        ranks.append(rank)
    mean_mrr = np.mean(1.0 / np.array(ranks))

    return mean_mrr

def calculate_cosine_distance(code_key, query):
    code_key = np.squeeze(code_key)
    query = np.squeeze(query)
    
    cosine_similarity = (np.dot(query, code_key) /
                         (np.linalg.norm(query) * np.linalg.norm(code_key)))
    # Compute cosine distance
    return 1 - cosine_similarity

### Execute

In [87]:
ds = load_dataset("drndr/statcodesearch")

In [88]:
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
dataset = ds["test"]

# Split the dataset into train and test sets with an 80/20 ratio
train_data, test_data = train_test_split(
    dataset, 
    test_size=0.2,    # 20% for testing
    train_size=0.8,   # 80% for training
    random_state=42   # For reproducible results
)

training_set = CustomDataset(train_data, tokenizer)
test_set = CustomDataset(test_data, tokenizer)


In [96]:
model = AutoModel.from_pretrained("Salesforce/codet5p-110m-embedding", trust_remote_code=True)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

CodeT5pEmbeddingModel(
  (shared): Embedding(32103, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32103, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dropout(

### Pre-trained eval

In [90]:
code_embs = create_embs(test_data["Code"], model)
doc_embs = create_embs(test_data["Comment"], model)
all_distances = predict_distances(doc_embs, code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/214 [00:00<?, ?it/s]

0.5577294425478596


In [None]:
all_code_embs = create_embs(ds["test"]["Code"], model)
all_doc_embs = create_embs(ds["test"]["Comment"], model)
all_distances = predict_distances(all_doc_embs, all_code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

### Fine-tuned eval

In [97]:
train(model, optimizer, training_set, 100)

Start training


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 1 - Train-Loss: 3.333803653717041


Epoch 2:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 2 - Train-Loss: 2.9805564880371094


Epoch 3:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 3 - Train-Loss: 3.1251726150512695


Epoch 4:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 4 - Train-Loss: 3.1930811405181885


Epoch 5:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 5 - Train-Loss: 3.0150060653686523


Epoch 6:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 6 - Train-Loss: 2.9117672443389893


Epoch 7:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 7 - Train-Loss: 2.7590246200561523


Epoch 8:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 8 - Train-Loss: 2.8775782585144043


Epoch 9:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 9 - Train-Loss: 3.0121967792510986


Epoch 10:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 10 - Train-Loss: 3.2797117233276367


Epoch 11:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 11 - Train-Loss: 2.819932460784912


Epoch 12:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 12 - Train-Loss: 2.345508098602295


Epoch 13:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 13 - Train-Loss: 3.0075106620788574


Epoch 14:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 14 - Train-Loss: 2.7716331481933594


Epoch 15:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 15 - Train-Loss: 2.586489200592041


Epoch 16:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 16 - Train-Loss: 2.3833322525024414


Epoch 17:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 17 - Train-Loss: 2.887766122817993


Epoch 18:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 18 - Train-Loss: 2.337655782699585


Epoch 19:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 19 - Train-Loss: 2.595944881439209


Epoch 20:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 20 - Train-Loss: 2.5058224201202393


Epoch 21:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 21 - Train-Loss: 2.2859244346618652


Epoch 22:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 22 - Train-Loss: 2.114027261734009


Epoch 23:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 23 - Train-Loss: 2.2905092239379883


Epoch 24:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 24 - Train-Loss: 2.162140130996704


Epoch 25:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 25 - Train-Loss: 2.1588480472564697


Epoch 26:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 26 - Train-Loss: 2.146881580352783


Epoch 27:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 27 - Train-Loss: 1.8629086017608643


Epoch 28:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 28 - Train-Loss: 2.188722610473633


Epoch 29:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 29 - Train-Loss: 1.7868626117706299


Epoch 30:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 30 - Train-Loss: 2.1986958980560303


Epoch 31:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 31 - Train-Loss: 2.0080010890960693


Epoch 32:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 32 - Train-Loss: 1.8645952939987183


Epoch 33:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 33 - Train-Loss: 2.047682285308838


Epoch 34:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 34 - Train-Loss: 1.6535433530807495


Epoch 35:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 35 - Train-Loss: 1.5718255043029785


Epoch 36:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 36 - Train-Loss: 2.035331964492798


Epoch 37:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 37 - Train-Loss: 1.6929466724395752


Epoch 38:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 38 - Train-Loss: 1.5265356302261353


Epoch 39:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 39 - Train-Loss: 1.6260502338409424


Epoch 40:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 40 - Train-Loss: 1.7361563444137573


Epoch 41:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 41 - Train-Loss: 2.051568031311035


Epoch 42:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 42 - Train-Loss: 1.8865412473678589


Epoch 43:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 43 - Train-Loss: 1.6909018754959106


Epoch 44:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 44 - Train-Loss: 2.1260745525360107


Epoch 45:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 45 - Train-Loss: 1.6366995573043823


Epoch 46:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 46 - Train-Loss: 1.3810169696807861


Epoch 47:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 47 - Train-Loss: 1.6005213260650635


Epoch 48:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 48 - Train-Loss: 1.4099675416946411


Epoch 49:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 49 - Train-Loss: 1.4107741117477417


Epoch 50:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 50 - Train-Loss: 1.4307223558425903


Epoch 51:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 51 - Train-Loss: 1.0825998783111572


Epoch 52:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 52 - Train-Loss: 1.6597869396209717


Epoch 53:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 53 - Train-Loss: 1.4671465158462524


Epoch 54:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 54 - Train-Loss: 1.5613152980804443


Epoch 55:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 55 - Train-Loss: 1.5277483463287354


Epoch 56:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 56 - Train-Loss: 1.3590056896209717


Epoch 57:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 57 - Train-Loss: 1.4076054096221924


Epoch 58:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 58 - Train-Loss: 1.4017257690429688


Epoch 59:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 59 - Train-Loss: 1.5180257558822632


Epoch 60:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 60 - Train-Loss: 1.3366392850875854


Epoch 61:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 61 - Train-Loss: 1.4837368726730347


Epoch 62:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 62 - Train-Loss: 1.708996295928955


Epoch 63:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 63 - Train-Loss: 0.9624516367912292


Epoch 64:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 64 - Train-Loss: 1.1495753526687622


Epoch 65:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 65 - Train-Loss: 1.4227619171142578


Epoch 66:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 66 - Train-Loss: 1.0918184518814087


Epoch 67:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 67 - Train-Loss: 1.2618380784988403


Epoch 68:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 68 - Train-Loss: 1.1937867403030396


Epoch 69:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 69 - Train-Loss: 1.1057435274124146


Epoch 70:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 70 - Train-Loss: 1.3719518184661865


Epoch 71:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 71 - Train-Loss: 0.9836402535438538


Epoch 72:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 72 - Train-Loss: 1.107199788093567


Epoch 73:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 73 - Train-Loss: 1.0933256149291992


Epoch 74:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 74 - Train-Loss: 0.9569829702377319


Epoch 75:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 75 - Train-Loss: 0.8633266687393188


Epoch 76:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 76 - Train-Loss: 1.0278910398483276


Epoch 77:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 77 - Train-Loss: 1.020966649055481


Epoch 78:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 78 - Train-Loss: 0.8389298915863037


Epoch 79:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 79 - Train-Loss: 0.8567149639129639


Epoch 80:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 80 - Train-Loss: 1.0628442764282227


Epoch 81:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 81 - Train-Loss: 1.2424830198287964


Epoch 82:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 82 - Train-Loss: 1.0779798030853271


Epoch 83:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 83 - Train-Loss: 1.1212207078933716


Epoch 84:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 84 - Train-Loss: 0.9049754738807678


Epoch 85:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 85 - Train-Loss: 0.9595184922218323


Epoch 86:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 86 - Train-Loss: 0.808447003364563


Epoch 87:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 87 - Train-Loss: 0.952498197555542


Epoch 88:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 88 - Train-Loss: 1.2326960563659668


Epoch 89:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 89 - Train-Loss: 0.8406128287315369


Epoch 90:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 90 - Train-Loss: 1.0253373384475708


Epoch 91:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 91 - Train-Loss: 0.9775577187538147


Epoch 92:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 92 - Train-Loss: 0.7623006105422974


Epoch 93:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 93 - Train-Loss: 0.685972273349762


Epoch 94:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 94 - Train-Loss: 0.7220814824104309


Epoch 95:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 95 - Train-Loss: 0.7484797239303589


Epoch 96:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 96 - Train-Loss: 1.1500318050384521


Epoch 97:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 97 - Train-Loss: 0.7597909569740295


Epoch 98:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 98 - Train-Loss: 0.8757592439651489


Epoch 99:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 99 - Train-Loss: 0.5764246582984924


Epoch 100:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch 100 - Train-Loss: 0.8005322217941284


In [99]:
code_embs = create_embs(test_data["Code"], model)
doc_embs = create_embs(test_data["Comment"], model)
all_distances = predict_distances(doc_embs, code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/214 [00:00<?, ?it/s]

0.6035769910504865


In [104]:
all_code_embs = create_embs(ds["test"]["Code"], model)
all_doc_embs = create_embs(ds["test"]["Comment"], model)
all_distances = predict_distances(all_doc_embs, all_code_embs)
mrr = calculate_mrr_from_distances(all_distances)
print(mrr)

Compute Distances:   0%|          | 0/1070 [00:00<?, ?it/s]

0.5372501951146463


In [116]:
df = pd.DataFrame(np.squeeze(all_code_embs, axis=1))
# Create Hugging Face Dataset
hf_dataset = Dataset.from_pandas(df)
# Save to disk
hf_dataset.save_to_disk("embeddings_code")

Saving the dataset (0/1 shards):   0%|          | 0/1070 [00:00<?, ? examples/s]

In [117]:
df = pd.DataFrame(np.squeeze(all_doc_embs, axis=1))
# Create Hugging Face Dataset
hf_dataset = Dataset.from_pandas(df)
# Save to disk
hf_dataset.save_to_disk("embeddings_doc")

Saving the dataset (0/1 shards):   0%|          | 0/1070 [00:00<?, ? examples/s]