In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from tqdm import tqdm
import copy
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import json
import torch.nn as nn
from c2nl.eval.bleu import corpus_bleu, compute_bleu
from c2nl.eval.rouge import Rouge
from c2nl.eval.meteor import Meteor
from c2nl.eval.distinct_n.distinct_ngrams import distinct_n_corpus_level
from c2nl.eval.self_bleu import self_bleu_score
from itertools import islice, zip_longest, chain
from datasets import load_dataset, load_from_disk, Dataset
from transformers import AutoModel, AutoTokenizer, TrainingArguments, Trainer, get_linear_schedule_with_warmup, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
from evaluate import load
from transformers.models.t5.modeling_t5 import T5Stack, T5LayerCrossAttention, T5LayerNorm, T5LayerFF
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType, PeftModel, PeftConfig#, PeftModelForSeq2SeqLM
from peft.utils import (
    SAFETENSORS_WEIGHTS_NAME,
    TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
    WEIGHTS_NAME,
    PeftType,
    TaskType,
    _get_batch_size,
    _prepare_prompt_learning_config,
    _set_adapter,
    _set_trainable,
    get_peft_model_state_dict,
    id_tensor_storage,
    infer_device,
    load_peft_weights,
    set_peft_model_state_dict,
    shift_tokens_right,
)
import warnings
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import numpy as np
import random
import re
from typing import Any, Optional, Union
from copy import deepcopy

In [3]:
def set_seed(seed_value):
    """Set seed for reproducibility."""
    random.seed(seed_value)  # Python random module
    np.random.seed(seed_value)  # Numpy module
    torch.manual_seed(seed_value)  # PyTorch
    torch.cuda.manual_seed(seed_value)  # PyTorch CUDA
    torch.cuda.manual_seed_all(seed_value)  # PyTorch CUDA (for multi-GPU setups)
    torch.backends.cudnn.deterministic = True  # For CUDA backend
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed_value)  # For Python hash seeding

# Example usage
set_seed(42)  # Replace 42 with your desired seed

In [4]:
lang = "python"
# num_prefix = 2
num_virtual_tokens = 2
if len(lang.split('_')) > 1:
    langs = lang.split('_')
else:
    langs = [lang]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
checkpoint_dir = "./codet5p_checkpoints"
checkpoint_name = f"codet5p_VPT_lang_{lang}_num_vtokens_{num_virtual_tokens}"

# Your finetuned model
base_model = f"codet5p_ft_lang_{lang}_backbone"

if lang == 'java':
    base_model_tokenizer = 'Salesforce/codet5p-220m'
else:
    base_model_tokenizer = 'Salesforce/codet5p-220m-bimodal'
if 'bimodal' in base_model or 'python' in base_model:
    print("using auto model")
    backbone_model = AutoModel.from_pretrained(base_model, trust_remote_code=True).to(device)
else:
    print("using t5 conditional generation model")
    backbone_model = T5ForConditionalGeneration.from_pretrained(base_model).to(device)

max_input_length = 512
max_target_length = 128
train_batch_size = 16
test_batch_size = 64
tokenizer = AutoTokenizer.from_pretrained(base_model_tokenizer)
print(checkpoint_name)

using auto model


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


codet5p_full_cvae_ft_prefix_tuning_self_attend_cond_code_prior_lang_python_num_vtokens_2


In [5]:
print(torch.cuda.get_device_name(device))

NVIDIA GeForce RTX 4090


In [6]:
# load data, change data directory here, make sure its format follows the one given in data directory
train_codes = []
train_docs = []
for l in langs:
    train_codes.extend(open("./data/{}/train/code.original".format(l), 'r').readlines())
    train_docs.extend(open("./data/{}/train/javadoc.original".format(l), 'r').readlines())
train_codes = [code.rstrip() for code in train_codes]
train_docs = [doc.rstrip() for doc in train_docs]
train_docs_codes = [code + tokenizer.eos_token + doc for code, doc in zip(train_codes, train_docs)]
train_inputs = {
    "code_text": train_codes,
    "doc_text": train_docs,
    "doc_code_text": train_docs_codes
}

In [7]:
train_data = Dataset.from_dict(train_inputs)
train_data.set_format(type='torch', columns=['code_text', 'doc_text', 'doc_code_text'])
train_data

Dataset({
    features: ['code_text', 'doc_text', 'doc_code_text'],
    num_rows: 55538
})

In [8]:
train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True)

In [9]:
train_ex_batch = next(iter(train_loader))
print(train_ex_batch.keys())

dict_keys(['code_text', 'doc_text', 'doc_code_text'])


