In [1]:
import torch
import numpy as np
import logging
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from typing import List, Dict
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class YourCustomDEModel:
    def __init__(self, model_name="intfloat/e5-base-v2", **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model_name = model_name
        self.tokenizer.add_eos_token = False

        print('YourCustomDEModel init')


    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def cls_pooling(self, model_output, attention_mask):
        # First element of model_output contains all token embeddings
        token_embeddings = model_output[0]
        # Extract the CLS token's embeddings (index 0) for each sequence in the batch
        cls_embeddings = token_embeddings[:, 0, :]
        return cls_embeddings

    def last_token_pool(self, model_output, attention_mask):
        last_hidden_states = model_output.last_hidden_state
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

    def encode_text(self, texts: List[str], batch_size: int = 12, max_length: int = 128) -> np.ndarray:
        logging.info(f"Encoding {len(texts)} texts...")

        embeddings = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Encoding batches", unit="batch"):
            batch_texts = texts[i:i+batch_size]
            encoded_input = self.tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
            with torch.no_grad():
                model_output = self.model(**encoded_input)
            batch_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
            embeddings.append(batch_embeddings.cpu())

        embeddings = torch.cat(embeddings, dim=0)

        if embeddings is None:
            logging.error("Embeddings are None.")
        else:
            logging.info(f"Encoded {len(embeddings)} embeddings.")

        return embeddings.numpy()

    def encode_queries(self, queries: List[str], batch_size: int = 12, max_length: int = 512, **kwargs) -> np.ndarray:
        all_queries = ["query: "+ query for query in queries]
        return self.encode_text(all_queries, batch_size, max_length)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 12, max_length: int = 512, **kwargs) -> np.ndarray:
        all_texts = ["passage: "+ doc['text'] for doc in corpus]
        #all_texts = ["passage: "+ doc for doc in corpus]

        return self.encode_text(all_texts, batch_size, max_length)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
temp = YourCustomDEModel()

YourCustomDEModel init


In [None]:
# def encode_multiple_exp(texts, method):
#     embeddings = temp.encode_text(texts)
#     print(embeddings.shape)

#     if method == "mean":
#         return np.mean(embeddings, axis=0)
#     elif method == "max":
#         return np.max(embeddings, axis=0)
#     elif method == "sum":
#         return np.sum(embeddings, axis=0)
#     else:
#         raise ValueError(f"Unsupported aggregation method: {method}")
    

# df = pd.read_csv('/work/pi_wenlongzhao_umass_edu/27/anamikaghosh/CS696DS-Oracle-Retrieving-Code-Explanations/Explanation_Generation/postprocessing/output/CSN_deepseek_valid_clean.csv')
# df1 = df.head(10)
# exp_list = df1[['explanation_deepseek_1_cleaned','explanation_deepseek_2_cleaned']].values.tolist()
# # emb_list = [encode_multiple_exp(list, method='mean') for list in exp_list]

In [9]:
import time
from itertools import chain
import pandas as pd

start = time.time()

df1 = pd.read_csv('/work/pi_wenlongzhao_umass_edu/27/anamikaghosh/CS696DS-Oracle-Retrieving-Code-Explanations/Explanation_Generation/postprocessing/output/CSN_deepseek_valid_clean.csv')

def encode_agg(df, exp_cols, method, model_obj):
    exp_list = df[exp_cols].values.tolist()
    flat_texts = list(chain.from_iterable(exp_list))  # total N = num_rows × num_explanations

    all_embeddings = model_obj.encode_text(flat_texts)  # shape: (N, emb_dim)

    num_rows = len(exp_list)
    num_exps = len(exp_list[0])
    all_embeddings = all_embeddings.reshape(num_rows, num_exps, -1)

    # Aggregate in batch using NumPy
    if method == "mean":
        emb_list = np.mean(all_embeddings, axis=1)
    elif method == "max":
        emb_list = np.max(all_embeddings, axis=1)
    elif method == "sum":
        emb_list = np.sum(all_embeddings, axis=1)
    else:
        raise ValueError("Unsupported method")
    end = time.time()
    print(f'Total time = {end-start} seconds')

    return emb_list

In [11]:
df1.columns

Index(['Unnamed: 0', 'query_id', 'corpus_id', 'doc', 'code', 'cleaned_code',
       'explanation_deepseek_1', 'explanation_deepseek_2',
       'explanation_deepseek_3', 'explanation_deepseek_4',
       'explanation_deepseek_5', 'explanation_deepseek_1_cleaned',
       'explanation_deepseek_2_cleaned', 'explanation_deepseek_3_cleaned',
       'explanation_deepseek_4_cleaned', 'explanation_deepseek_5_cleaned'],
      dtype='object')

