# Prompt Tuning


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from transformers import AutoConfig
from torch.utils.data import DataLoader
import os
os.environ["MPLBACKEND"] = "Agg"

import matplotlib
matplotlib.use('Agg', force=True)
from matplotlib import pyplot as plt

import random
from tqdm import tqdm
import numpy as np
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from collections import deque

In [None]:
#Prepare dataset
data = []
with open("sec-desc.jsonl", 'r') as file:
    for i in file:
        data.append(json.loads(i))

In [None]:
#Fine tuining the embedding model through a classification head
class ModuleEmbedderHead(nn.Module):
    def __init__(self, embedding_model):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(embedding_model)

     #Classification head that output logits
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 1))

    #Encoding function
    def encoding(self, inputs):
        outputs = self.encoder(**inputs)
        #CLS Embeddings
        embedding = outputs.last_hidden_state[:, 0]
        return embedding

    #Classification function
    def classifying(self, inputs):
        return self.classifier(inputs)

    #Forward function
    def forward(self, inputs):
        embeddings = self.encoding(inputs)
        output_logits = self.classifying(embeddings)
        return output_logits

In [None]:
#Train Classifier function
def train_classifier(model, tokenizer, dataloader, data, epochs=5, lr=2e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    #Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    #Initialize loss
    bce= nn.BCEWithLogitsLoss()

    torch.manual_seed(42);
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in tqdm(dataloader):
            safe=[]
            unsafe=[]
            for i in batch:
                safe.append(i["func_src_after"])
                unsafe.append(i["func_src_before"])
            #Tokenize the inputs
            safe_tokens = tokenizer(safe, return_tensors="pt", truncation=True, padding=True).to(device)
            unsafe_tokens = tokenizer(unsafe, return_tensors="pt", truncation=True, padding=True).to(device)

            #Foward pass to get the logits
            pos_logits = model(safe_tokens)
            neg_logits = model(unsafe_tokens)

            #Get true labels (1 if safe and 0 if unsafe)
            pos_labels = torch.ones_like(pos_logits)
            labels_neg = torch.zeros_like(neg_logits)

            #Calculate total loss
            safe_loss = bce(pos_logits, pos_labels)
            unsafe_loss = bce(neg_logits, labels_neg)
            loss = safe_loss + unsafe_loss

            #Update model
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Avg Loss = {total_loss / len(data)}")

    torch.save(model.state_dict(), "classifier.pt")

In [None]:
#Split data into train and test
train_triplets, test_triplets = train_test_split(
    data, test_size=0.2, random_state=42)
batch_size = 1

#Prepare data for train
dataloader = DataLoader(train_triplets, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

#Define the embedding model
embedding_model = "microsoft/graphcodebert-base"
#Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
#Initialize classifier
classifier_model = ModuleEmbedderHead(embedding_model)
#Train
train_classifier(classifier_model, tokenizer , dataloader, train_triplets, epochs=8)

In [None]:
#Defining a module that learns a small number of soft prompt to be prepended to the input
class SoftPrompt(nn.Module):
    def __init__(self, num_virtual_tokens, embed_dim):
        super().__init__()
        #Initizalize a learnable matrix of shape (num_virtual_tokens,embed_dim) - randomly initialized
        self.embedding = nn.Embedding(num_virtual_tokens, embed_dim)

    def forward(self, batch_size):
        #Create a 1D tensor
        indices = torch.arange(self.embedding.num_embeddings)
        #Add a batch dimension and replicate and put tensor on the same device of the weights
        indices = indices.unsqueeze(0).repeat(batch_size, 1).to(self.embedding.weight.device)
        #return a tensor of shape (batch_size, num_virtual_tokens, embed_dim) that can be prepended to the input
        embeddings=self.embedding(indices)
        return embeddings

In [None]:
#Create semantic memory
class Memory:
    def __init__(self, max_size=5000):
        self.memory = deque(maxlen=max_size)

    #Add to memory method
    def add(self, embedding, reward, description, code):
        self.memory.append({"embedding": embedding.detach().cpu(), "reward": reward, "description": description, "code": code})

    #Retrieve the best n past similar embeddings method
    def get_past_embeddings(self, query, top_n, min_reward=0.5):

        #If the memory is empty return an empty list
        if len(self.memory) == 0:
            return []

        # Filter memory by reward
        filter_memory = []
        for i in self.memory:
            if i["reward"] >= min_reward:
                filter_memory.append(i)

        #If there is no safe memory return empty list
        if len(filter_memory) == 0:
            return []

        #Get only the list of embeddings
        past_embeddings = []
        for i in filter_memory:
            past_embeddings.append(i["embedding"])

        #Convert to tensor
        past_embeddings = torch.stack(past_embeddings)

        #Compute cosine similarity between query and embeddings
        cos_similarity = F.cosine_similarity(query.cpu().unsqueeze(0), past_embeddings, dim=1)

        #Get top-n most similar indices
        topn_indices = torch.topk(cos_similarity, k=min(top_n, len(cos_similarity)))[1]

        #Save the code in text form of the best most semantically similar embeddings
        best_memory = []
        for i in topn_indices:
            best_memory.append(filter_memory[i]["code"])
        #Return best
        return best_memory


    #Build augmented prompt function
    def augment_prompt(self, description, desc_embed, top_n=3):

        #Get best past examples
        best_past_examples = self.get_past_embeddings(desc_embed, top_n=top_n)

        #Build augmented prompt with the past safe code in text form 
        context = []
        for i in best_past_examples:
            context.append(f"# Past sampple:\n{i}\n")
        augmented_prompt = "\n".join(context)

        #Contenate total prompt
        full_prompt = augmented_prompt + f"\n# Task:\n{description}"

        return full_prompt


In [None]:
def train_soft_prompt(model, soft_prompt, dataloader, tokenizer, classifier, classifier_tokenizer, semantic_memory, num_epochs=3, lr=1e-3, alpha=0.9):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    soft_prompt.to(device)

    #Load model and freeze it
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False

    #Load classifier for inference and optimizer
    classifier.eval().to(device)
    optimizer = torch.optim.Adam(soft_prompt.parameters(), lr=lr)

    #Initialize metrics logs
    loss_log=[]
    sim_safe_log=[]
    sim_unsafe_log=[]
    reward_log =[]

    #Initialize baseline
    baseline = 0.0

    torch.manual_seed(42);
    for epoch in range(num_epochs):
        soft_prompt.train()
        epoch_loss = 0
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        for batch in tqdm(dataloader):
           desc=[]
           safe=[]
           unsafe=[]
           for i in batch:
                desc= [i["description"]]
                safe= [i["func_src_after"]]
                unsafe= [i["func_src_before"]]

           #Tokenize description and target
           inputs = tokenizer(desc, return_tensors="pt", padding=True, truncation=True).to(device)
           target = tokenizer(safe, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

           #Get input embeddings
           input_embeds=classifier.encoding(inputs).squeeze(0)

           #Generate a batch of virtual soft embeddings of the same size of the batch (1)
           soft_prompts = soft_prompt(1)

           #Build augmented Prompt
           augmented_prompt = semantic_memory.augment_prompt(description=desc, desc_embed=input_embeds, top_n=3)

           #Tokenize full prompt
           inputs_concat = tokenizer(augmented_prompt, return_tensors="pt", padding=True, truncation=True).to(device)

           #Get full input embeddings
           input_embeds2 = model.get_input_embeddings()(inputs_concat.input_ids)

           #Concatenate learnable emebeddings to full input embeddings
           full_embeds = torch.cat([soft_prompts, input_embeds2], dim=1)

           #Construct attention mask of 1s
           prompt_mask = torch.ones(1, soft_prompts.size(1), device=device)
           full_attention_mask = torch.cat([prompt_mask, inputs_concat.attention_mask], dim=1)


           #Forward pass
           outputs = model(inputs_embeds=full_embeds, attention_mask=full_attention_mask, labels=target.to(device))
           loss = outputs.loss

           #Sample the generated output
           gen_ids= model.generate(
                   inputs_embeds=full_embeds, attention_mask=full_attention_mask,
                   do_sample=True, temperature=0.7, top_k=50)

           #Decode the generated output
           text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

           #Tokenize the generated code using the classifier tokenizer
           gen_inputs = classifier_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

           #Classify and get reward (probability of being safe)
           with torch.no_grad():
               logit = classifier(gen_inputs).squeeze()
               reward = torch.sigmoid(logit).item()

           #REINFORCE update
           reward_adj = reward - baseline
           loss_reinforce = reward_adj * loss

           #EMA Baseline update
           baseline = alpha * baseline + (1 - alpha) * reward

           #Similarity metrics
           sim_safe = F.cosine_similarity(
            classifier.encoding(gen_inputs).squeeze(0),
            classifier.encoding(classifier_tokenizer(safe, return_tensors="pt", truncation=True, padding=True).to(device)).squeeze(0),
            dim=0).item()
           sim_unsafe = F.cosine_similarity(
                classifier.encoding(gen_inputs).squeeze(0),
                classifier.encoding(classifier_tokenizer(unsafe, return_tensors="pt", truncation=True, padding=True).to(device)).squeeze(0),
                dim=0).item()


           #Tokenize the generated text
           gen_inputs = classifier_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
           #Encode the generated text
           embeds = classifier.encoding(gen_inputs).squeeze(0)
           #Add to memory
           semantic_memory.add(embedding=embeds, reward=reward, description=desc, code=text)


           loss = loss_reinforce
           loss.backward()
           #Update prompt tuner parameters
           optimizer.step()
           #Clear gradients to no accumulate
           optimizer.zero_grad()


           #Update metrics
           epoch_loss += loss_reinforce.item()
           loss_log.append(loss_reinforce.item())
           sim_safe_log.append(sim_safe)
           sim_unsafe_log.append(sim_unsafe)
           reward_log.append(reward)

           print(f"Loss: {loss.item()} | Reward: {reward} | Safe Sim: {sim_safe} | Unsafe Sim: {sim_unsafe}")

        print(f"Avg {epoch+1} Epoch Loss: {epoch_loss / len(dataloader):}")

    return loss_log, sim_safe_log, sim_unsafe_log, reward_log


In [None]:
#Train the prompt Tuner
model_name = "google/flan-t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

num_soft_tokens = 10
batch_size = 1
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

#Initialize prompt tuner
soft_prompt= SoftPrompt(num_soft_tokens, config.d_model)

#Initialize classifier
classifier = ModuleEmbedderHead("microsoft/graphcodebert-base")
classifier.load_state_dict(torch.load("classifier.pt"))
classifier.eval()
classifier_tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")


semantic_memory = Memory()

loss_log, sim_safe_log, sim_unsafe_log, reward_log= train_soft_prompt(
    model_name, soft_prompt, dataloader, tokenizer, classifier, classifier_tokenizer, semantic_memory=semantic_memory, num_epochs=20, alpha=0.9)

torch.save(soft_prompt.state_dict(), "prompt_tuner.pt")


In [None]:
#Plot training metrics
steps = list(range(1, len(loss_log) + 1))
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(steps, loss_log, linestyle="-", marker="o", markersize=4, color="steelblue")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("REINFORCE Loss")
plt.grid()

plt.subplot(2, 2, 2)
plt.plot(steps, reward_log, linestyle="-", marker="o", markersize=4, color="#FFDAB9")
plt.xlabel("Step")
plt.ylabel("Probability of safe code")
plt.title("Reward Over Time")
plt.grid()

plt.subplot(2, 2, 3)
plt.plot(steps, sim_safe_log, linestyle="-", marker="o", markersize=4, color="slategray")
plt.xlabel("Step")
plt.ylabel("Safe Similarity")
plt.title("Similarity to Safe Code")
plt.grid()

plt.subplot(2, 2, 4)
plt.plot(steps, sim_unsafe_log, linestyle="-", marker="o", markersize=4, color="#D7BDE2")
plt.xlabel("Step")
plt.ylabel("Unsafe Similarity")
plt.title("Similarity to Unsafe Code")
plt.grid()

plt.tight_layout()
plt.savefig("training_stats.png")
plt.show()

In [None]:
#Plot average training metrics per epoch 
num_epochs=20

avg_loss=[]
avg_reward= []
avg_sim_safe=[]
avg_sim_unsafe= []

for epoch_losses, epoch_rewards, epoch_safe, epoch_unsafe in zip(
    np.array_split(loss_log, num_epochs),
    np.array_split(reward_log, num_epochs),
    np.array_split(sim_safe_log, num_epochs),
    np.array_split(sim_unsafe_log, num_epochs)):

    avg_loss.append(np.mean(epoch_losses))
    avg_reward.append(np.mean(epoch_rewards))
    avg_sim_safe.append(np.mean(epoch_safe))
    avg_sim_unsafe.append(np.mean(epoch_unsafe))

plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(range(1, num_epochs + 1), avg_loss, marker='o', color="steelblue")
plt.title("Avg Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Avg Loss")
plt.grid(True)

plt.subplot(2, 2, 2)
plt.plot(range(1, num_epochs + 1), avg_reward, marker='o', color="#FFDAB9")
plt.title("Avg Classifier Reward per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Avg Reward")
plt.grid(True)

plt.subplot(2, 2, 3)
plt.plot(range(1, num_epochs + 1), avg_sim_safe, marker='o', color="slategray")
plt.title("Avg Safe Similarity per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Cosine Similarity to Safe Code")
plt.grid(True)

plt.subplot(2, 2, 4)
plt.plot(range(1, num_epochs + 1), avg_sim_unsafe, marker='o', color="#D7BDE2")
plt.title("Avg Unsafe Similarity per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Cosine Similarity to Unsafe Code")
plt.grid(True)

plt.tight_layout()
plt.savefig("epochwise_metrics.png")
plt.show()

In [None]:

#Take embeddings and rewards
emb=[]
for i in semantic_memory.memory:
    emb.append(torch.stack([i["embedding"]]))
rew=[]
for i in semantic_memory.memory:
    emb.append(np.array([i["reward"]]))

#Normalize rewards
rew_norm = (rew - rew.min()) / (rew.max() - rew.min() + 1e-8)
#t-SNE
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
result = tsne.fit_transform(emb)
#Plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(result[:, 0], result[:, 1], c=rew_norm, cmap='viridis', s=20)
plt.colorbar(scatter, label="Normalized Reward")
plt.title("Visualization of Semantic Memory Embeddings")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.grid()
plt.savefig("TSNE.png",  dpi=300, bbox_inches='tight')
plt.show()