In [10]:
ex_codes_input = tokenizer(train_ex_batch['code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length)
ex_docs_input = tokenizer(train_ex_batch['doc_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_target_length)
ex_docs_codes_input = tokenizer(train_ex_batch['doc_code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length+max_target_length)
ex_input_ids = ex_codes_input["input_ids"]
ex_attention_mask = ex_codes_input["attention_mask"]
ex_labels = ex_docs_input["input_ids"].clone()
ex_labels[ex_labels == tokenizer.pad_token_id] = -100
ex_labels_ids = ex_docs_input["input_ids"]
ex_labels_attention_mask = ex_docs_input["attention_mask"]
ex_docs_codes_input_ids = ex_docs_codes_input["input_ids"]
ex_docs_codes_attention_mask = ex_docs_codes_input["attention_mask"]
(ex_input_ids.shape, ex_attention_mask.shape, ex_labels.shape, ex_labels_ids.shape, ex_labels_attention_mask.shape, ex_docs_codes_input_ids.shape, ex_docs_codes_attention_mask.shape)

(torch.Size([15, 223]),
 torch.Size([15, 223]),
 torch.Size([15, 39]),
 torch.Size([15, 39]),
 torch.Size([15, 39]),
 torch.Size([15, 229]),
 torch.Size([15, 229]))

In [11]:
learning_rate = 5e-5
warmup_steps = 500
num_epochs = 200

total_steps = len(train_loader) * num_epochs

In [12]:
def eval_bleu(model, device, src_dir, tgt_dir, tokenizer, std_scale=1.0):
    model.eval()
    model.std_scale = std_scale
    source_codes = open(src_dir, encoding="utf-8").readlines()
    targets = open(tgt_dir, encoding="utf-8").readlines()
    all_summaries = []
    eval_inputs = {
        "code_text": source_codes,
        "doc_text": targets
    }
    eval_data = Dataset.from_dict(eval_inputs)
    eval_data.set_format(type='torch', columns=['code_text', 'doc_text'])
    eval_loader = DataLoader(eval_data, batch_size=test_batch_size)
    for batch in tqdm(eval_loader):
        tokenized_codes = tokenizer(batch['code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length)
        tokenized_docs = tokenizer(batch['doc_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_target_length)
        input_ids = tokenized_codes["input_ids"].to(device)
        # print(input_ids.shape)
        attention_mask = tokenized_codes["attention_mask"].to(device)
        labels = tokenized_docs["input_ids"].clone().to(device)
        labels[labels == tokenizer.pad_token_id] = -100
        labels_ids = tokenized_docs["input_ids"].to(device)
        labels_attention_mask = tokenized_docs["attention_mask"].to(device)
        generated_ids = model.generate(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       max_length=80)
        summaries = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        all_summaries.extend(summaries)
    hypotheses = dict(enumerate([[summary.rstrip().lower()[:-1]+' .'] for summary in all_summaries]))
    references = dict(enumerate([[target.rstrip().lower()] for target in targets]))
    _, bleu, ind_bleu = corpus_bleu(hypotheses, references)
    return bleu
    

In [13]:
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

In [14]:
def eval_bleu_with_answer(model, device, src_dir, tgt_dir, tokenizer):
    model.eval()
    source_codes = open(src_dir, encoding="utf-8").readlines()
    targets = open(tgt_dir, encoding="utf-8").readlines()
    all_summaries = []
    eval_inputs = {
        "code_text": source_codes,
        "doc_text": targets
    }
    eval_data = Dataset.from_dict(eval_inputs)
    eval_data.set_format(type='torch', columns=['code_text', 'doc_text'])
    eval_loader = DataLoader(eval_data, batch_size=test_batch_size)
    for batch in tqdm(eval_loader):
        tokenized_codes = tokenizer(batch['code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length)
        tokenized_docs = tokenizer(batch['doc_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_target_length)
        input_ids = tokenized_codes["input_ids"].to(device)
        attention_mask = tokenized_codes["attention_mask"].to(device)
        labels = tokenized_docs["input_ids"].clone().to(device)
        labels[labels == tokenizer.pad_token_id] = -100
        labels_ids = tokenized_docs["input_ids"].to(device)
        labels_attention_mask = tokenized_docs["attention_mask"].to(device)
        generated_ids = model.generate(input_ids=input_ids, 
                                       labels=labels,
                                       max_length=50)
        summaries = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        all_summaries.extend(summaries)
    hypotheses = dict(enumerate([[summary.rstrip().lower()[:-1]+' .'] for summary in all_summaries]))
    references = dict(enumerate([[target.rstrip().lower()] for target in targets]))
    _, bleu, ind_bleu = corpus_bleu(hypotheses, references)
    return bleu

In [15]:
peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=num_virtual_tokens)

In [16]:
class CVAEPrefixEncoder(torch.nn.Module):
    r"""
    The `torch.nn` model to encode the prefix.

    Args:
        config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.

    Example:

    ```py
    >>> from peft import PrefixEncoder, PrefixTuningConfig

    >>> config = PrefixTuningConfig(
    ...     peft_type="PREFIX_TUNING",
    ...     task_type="SEQ_2_SEQ_LM",
    ...     num_virtual_tokens=20,
    ...     token_dim=768,
    ...     num_transformer_submodules=1,
    ...     num_attention_heads=12,
    ...     num_layers=12,
    ...     encoder_hidden_size=768,
    ... )
    >>> prefix_encoder = PrefixEncoder(config)
    ```

    **Attributes**:
        - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder.
        - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if
          `prefix_projection` is `True`.
        - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.

    Input shape: (`batch_size`, `num_virtual_tokens`)

    Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`)
    """

    def __init__(self, config, model):
        super().__init__()
        # self.prefix_projection = config.prefix_projection
        text_encoder = model.get_encoder()
        self.token_dim = config.token_dim
        self.num_layers = config.num_layers
        self.d_model = text_encoder.config.d_model
        # encoder_hidden_size = config.encoder_hidden_size
        self.num_virtual_tokens = config.num_virtual_tokens
        # if self.prefix_projection and not config.inference_mode:
        #     # Use a two-layer MLP to encode the prefix
        #     self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
        #     self.transform = torch.nn.Sequential(
        #         torch.nn.Linear(token_dim, encoder_hidden_size),
        #         torch.nn.Tanh(),
        #         torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
        #     )
        # else:
        #     self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
        self.embedding = nn.Embedding(self.num_virtual_tokens, self.d_model)
        self.code_embedding = nn.Embedding(self.num_virtual_tokens, self.d_model)
        self.encoder_embedding = text_encoder.get_input_embeddings()
        
        self.encoder = deepcopy(text_encoder)
        # freeze the encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
            
        # summ_encode_layer_config = copy.deepcopy(text_encoder.config)
        # summ_encode_layer_config.num_layers = 2
        # self.encoder = T5Stack(summ_encode_layer_config, self.encoder_embedding)
        
        # context_attn_config = copy.deepcopy(model.get_decoder().config)
        # context_attn_config.num_layers = 1
        # 
        # self.context_attn = T5LayerCrossAttention(context_attn_config)
        # self.dropout = nn.Dropout(model.config.dropout_rate)
        # self.layer_norm = T5LayerNorm(model.config.d_model)
        
        self.mean = T5LayerFF(text_encoder.config)
        self.log_var = T5LayerFF(text_encoder.config)
        
        
        
        # self.transform_mean = nn.Sequential(
        #     nn.Linear(self.d_model, self.token_dim),
        #     nn.Tanh(),
        #     nn.Linear(self.token_dim, self.num_layers * 2 * self.d_model),
        # )
        # 
        # self.transform_log_var = nn.Sequential(
        #     nn.Linear(self.d_model, self.token_dim),
        #     nn.Tanh(),
        #     nn.Linear(self.token_dim, self.num_layers * 2 * self.d_model),
        # )
        
        self.transform_z = nn.Sequential(
            nn.Linear(self.d_model, self.token_dim),
            nn.Tanh(),
            nn.Linear(self.token_dim, self.num_layers * 2 * self.d_model),
        )


    def forward(self,  
                prefix, 
                input_ids=None, 
                attention_mask=None, 
                labels_id=None, 
                labels_attention_mask=None,
                doc_code_input_ids=None,
                doc_code_attention_mask=None,
                std_scale=1.0,
                num_beams=None):
        if self.training:
            prefix_embeddings = self.embedding(prefix)
            code_prefix_embeddings = self.code_embedding(prefix)
            
            # labels_embeddings = self.encoder_embedding(labels_id)
            input_embeddings = self.encoder_embedding(input_ids)
            doc_code_embeddings = self.encoder_embedding(doc_code_input_ids)
            doc_code_attention_mask = torch.cat((torch.ones(prefix.shape[0], self.num_virtual_tokens).to(prefix.device), 
                                                doc_code_attention_mask), dim=1)
            
            concat_embeddings = torch.cat((prefix_embeddings, doc_code_embeddings), dim=1)
            code_concat_embeddings = torch.cat((code_prefix_embeddings, input_embeddings), dim=1)
            # concat_mask = torch.cat((labels_attention_mask, attention_mask), dim=1)
            
            encoder_outputs = self.encoder(input_ids=None,
                                           attention_mask=doc_code_attention_mask,
                                           inputs_embeds=concat_embeddings)
            
            # attention_mask_with_prefix = torch.cat((torch.ones(prefix.shape[0], self.num_virtual_tokens).to(prefix.device), 
            #                                         attention_mask), dim=1)
            # code_encoder_outputs = self.encoder(input_ids=input_ids,
            #                                     attention_mask=attention_mask)
            
            # prior_mu = code_encoder_outputs.last_hidden_state[:, :self.num_virtual_tokens]

            attention_mask_with_prefix = torch.cat((torch.ones(prefix.shape[0], self.num_virtual_tokens).to(prefix.device),
                                                    attention_mask), dim=1)
            code_encoder_outputs = self.encoder(input_ids=None,
                                                attention_mask=attention_mask_with_prefix,
                                                inputs_embeds=code_concat_embeddings)

            prior_mu = code_encoder_outputs.last_hidden_state[:, :self.num_virtual_tokens]
            
            # code_encoding = code_encoder_outputs.last_hidden_state
            # past_key_values = self.transform(encoder_outputs.last_hidden_state[:, :self.num_virtual_tokens])
            prefix_encoding = encoder_outputs.last_hidden_state[:, :self.num_virtual_tokens]
            
            # mid = self.context_attn(prefix_encoding,
            #                         key_value_states=code_encoding,
            #                         attention_mask=self.encoder.get_extended_attention_mask(attention_mask, input_ids.size()))[0]
            
            # print(mid.shape)
            # prefix_encoding = self.layer_norm(self.dropout(mid) + prefix_encoding)
                                    
            mean = self.mean(prefix_encoding)
            log_var = self.log_var(prefix_encoding)
            # projected_mean = self.transform_mean(mean)
            # projected_log_var = self.transform_log_var(log_var)
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            z = eps.mul(std).add_(mean)
            past_key_values = self.transform_z(z)
        else:
            # past_key_values_shape = (prefix.shape[0], self.num_virtual_tokens, self.num_layers * 2, self.d_model)
            # projected_mean = torch.zeros(past_key_values_shape).to(prefix.device)
            # projected_log_var = torch.zeros(past_key_values_shape).to(prefix.device)
            # past_key_values = torch.empty_like(projected_mean).normal_(mean=0, 
            #                                                            std=std_scale).to(prefix.device)
            
            code_prefix_embeddings = self.code_embedding(prefix)
            if num_beams:
                input_ids = input_ids.repeat_interleave(num_beams, dim=0)
            input_embeddings = self.encoder_embedding(input_ids)

            code_concat_embeddings = torch.cat((code_prefix_embeddings, input_embeddings), dim=1)

            attention_mask_with_prefix = torch.cat((torch.ones(prefix.shape[0], self.num_virtual_tokens).to(prefix.device), 
                                                    attention_mask), dim=1)
            code_encoder_outputs = self.encoder(input_ids=None,
                                                attention_mask=attention_mask_with_prefix,
                                                inputs_embeds=code_concat_embeddings)
            
            prior_mu = code_encoder_outputs.last_hidden_state[:, :self.num_virtual_tokens]
            
            # z_shape = (prefix.shape[0], self.num_virtual_tokens, self.d_model)
            # 
            # z = torch.empty(z_shape).normal_(mean=0, std=std_scale).to(prefix.device) # + prior_mu
            # past_key_values = self.transform_z(z)
            # mean = torch.zeros(z_shape).to(prefix.device)
            # log_var = torch.zeros(z_shape).to(prefix.device)
            z = torch.empty_like(prior_mu).normal_(mean=0, std=std_scale).to(prefix.device) + prior_mu
            past_key_values = self.transform_z(z)
            mean = torch.zeros_like(prior_mu).to(prefix.device)
            log_var = torch.zeros_like(prior_mu).to(prefix.device)
        # return past_key_values, projected_mean.reshape(prefix.shape[0], 1, -1), projected_log_var.reshape(prefix.shape[0], 1, -1)
        return past_key_values, mean, log_var, prior_mu

In [17]:
class CVAEPeftModelForSeq2SeqLM(PeftModel):
    """
    Peft model for sequence-to-sequence language modeling.

    Args:
        model ([`~transformers.PreTrainedModel`]): Base transformer model.
        peft_config ([`PeftConfig`]): Peft config.


    Example:

        ```py
        >>> from transformers import AutoModelForSeq2SeqLM
        >>> from peft import PeftModelForSeq2SeqLM, get_peft_config

        >>> config = {
        ...     "peft_type": "LORA",
        ...     "task_type": "SEQ_2_SEQ_LM",
        ...     "inference_mode": False,
        ...     "r": 8,
        ...     "target_modules": ["q", "v"],
        ...     "lora_alpha": 32,
        ...     "lora_dropout": 0.1,
        ...     "fan_in_fan_out": False,
        ...     "enable_lora": None,
        ...     "bias": "none",
        ... }

        >>> peft_config = get_peft_config(config)
        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
        >>> peft_model.print_trainable_parameters()
        trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
        ```
    """

    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        self.curr_input_ids = None
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
        self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
            self.base_model._prepare_encoder_decoder_kwargs_for_generation
        )
        self.prompt_encoder[self.active_adapter] = CVAEPrefixEncoder(peft_config, self.base_model)
        self.std_scale = 1.0
        self.beam_size = None

    @staticmethod
    def kl_loss(mean1, logvar1, mean2, logvar2):
        exponential = logvar1 - logvar2 - torch.pow(mean1 - mean2, 2) / logvar2.exp() - torch.exp(logvar1 - logvar2) + 1
        return -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))

    def get_prompt(self, 
                   batch_size: int, 
                   task_ids: Optional[torch.Tensor] = None, 
                   input_ids=None, 
                   attention_mask=None, 
                   labels_id=None, 
                   labels_attention_mask=None, 
                   doc_code_input_ids=None,
                   doc_code_attention_mask=None,
                   std_scale=1.0, 
                   num_beams=None) -> torch.Tensor:
        """
        Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
        """
        peft_config = self.active_peft_config
        prompt_encoder = self.prompt_encoder[self.active_adapter]
        prompt_tokens = (
            self.prompt_tokens[self.active_adapter]
            .unsqueeze(0)
            .expand(batch_size, -1)
            .to(prompt_encoder.embedding.weight.device)
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]

            if peft_config.inference_mode:
                past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                past_key_values, mu, logvar, prior_mu = prompt_encoder(prompt_tokens,
                                                                       input_ids=input_ids,
                                                                       attention_mask=attention_mask,
                                                                       labels_id=labels_id,
                                                                       labels_attention_mask=labels_attention_mask,
                                                                       doc_code_input_ids=doc_code_input_ids,
                                                                       doc_code_attention_mask=doc_code_attention_mask,
                                                                       std_scale=std_scale,
                                                                       num_beams=num_beams)
            if self.base_model_torch_dtype is not None:
                past_key_values = past_key_values.to(self.base_model_torch_dtype)
            past_key_values = past_key_values.view(
                batch_size,
                peft_config.num_virtual_tokens,
                peft_config.num_layers * 2,
                peft_config.num_attention_heads,
                peft_config.token_dim // peft_config.num_attention_heads,
            )
            if peft_config.num_transformer_submodules == 2:
                past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
            past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                peft_config.num_transformer_submodules * 2
            )
            if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
                post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
                past_key_values = post_process_fn(past_key_values)
            return past_key_values, mu, logvar, prior_mu
        else:
            if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
                prompts = prompt_encoder(prompt_tokens, task_ids)
            else:
                if peft_config.inference_mode:
                    prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
                else:
                    prompts = prompt_encoder(prompt_tokens)
            return prompts

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            inputs_embeds=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            decoder_inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            doc_code_input_ids=None,
            doc_code_attention_mask=None,
            **kwargs,
    ):
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=inputs_embeds,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    decoder_inputs_embeds=decoder_inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

        batch_size = _get_batch_size(input_ids, inputs_embeds)
        if decoder_attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                decoder_attention_mask.device
            )
            if peft_config.peft_type not in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
                decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "decoder_attention_mask": decoder_attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            labels_ids = labels.clone().to(device)
            labels_ids[labels_ids == -100] = self.config.pad_token_id
            past_key_values, mu, logvar, prior_mu = self.get_prompt(batch_size,
                                                                    input_ids=input_ids,
                                                                    attention_mask=attention_mask,
                                                                    labels_id=labels_ids,
                                                                    labels_attention_mask=decoder_attention_mask,
                                                                    doc_code_input_ids=doc_code_input_ids,
                                                                    doc_code_attention_mask=doc_code_attention_mask,)
            kl_loss = self.kl_loss(mu, logvar, prior_mu, torch.zeros_like(logvar).to(logvar.device))
            outputs = self.base_model(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                decoder_inputs_embeds=decoder_inputs_embeds,
                past_key_values=past_key_values,
                **kwargs,
            )
            outputs.kl_loss = kl_loss.mean()
            return outputs
        elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

            prompts = self.get_prompt(batch_size=batch_size)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)

            return self.base_model(
                inputs_embeds=inputs_embeds,
                decoder_input_ids=decoder_input_ids,
                decoder_inputs_embeds=decoder_inputs_embeds,
                **kwargs,
            )
        else:
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)
            if decoder_inputs_embeds is None and decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
                decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)

            if attention_mask is not None:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                    attention_mask.device
                )
                kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
            # concat prompt labels
            if labels is not None:
                if peft_config.num_transformer_submodules == 1:
                    kwargs["labels"] = labels
                elif peft_config.num_transformer_submodules == 2:
                    prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                    kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
            prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
            prompts = prompts.to(inputs_embeds.dtype)
            inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
            if peft_config.num_transformer_submodules == 1:
                return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
            elif peft_config.num_transformer_submodules == 2:
                decoder_inputs_embeds = torch.cat(
                    (prompts[:, peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
                )
                return self.base_model(
                    inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
                )

    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
            self._prepare_encoder_decoder_kwargs_for_generation
        )
        try:
            if not peft_config.is_prompt_learning:
                with self._enable_peft_forward_hooks(**kwargs):
                    kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                    outputs = self.base_model.generate(**kwargs)
            else:
                if "input_ids" not in kwargs:
                    raise ValueError("input_ids must be provided for Peft model generation")
                if kwargs.get("position_ids", None) is not None:
                    warnings.warn(
                        "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
                    )
                    kwargs["position_ids"] = None
                if kwargs.get("token_type_ids", None) is not None:
                    warnings.warn(
                        "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
                    )
                    kwargs["token_type_ids"] = None

                if peft_config.peft_type == PeftType.PREFIX_TUNING:
                    self.curr_input_ids = kwargs.get("input_ids", None)
                    self.num_beams = kwargs.get("num_beams", None)
                    outputs = self.base_model.generate(**kwargs)
                elif peft_config.peft_type in [
                    PeftType.PROMPT_TUNING,
                    PeftType.P_TUNING,
                    PeftType.MULTITASK_PROMPT_TUNING,
                ]:
                    kwargs = deepcopy(kwargs)

                    if "encoder_outputs" in kwargs:
                        del kwargs["encoder_outputs"]
                        warnings.warn(
                            "`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
                        )

                    input_ids = kwargs.pop("input_ids")
                    inputs_embeds = self.word_embeddings(input_ids)
                    batch_size = inputs_embeds.shape[0]
                    prompts = self.get_prompt(batch_size=batch_size, task_ids=kwargs.pop("task_ids", None))
                    prompts = prompts.to(inputs_embeds.dtype)

                    inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
                    kwargs["inputs_embeds"] = inputs_embeds

                    if "attention_mask" in kwargs:
                        prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                            kwargs["attention_mask"].device
                        )
                        kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)

                    return self.base_model.generate(**kwargs)
                else:
                    raise NotImplementedError
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
                self.base_model_prepare_encoder_decoder_kwargs_for_generation
            )
            return outputs

    def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if peft_config.peft_type == PeftType.POLY:
            model_kwargs["task_ids"] = task_ids
        if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
            batch_size = model_kwargs["decoder_input_ids"].shape[0]
            attention_mask = kwargs.get("attention_mask", None)
            assert self.curr_input_ids is not None and attention_mask is not None, "input_ids and attention_mask must be provided"
            past_key_values, _, _, _ = self.get_prompt(batch_size, std_scale=self.std_scale, input_ids=self.curr_input_ids, attention_mask=attention_mask, num_beams=self.num_beams)
            self.curr_input_ids = None
            self.num_beams = None
            # print(len(past_key_values))
            # for i in range(len(past_key_values)):
            #     print(past_key_values[i].shape)
            model_kwargs["past_key_values"] = past_key_values

        return model_kwargs

In [18]:
peft_model = CVAEPeftModelForSeq2SeqLM(backbone_model, peft_config).to(device)
peft_model.print_trainable_parameters()

trainable params: 24,206,592 || all params: 356,898,690 || trainable%: 6.782482726400593


In [19]:
# peft_model(input_ids=ex_input_ids.to(device), 
#            attention_mask=ex_attention_mask.to(device),
#            labels=ex_labels.to(device),
#            decoder_attention_mask=ex_labels_attention_mask.to(device),
#            doc_code_input_ids=ex_docs_codes_input_ids.to(device),
#            doc_code_attention_mask=ex_docs_codes_attention_mask.to(device)).loss

In [20]:
# peft_model.eval()
# peft_model.generate(input_ids=ex_input_ids.to(device),
#                     labels=ex_labels.to(device), 
#                     max_length=50)

In [21]:
peft_lr = 5e-5

In [22]:
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=peft_lr)
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_loader) * num_epochs),
)