In [None]:
df1['mean_emb'] = list(encode_agg(df1, ['explanation_deepseek_1_cleaned','explanation_deepseek_2_cleaned'], 'mean', temp))

Encoding batches: 100%|██████████| 2298/2298 [00:37<00:00, 61.29batch/s]


Total time = 255.66733860969543 seconds


In [16]:
df1.head()

Unnamed: 0.1,Unnamed: 0,query_id,corpus_id,doc,code,cleaned_code,explanation_deepseek_1,explanation_deepseek_2,explanation_deepseek_3,explanation_deepseek_4,explanation_deepseek_5,explanation_deepseek_1_cleaned,explanation_deepseek_2_cleaned,explanation_deepseek_3_cleaned,explanation_deepseek_4_cleaned,explanation_deepseek_5_cleaned,mean_emb
0,0,q251820,c251820,Save model to a pickle located at `path`,"def save_act(self, path=None):\n """"""Sav...","\ndef save_act(self, path=None):\n\n if...",The code saves a model to a file and then uses...,"Okay, I'm trying to understand this code snipp...",This code snippet does something to save and s...,The code snippet is a method called save_act w...,The code in entry['code'] sets the path for wh...,The code saves a model to a file and then uses...,"Okay, I'm trying to understand this code snipp...",This code snippet does something to save and s...,The code snippet is a method called save_act w...,The code in entry['code'] sets the path for wh...,"[-0.16336083, -0.67789125, -0.08247605, 0.0169..."
1,1,q251821,c251821,CNN from Nature paper.,"def nature_cnn(unscaled_images, **conv_kwargs)...","\ndef nature_cnn(unscaled_images, **conv_kwarg...",The code converts the input images into a scal...,"Okay, I'm going to try to break down this code...",The code snippet does the following:\n\n1. It ...,The code snippet is a function called nature_c...,The code implements a deep neural network to p...,The code converts the input images into a scal...,"Okay, I'm going to try to break down this code...",The code snippet does the following:1. It take...,The code snippet is a function called nature_c...,The code implements a deep neural network to p...,"[-0.13151847, -0.74018204, -0.54601645, -0.100..."
2,2,q251822,c251822,convolutions-only net\n\n Parameters:\n ...,"def conv_only(convs=[(32, 8, 4), (64, 4, 2), (...","\ndef conv_only(convs=[(32, 8, 4), (64, 4, 2),...","The code defines a function called conv_only, ...","Okay, so I'm trying to understand this code sn...",This code creates a function called conv_only ...,The code defines a function called conv_only t...,The code snippet in entry['code'] implements a...,"The code defines a function called conv_only, ...","Okay, so I'm trying to understand this code sn...",This code creates a function called conv_only ...,The code defines a function called conv_only t...,The code snippet in entry['code'] implements a...,"[0.0032902882, -0.7050658, -0.36576706, -0.259..."
3,3,q251823,c251823,"Create a wrapped, monitored SubprocVecEnv for ...","def make_vec_env(env_id, env_type, num_env, se...","\ndef make_vec_env(env_id, env_type, num_env, ...",The code creates a vector environment for mult...,"Okay, I'm trying to understand this code snipp...",This code creates a function called make_vec_e...,The code defines a function called make_vec_en...,The code snippet in entry['code'] is responsib...,The code creates a vector environment for mult...,"Okay, I'm trying to understand this code snipp...",This code creates a function called make_vec_e...,The code defines a function called make_vec_en...,The code snippet in entry['code'] is responsib...,"[0.16048774, -0.76488286, -0.4442173, -0.01827..."
4,4,q251824,c251824,Parse arguments not consumed by arg parser int...,"def parse_unknown_args(args):\n """"""\n Pa...",\ndef parse_unknown_args(args):\n\n retval ...,The code parses unknown arguments by looking f...,... (this...)\n</think>\n\ndef parse_unknown_a...,This code is used to parse unknown arguments. ...,def parse_unknown_args(args):\n def parse_u...,Please think about the connection between code...,The code parses unknown arguments by looking f...,... (this...)def parse_unknown_args(args): ...,This code is used to parse unknown arguments. ...,def parse_unknown_args(args): def parse_unk...,Please think about the connection between code...,"[-0.1587004, -0.40531266, -0.6960724, -0.21294..."