In [23]:
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [24]:
steps = 0
curr_epoch = 0
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_name + ".pt")]
if len(checkpoints) > 0:
    checkpoint = torch.load(os.path.join(checkpoint_dir, checkpoints[-1]))
    peft_model.prompt_encoder[peft_model.active_adapter].load_state_dict(checkpoint['prefix_encoder_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    steps = checkpoint['steps']
    curr_epoch = checkpoint['epoch']
    print("Loaded checkpoint: ", checkpoints[-1], " at epoch ", curr_epoch, " and steps ", steps)
    print(f"ML Loss: {checkpoint['ml_loss']}, KL Loss: {checkpoint['kl_loss']}")
else:
    print("No checkpoints found")

Loaded checkpoint:  codet5p_full_cvae_ft_prefix_tuning_self_attend_cond_code_prior_lang_python_num_vtokens_2.pt  at epoch  189  and steps  699867
ML Loss: 0.006632651334499315, KL Loss: 0.8372590363653422


In [None]:
# measure model performance before fine-tuning
if curr_epoch < num_epochs:
    for l in langs:
        test_source_dir = "./data/{}/test/code.original".format(l)
        test_target_dir = "./data/{}/test/javadoc.original".format(l)
        print("Test BLEU for {}: ".format(l), eval_bleu(peft_model,
                                                        device,
                                                        test_source_dir,
                                                        test_target_dir,
                                                        tokenizer,
                                                        std_scale=1.0))
        torch.cuda.empty_cache()

In [26]:
def train(model, 
          device, 
          train_loader, 
          optimizer, 
          scheduler, 
          num_epochs,
          steps=0,
          curr_epoch=0):
    steps = steps
    kl_betas = torch.cat((torch.linspace(0, 1, 2500), torch.ones(2500).float())).tolist()
    for epoch in range(curr_epoch, num_epochs):
        model.train()
        total_ml_loss = 0
        total_kl_loss = 0
        total_loss = 0
        # Wrap the train_loader with tqdm for a progress bar
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Load batch to device
            kl_beta = kl_betas[steps % 5000]
            tokenized_codes = tokenizer(batch['code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length)
            tokenized_docs = tokenizer(batch['doc_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_target_length)
            tokenizeds_docs_codes = tokenizer(batch['doc_code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length+max_target_length)
            input_ids = tokenized_codes["input_ids"].to(device)
            attention_mask = tokenized_codes["attention_mask"].to(device)
            labels = tokenized_docs["input_ids"].clone().to(device)
            labels[labels == tokenizer.pad_token_id] = -100
            labels_ids = tokenized_docs["input_ids"].to(device)
            labels_attention_mask = tokenized_docs["attention_mask"].to(device)
            doc_code_input_ids = tokenizeds_docs_codes["input_ids"].to(device)
            doc_code_attention_mask = tokenizeds_docs_codes["attention_mask"].to(device)

            # Forward pass
            model.zero_grad()
            outputs = model(input_ids=input_ids, 
                            attention_mask=attention_mask, 
                            labels=labels,
                            decoder_attention_mask=labels_attention_mask,
                            doc_code_input_ids=doc_code_input_ids,
                            doc_code_attention_mask=doc_code_attention_mask)
            ml_loss = outputs.loss
            kl_loss = outputs.kl_loss
            loss = ml_loss  + kl_beta * kl_loss

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
            optimizer.step()
            scheduler.step()  # Update the learning rate

            total_loss += loss.item()
            total_ml_loss += ml_loss.item()
            total_kl_loss += kl_loss.item()
            total_loss += loss.item()
            steps += 1
            # Update the progress bar with the current loss
            progress_bar.set_postfix({'ml_loss': ml_loss.item(), 'kl_loss': kl_loss.item(), 'total_loss': loss.item(),'kl_beta': kl_beta, 'steps': steps})

        avg_epoch_ml_loss = total_ml_loss / len(train_loader)
        avg_epoch_kl_loss = total_kl_loss / len(train_loader)
        avg_epoch_loss = (total_ml_loss + total_kl_loss) / len(train_loader)
        # evaluate model performance every 5 epochs

        print(f"Epoch {epoch+1} completed. Average ML Loss: {avg_epoch_ml_loss}, Average KL Loss: {avg_epoch_kl_loss}, Average Loss: {avg_epoch_loss}")
        
        # model.save_pretrained(checkpoint_name + "peft_params")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        torch.save({
            'epoch': epoch+1,
            'prefix_encoder_state_dict': model.prompt_encoder[model.active_adapter].state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'steps': steps,
            'ml_loss': avg_epoch_ml_loss,
            'kl_loss': avg_epoch_kl_loss,
        }, os.path.join(checkpoint_dir, checkpoint_name+".pt"))
        
        if (epoch+1) % 5 == 0:
            for l in langs:
                validation_source_dir = "./data/{}/dev/code.original".format(l)
                validation_target_dir = "./data/{}/dev/javadoc.original".format(l)
                print("Validation BLEU for {}: ".format(l),
                      eval_bleu(model,
                                device,
                                validation_source_dir,
                                validation_target_dir,
                                tokenizer,
                                std_scale=1.0))
                

In [None]:
# Start training
train(peft_model, 
      device, 
      train_loader, 
      optimizer,
      scheduler, 
      num_epochs,
      steps=steps,
      curr_epoch=curr_epoch)

In [None]:
# measure model performance before fine-tuning
if curr_epoch < num_epochs:
    for l in langs:
        test_source_dir = "./data/{}/test/code.original".format(l)
        test_target_dir = "./data/{}/test/javadoc.original".format(l)
        print("Test BLEU for {}: ".format(l), eval_bleu(peft_model,
                                                        device,
                                                        test_source_dir,
                                                        test_target_dir,
                                                        tokenizer,
                                                        std_scale=1.0))
        torch.cuda.empty_cache()

In [29]:
# Save the model
# model.save_pretrained("codet5p_ft_epoch_{}_lang_{}".format(num_epochs, lang))

In [30]:
def distinct_with_beam_search(model,
                              device,
                              src_dir,
                              tgt_dir,
                              tokenizer,
                              batch_size=16,
                              beam_size=10,
                              num_return_sequences=8,
                              std_scale=1.0):
    model.eval()
    model.std_scale = std_scale
    source_codes = open(src_dir, encoding="utf-8").readlines()
    targets = open(tgt_dir, encoding="utf-8").readlines()
    source_codes = [code.rstrip() for code in source_codes]
    targets = [target.rstrip() for target in targets]
    all_summaries = []
    eval_inputs = {
        "code_text": source_codes,
        "doc_text": targets
    }
    eval_data = Dataset.from_dict(eval_inputs)
    eval_data.set_format(type='torch', columns=['code_text', 'doc_text'])
    eval_loader = DataLoader(eval_data, batch_size=batch_size)
    for batch in tqdm(eval_loader):
        tokenized_codes = tokenizer(batch['code_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_input_length)
        tokenized_docs = tokenizer(batch['doc_text'], return_tensors="pt", padding=True, truncation=True, max_length=max_target_length)
        input_ids = tokenized_codes["input_ids"].to(device)
        attention_mask = tokenized_codes["attention_mask"].to(device)
        labels = tokenized_docs["input_ids"].clone().to(device)
        labels[labels == tokenizer.pad_token_id] = -100
        labels_ids = tokenized_docs["input_ids"].to(device)
        labels_attention_mask = tokenized_docs["attention_mask"].to(device)
        generated_ids = model.generate(input_ids=input_ids, 
                                       attention_mask=attention_mask,
                                       max_length=100, 
                                       num_beams=beam_size, 
                                       num_return_sequences=num_return_sequences)
        summaries = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        all_summaries.extend(summaries)
    hypotheses = dict(enumerate([[re.sub(r"\n{1,}|\t{1,}|\r{1,}", " ", summary.strip().lower()[:-1]+' .')] for summary in all_summaries]))
    # repeat targets for each generated sequence
    repeated_targets = []
    for target in targets:
        repeated_targets.extend([target]*num_return_sequences)
    references = dict(enumerate([[re.sub(r"\n{1,}|\t{1,}|\r{1,}", " ", target.strip().lower())] for target in repeated_targets]))
    #calculate oracle scores
    _, bleu, ind_bleu = corpus_bleu(hypotheses, references)
    reshaped_bleu = np.array(list(ind_bleu.values())).reshape(-1, num_return_sequences)
    oracle_bleu = np.max(reshaped_bleu, axis=1)
    rouge_calculator = Rouge()
    rouge_l, ind_rouge = rouge_calculator.compute_score(references, hypotheses)
    reshaped_rouge = np.array(list(ind_rouge.values())).reshape(-1, num_return_sequences)
    oracle_rouge = np.max(reshaped_rouge, axis=1)
    meteor_calculator = Meteor()
    meteor, ind_meteor = meteor_calculator.compute_score(references, hypotheses)
    reshaped_meteor = np.array(list(ind_meteor)).reshape(-1, num_return_sequences)
    oracle_meteor = np.max(reshaped_meteor, axis=1)
    print("Oracle_bleu: {}, Oracle_rouge: {}, Oracle_meteor: {}".format(np.mean(oracle_bleu) * 100,
                                                                         np.mean(oracle_rouge) * 100,
                                                                         np.mean(oracle_meteor) * 100))

    return hypotheses, references

In [31]:
# num_return_sequences_list = [4, 8, 12, 16, 20]
# for num_return_sequences in num_return_sequences_list:
#     for l in langs:
#         test_source_dir = "./data/{}/test/code.original".format(l)
#         test_target_dir = "./data/{}/test/javadoc.original".format(l)
#         print("Test distinct with beam search for {} with beam size {}: ".format(l, num_return_sequences))
#         hypotheses, references = distinct_with_beam_search(peft_model,
#                                                            device,
#                                                            test_source_dir,
#                                                            test_target_dir,
#                                                            tokenizer,
#                                                            batch_size=4,
#                                                            beam_size=num_return_sequences,
#                                                            num_return_sequences=num_return_sequences,
#                                                            std_scale=1.0)

In [32]:
class TextDataset(TorchDataset):
    def __init__(self, hypotheses, references):
        self.hypotheses = list(chain.from_iterable(hypotheses.values()))
        self.references = list(chain.from_iterable(references.values()))

    def __len__(self):
        return len(self.hypotheses)

    def __getitem__(self, idx):
        return self.hypotheses[idx], self.references[idx]

In [33]:
def distinct_with_latent(model,
                         device,
                         src_dir,
                         tgt_dir,
                         tokenizer,
                         std_scale=20,
                         batch_size=1,
                         num_latents=100,
                         num_beams=4):
    model.eval()
    model.std_scale = std_scale
    source_codes = open(src_dir, encoding="utf-8").readlines()
    targets = open(tgt_dir, encoding="utf-8").readlines()
    source_codes = [code.rstrip() for code in source_codes]
    targets = [target.rstrip() for target in targets]
    all_summaries = []
    # repeat each code num_latents times
    repeated_source_codes = []

    for code in source_codes:
        repeated_source_codes.extend([code]*num_latents)
    for i in tqdm(range(0, len(repeated_source_codes), batch_size)):
        batch = repeated_source_codes[i:i+batch_size]
        input = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=400)
        input_ids = input["input_ids"].to(device)
        attention_mask = input["attention_mask"].to(device)
        # out = model.encoder(input_ids,
        #                     attention_mask=attention_mask,
        #                     labels_id=None,
        #                     labels_attention_mask=None,
        #                     num_prefix=num_prefix,
        #                     std_scale=std_scale,
        #                     return_dict=True)[0]
        # enc_out = BaseModelOutput(last_hidden_state=out)
        # generated_ids = model.decoder.generate(encoder_outputs=enc_out, max_length=80, num_beams=1)
        if num_beams <= 1:
            generated_ids = model.generate(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           max_length=80,
                                           )
        else:
            generated_ids = model.generate(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           max_length=80,
                                           num_beams=num_beams,
                                           num_return_sequences=1)
        summaries = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        all_summaries.extend(summaries)

    # separate the summaries into nested list of num_latents
    separated_summaries = [all_summaries[i:i+num_latents] for i in range(0, len(all_summaries), num_latents)]
    all_summaries_dict = [[source_codes[i],
                           re.sub(r"\n{1,}|\t{1,}|\r{1,}", " ", targets[i].strip().lower()),
                           [re.sub(r"\n{1,}|\t{1,}|\r{1,}", " ", summary.strip().lower()[:-1]+' .') for summary in separated_summaries[i]]]
                          for i in range(len(source_codes))]
    all_ref_fw = open(os.path.join("T5_results", "generated", lang, f"num_latents_{num_latents}_std_scale_{std_scale}_beam_{num_beams}_hypotheses_all.json"), 'w')
    
    # raw summaries of 100 is now generated
    json.dump(all_summaries_dict, all_ref_fw, indent=4)
    all_ref_fw.close()
    return all_summaries_dict

In [None]:
# generate distinct set with latent sampling, you should tune the std_scales based on your dataset
std_scales = [5]
for std_scale in std_scales:
    print("current std_scale: ", std_scale)
    for l in langs:
        # change test dataset dir here
        test_source_dir = "./data/{}/test/code.original".format(l)
        test_target_dir = "./data/{}/test/javadoc.original".format(l)
        all_summaries_dict = distinct_with_latent(peft_model,
                                             device,
                                             test_source_dir,
                                             test_target_dir,
                                             tokenizer,
                                             std_scale=std_scale,
                                             batch_size=100,
                                             num_latents=100,
                                             num_beams=4)

In [None]:
# compute model score for output summary, recalculated since the log probs are perturbed by latent variables
def score_model_sequence_scoring(code, all_candidates, score_model, score_tokenizer, batch_size):

    tokenized_code = score_tokenizer([code] * len(all_candidates), return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    tokenized_candidates = score_tokenizer(all_candidates, return_tensors="pt", padding=True).to(device)
    tokenized_candidates_cat_bos_token = torch.cat([torch.zeros(tokenized_candidates.input_ids.shape[0], 1, dtype=torch.long).to(device), tokenized_candidates.input_ids], dim=-1)
    with torch.no_grad():
        scores = []
        for i in range(0, len(all_candidates), batch_size):
            batch_code_tokens = tokenized_code.input_ids[i:i+batch_size]
            batch_hyp_tokens = tokenized_candidates_cat_bos_token[i:i+batch_size]
            batch_decode_outputs = score_model(input_ids=batch_code_tokens, decoder_input_ids=batch_hyp_tokens)
            batch_decode_outputs_logits_tuple = [batch_decode_outputs.logits[:, j, :] for j in range(batch_decode_outputs.logits.shape[1]-1)]
            batch_decode_transition_scores = score_model.compute_transition_scores(batch_hyp_tokens[:, 1:], batch_decode_outputs_logits_tuple, normalize_logits=True)
            batch_decode_transition_scores_masked = batch_decode_transition_scores * tokenized_candidates.attention_mask[i:i+batch_size]
            batch_decode_sequence_score = batch_decode_transition_scores_masked[:, 1:].sum(dim=-1)/tokenized_candidates.attention_mask[i:i+batch_size, 1:].sum(dim=-1)
            scores.extend(batch_decode_sequence_score.cpu().detach().numpy().tolist())
    scores = [(candidate, score) for candidate, score in zip(all_candidates, scores)]
    scores.sort(key=lambda x: x[1], reverse=True)
    return scores

In [None]:
def calculate_all_sequence_scores(summaries, score_model, score_tokenizer, batch_size, num_latents, std_scale, num_beams):
    all_sequence_scores = []
    for source_code, target, predicted in tqdm(summaries):
        all_sequence_scores.append([source_code,
                                    target,
                                    score_model_sequence_scoring(source_code, predicted, score_model, score_tokenizer, batch_size)])
    all_ref_fw = open(os.path.join("T5_results", "generated", lang, f"num_latents_{num_latents}_std_scale_{std_scale}_beam_{num_beams}_hypotheses_with_seq_score_all.json"), 'w')
    json.dump(all_sequence_scores, all_ref_fw, indent=4)
    all_ref_fw.close()
    return all_sequence_scores

In [None]:
all_sequence_scores = calculate_all_sequence_scores(all_summaries_dict, backbone_model, tokenizer, batch_size=100, num_latents=100, std_scale=5, num_beams=4)

In [None]:
def compute_pairwise_bleu(summaries):
    distinct_summaries = set(summaries)
    distinct_summaries_list = list(distinct_summaries)
    # assign a unique id to each distinct summary
    summary_to_id = {summary: i for i, summary in enumerate(distinct_summaries_list)}
    # compute pairwise BLEU scores
    distinct_pairwise_bleu_scores = np.zeros((len(distinct_summaries_list), len(distinct_summaries_list)))
    for i, summary_i in enumerate(distinct_summaries_list):
        distinct_pairwise_bleu_scores[i, i] = 1.0
        for j, summary_j in enumerate(distinct_summaries_list[i + 1:]):
            ref = [summary_i.split()]
            hyp = summary_j.split()
            distinct_pairwise_bleu_scores[i, j+i+1] = compute_bleu([ref], [hyp], smooth=True)[0]
            distinct_pairwise_bleu_scores[j+i+1, i] = distinct_pairwise_bleu_scores[i, j+i+1]

    pairwise_bleu = np.zeros((len(summaries), len(summaries)))
    pairwise_bleu += np.eye(len(summaries)) * 1.0
    for i, summary_i in enumerate(summaries):
        for j, summary_j in enumerate(summaries[i+1:]):
            pairwise_bleu[i, j+i+1] = distinct_pairwise_bleu_scores[summary_to_id[summary_i], summary_to_id[summary_j]]
            pairwise_bleu[j+i+1, i] = pairwise_bleu[i, j+i+1]

    return distinct_pairwise_bleu_scores, summary_to_id, pairwise_bleu
    

In [None]:
pairwise_diversity_scores = []
for source_code, target, predicted_and_score in tqdm(all_sequence_scores):
    predicted = [item[0] for item in predicted_and_score]
    pairwise_diversity_scores.append(compute_pairwise_bleu(predicted))

In [None]:
import didppy as dp
# bi-criteria subset selection
def find_optimal_pair(h_arr, g_arr, alpha=1, beta=1, count=1, return_set=True, return_score=False):
    if count < 0:
        raise ValueError("Count must be non-negative")
    arr = np.zeros_like(h_arr)
    if alpha == 0 and beta == 0:
        raise ValueError('This does not make sense! Alpha and Beta are both zeros.')
    if beta != 0:
        arr += beta*h_arr
    if alpha != 0:
        g_arr_i = g_arr[:, np.newaxis]
        g_arr_j = g_arr[np.newaxis, :]
        arr += alpha*g_arr_i + alpha*g_arr_j
    # Ensure the array is 2D and symmetric
    if arr.ndim != 2 or not np.allclose(arr, arr.T):
        raise ValueError(f"Array must be 2D and symmetric, but array is: {arr}")

    # Create a mask for the upper triangle, including the diagonal
    upper_triangle_mask = np.triu(np.ones_like(arr, dtype=bool))

    # Apply mask and find unique values in the lower triangle
    masked_arr = np.ma.masked_array(arr, mask=upper_triangle_mask)
    unique_values = np.unique(masked_arr.compressed())[::-1]  # Sorted in descending order
    max_indices = []
    i = 0  # Index to iterate over unique values
    while len(max_indices) < count and i < len(unique_values):
        # Find positions of the current max value
        current_max_value = unique_values[i]
        current_positions = np.ma.where(masked_arr == current_max_value)
        # Combine the row and column indices into tuples
        current_max_indices = [{i, j} if return_set else [i, j] for i, j in zip(current_positions[0],
                                                                                current_positions[1])]
        # Ensure the indices are unique
        if return_score:
            max_indices.extend([(current_max_value, idx) for idx in current_max_indices])
        else:
            max_indices.extend(current_max_indices)
        i += 1

    # Ensure we do not exceed the count
    if count == 1:
        max_indices = max_indices[0]
    else:
        max_indices = max_indices[:count]
    return max_indices

def HDBS(diversity_matrix, cost_dict, summary_to_id, alpha=1, beta=1, criteria='sum',
         k=5, beam_width=256, seed=42, threads=64, initialize=False, parallelization='hdbs1'):
    if len(summary_to_id) <= k:
        ids = []
        scores = []
        for summary in summary_to_id:
            ids.append(summary_to_id[summary])
            scores.append(cost_dict[summary])
        return ids, scores

    np.random.seed(seed)
    diversity_matrix = 1 - diversity_matrix
    cost_vector = np.zeros(len(cost_dict))
    for summary in cost_dict:
        cost_vector[summary_to_id[summary]] = cost_dict[summary]

    np.random.seed(seed)
    model = dp.Model(maximize=True, float_cost=True)
    n = diversity_matrix.shape[0]

    selection = model.add_object_type(number=n)
    selected = model.add_set_var(object_type=selection, target=[], name="selected")

    # if criteria == 'sum':
    #     diversity_matrix = diversity_matrix*1000000
    # else:
    #     diversity_matrix = (diversity_matrix +
    #                         np.diag(np.diag(diversity_matrix + 1))) * 1000000
    # 
    # cost_vector = cost_vector*1000000
    cst = model.add_float_table(cost_vector)
    div = model.add_float_table(diversity_matrix)

    if criteria == 'sum':
        if initialize:
            p1, p2 = find_optimal_pair(diversity_matrix,
                                       cost_vector,
                                       alpha=alpha, beta=beta)
            first_selection = True
            for j in range(n):
                if j == p1 or j == p2 and first_selection:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=beta*div[j, selected] + alpha*cst[j] + dp.FloatExpr.state_cost(),
                        effects=[
                            (selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select, forced=True)
                    first_selection = False
                else:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=beta*div[j, selected] + alpha*cst[j] + dp.FloatExpr.state_cost(),
                        effects=[
                            (selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select)
        else:
            for j in range(n):
                select = dp.Transition(
                    name=f"select_{j}",
                    cost=beta*div[j, selected] + alpha*cst[j] + dp.FloatExpr.state_cost(),
                    effects=[
                        (selected, selected.add(j))],
                    preconditions=[selected.len() < k, selected.complement().contains(j)],
                )
                model.add_transition(select)
    elif criteria == 'min':
        if beam_width == 1:
            p1, p2 = find_optimal_pair(diversity_matrix,
                                       cost_vector,
                                       alpha=alpha, beta=beta)
            for j in range(n):
                if j == p1 or j == p2:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=dp.min(alpha*100000+beta*100000,
                                    dp.FloatExpr.state_cost()+alpha*100000+beta*100000),
                        effects=[(selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select, forced=True)
                else:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=dp.min(alpha*dp.min(cst.min(selected), cst[j]) + beta*dp.min(div.min(selected, selected),
                                                                                          div.min(j, selected)),
                                    dp.FloatExpr.state_cost()),
                        effects=[(selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select)
        elif initialize:
            p1, p2 = find_optimal_pair(diversity_matrix,
                                       cost_vector,
                                       alpha=alpha, beta=beta)
            first_selection = True
            for j in range(n):
                if j == p1 or j == p2 and first_selection:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=dp.min(alpha*100000+beta*100000,
                                    dp.FloatExpr.state_cost()+alpha*100000+beta*100000),
                        effects=[(selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select, forced=True)
                    first_selection = False
                else:
                    select = dp.Transition(
                        name=f"select_{j}",
                        cost=dp.min(alpha*dp.min(cst.min(selected), cst[j]) + beta*dp.min(div.min(selected, selected),
                                                                                          div.min(j, selected)),
                                    dp.FloatExpr.state_cost()),
                        effects=[(selected, selected.add(j))],
                        preconditions=[selected.len() < k, selected.complement().contains(j)],
                    )
                    model.add_transition(select)
        else:
            for j in range(n):
                select_first = dp.Transition(
                    name=f"selectf_{j}",
                    cost=dp.min(alpha*100000+beta*100000,
                                dp.FloatExpr.state_cost()+alpha*100000+beta*100000),
                    effects=[(selected, selected.add(j))],
                    preconditions=[selected.len() == 0, selected.complement().contains(j)],
                )
                model.add_transition(select_first)
                select = dp.Transition(
                    name=f"select_{j}",
                    cost=dp.min(alpha*dp.min(cst.min(selected), cst[j]) + beta*dp.min(div.min(selected, selected),
                                                                                      div.min(j, selected)),
                                dp.FloatExpr.state_cost()),
                    effects=[(selected, selected.add(j))],
                    preconditions=[selected.len() < k, selected.len() > 0, selected.complement().contains(j)],
                )
                model.add_transition(select)

    model.add_base_case([selected.len() == k])
    if parallelization == 'hdbs1':
        method = dp.BeamParallelizationMethod.Hdbs1
    elif parallelization == 'hdbs2':
        method = dp.BeamParallelizationMethod.Hdbs2
    else:
        method = dp.BeamParallelizationMethod.Sbs
    solver = dp.CABS(model,
                     parallelization_method=method,
                     initial_beam_size=beam_width,
                     max_beam_size=beam_width,
                     threads=threads,
                     quiet=True,
                     time_limit=1)
    solution = solver.search()
    # print([t for t in solution.transitions])
    beam_result = [int(t.name.split('_')[1]) for t in solution.transitions]
    return beam_result, solution.cost/100000

In [None]:
def bicriteria_subset_selection(summaries, num_distinct_summary_list, std_scale, pairwise_diversity_scores, alpha=1, beta=1):
    # sequence_scores = []

    for num_distinct_summary in num_distinct_summary_list:
        # keep at most num_distinct_summary distinct summaries
        distinct_summaries = []
        distinct_guesses = []
        repeated_targets = []
        separated_hypotheses = []
        for i, summaries_lst in tqdm(enumerate(summaries)):
            source_code, target, model_scores = summaries_lst
            # predicted = [item[0] for item in predicted_and_score]
            # sequence_scores = [item[1] for item in predicted_and_score]
            distinct_div_matrix, summary_to_id, _ = pairwise_diversity_scores[i]

            # quality_vec = all_codeT5_scores[i]
            # distinct_sum = list(set(predicted))
            # model_scores = sequence_scores[i]
            model_scores = [(summary, score) for summary, score in model_scores]
            model_scores_distinct = list(set(model_scores))
            model_scores_distinct.sort(key=lambda x: x[1], reverse=True)
            if len(model_scores_distinct) < num_distinct_summary:
                distinct_sum = [summary for summary, _ in model_scores_distinct]
            else:
                quality_vec = {summary: score for summary, score in model_scores}
                id_to_summary = {i: summary for summary, i in summary_to_id.items()}
                ids_to_take = HDBS(distinct_div_matrix, quality_vec, summary_to_id, alpha=alpha, beta=beta, k=num_distinct_summary, beam_width=256, seed=42)[0]
                distinct_sum = [id_to_summary[id] for id in ids_to_take]
            distinct_guesses.append(len(distinct_sum))
            separated_hypotheses.append([summary.strip() for summary in distinct_sum])
            distinct_summaries.extend(distinct_sum)
            repeated_targets.extend([target]*len(distinct_sum))
        all_distinct_unigrams_ratio = [distinct_n_corpus_level(preds, n=1) for preds in separated_hypotheses]
        all_distinct_bigrams_ratio = [distinct_n_corpus_level(preds, n=2) for preds in separated_hypotheses]
        all_self_bleu = [self_bleu_score(preds) for preds in separated_hypotheses]
        average_distinct_unigrams_ratio = np.mean(all_distinct_unigrams_ratio)
        average_distinct_bigrams_ratio = np.mean(all_distinct_bigrams_ratio)
        average_self_bleu = np.mean(all_self_bleu)
        hypotheses = dict(enumerate([[summary.strip().lower()] for summary in distinct_summaries]))
        references = dict(enumerate([[target.strip().lower()] for target in repeated_targets]))
        #calculate oracle scores
        print("average distinct guesses: ", np.mean(distinct_guesses))
        _, bleu, ind_bleu = corpus_bleu(hypotheses, references)
        ind_bleu_input = iter(list(ind_bleu.values()))
        sliced_bleu = [list(islice(ind_bleu_input, elem))
                       for elem in distinct_guesses]
        np_sliced_bleu = np.array(list(zip_longest(*sliced_bleu, fillvalue=0))).T
        oracle_bleu = np.max(np_sliced_bleu, axis=1)

        rouge_calculator = Rouge()
        rouge_l, ind_rouge = rouge_calculator.compute_score(references, hypotheses)
        ind_rouge_input = iter(list(ind_rouge.values()))
        sliced_rouge = [list(islice(ind_rouge_input, elem))
                        for elem in distinct_guesses]
        np_sliced_rouge = np.array(list(zip_longest(*sliced_rouge, fillvalue=0))).T
        oracle_rouge = np.max(np_sliced_rouge, axis=1)


        if len(summaries) > 1500:
            meteor_calculator = Meteor()
            meteor, ind_meteor = meteor_calculator.compute_score(references, hypotheses)
            ind_meteor_input = iter(list(ind_meteor))
            sliced_meteor = [list(islice(ind_meteor_input, elem))
                             for elem in distinct_guesses]
            np_sliced_meteor = np.array(list(zip_longest(*sliced_meteor, fillvalue=0))).T
            oracle_meteor = np.max(np_sliced_meteor, axis=1)
        else:
            oracle_meteor = 0.0
        # print("Oracle scores, bleu: ", np.mean(oracle_bleu) * 100, " rouge-l: ", np.mean(oracle_rouge) * 100, " meteor: ", np.mean(oracle_meteor) * 100)
        # write scores to file

        bertscore = load("bertscore")

        bert_precision = []
        bert_recall = []
        bert_f1 = []
        # Define your batch size
        batch_size = 5120
        # Create an instance of your dataset
        dataset = TextDataset(hypotheses, references)
        # Create a DataLoader instance
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        # Iterating over batches and compute BERTScore for each with a progress bar
        for hypotheses_batch, references_batch in tqdm(dataloader, desc="Processing batches"):
            batch_results = bertscore.compute(predictions=hypotheses_batch,
                                              references=references_batch,
                                              lang="en")
            bert_precision.extend(list(batch_results['precision']))
            bert_recall.extend(list(batch_results['recall']))
            bert_f1.extend(list(batch_results['f1']))

        bert_precision_input = iter(bert_precision)
        sliced_bert_precision = [list(islice(bert_precision_input, elem))
                                 for elem in distinct_guesses]
        np_sliced_bert_precision = np.array(list(zip_longest(*sliced_bert_precision, fillvalue=0))).T
        best_bert_precision = np.max(np_sliced_bert_precision, axis=1)

        bert_recall_input = iter(bert_recall)
        sliced_bert_recall = [list(islice(bert_recall_input, elem))
                              for elem in distinct_guesses]
        np_sliced_bert_recall = np.array(list(zip_longest(*sliced_bert_recall, fillvalue=0))).T
        best_bert_recall = np.max(np_sliced_bert_recall, axis=1)

        bert_f1_input = iter(bert_f1)
        sliced_bert_f1 = [list(islice(bert_f1_input, elem))
                          for elem in distinct_guesses]
        np_sliced_bert_f1 = np.array(list(zip_longest(*sliced_bert_f1, fillvalue=0))).T
        best_bert_f1 = np.max(np_sliced_bert_f1, axis=1)

        with open(os.path.join("T5_results", f"CodeT5+_VPT_bicriteria_subset_selection_results_{lang}.txt"), 'a') as f:
            f.write(f"num_distinct_summary: {num_distinct_summary} | std_scale: {std_scale:.1f} | alpha: {alpha} | beta: {beta} | average_distinct_guesses: {np.mean(distinct_guesses): .4f} | average_distinct_unigrams_ratio: {average_distinct_unigrams_ratio * 100: .4f} | average_distinct_bigrams_ratio: {average_distinct_bigrams_ratio * 100: .4f} | average_self_bleu: {average_self_bleu * 100: .4f} | Oracle_bleu: {np.mean(oracle_bleu) * 100: .4f} | Oracle_rouge-l: {np.mean(oracle_rouge) * 100: .4f} | Oracle_meteor: {np.mean(oracle_meteor) * 100: .4f} | Oracle_bert_precision: {np.mean(best_bert_precision) * 100: .4f} | Oracle_bert_recall: {np.mean(best_bert_recall) * 100: .4f} | Oracle_bert_f1: {np.mean(best_bert_f1) * 100: .4f}\n")
        print(f"num_distinct_summary: {num_distinct_summary} | std_scale: {std_scale:.1f} | alpha: {alpha} | beta: {beta} | average_distinct_guesses: {np.mean(distinct_guesses): .4f} | average_distinct_unigrams_ratio: {average_distinct_unigrams_ratio * 100: .4f} | average_distinct_bigrams_ratio: {average_distinct_bigrams_ratio * 100: .4f} | average_self_bleu: {average_self_bleu * 100: .4f} | Oracle_bleu: {np.mean(oracle_bleu) * 100: .4f} | Oracle_rouge-l: {np.mean(oracle_rouge) * 100: .4f} | Oracle_meteor: {np.mean(oracle_meteor) * 100: .4f} | Oracle_bert_precision: {np.mean(best_bert_precision) * 100: .4f} | Oracle_bert_recall: {np.mean(best_bert_recall) * 100: .4f} | Oracle_bert_f1: {np.mean(best_bert_f1) * 100: .4f}\n")


In [None]:
# alpha and beta should be tuned for your dataset
bicriteria_subset_selection(all_sequence_scores, [10, 20], 5, pairwise_diversity_scores, alpha=1.0, beta=0.5)