In [33]:
import torch
from torch.nn import functional
from torch.cuda.amp import autocast
import random
from functools import partial
from random import choice
import numpy as np
from numpy.random import choice as np_choice
import pandas as pd
import gc
import time
from typing import List
from copy import deepcopy
import kagglehub
import os

import csv
from time import time

from transformers import AutoTokenizer, T5Tokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

assert torch.cuda.is_available()

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [21]:
EPOCHS = 100
BATCH_SIZE = 32
LEARNING_RATE = 1e-5

max_new_tokens = 30
max_sentences_in_response = 1

# Dataset

In [7]:
# Prompt_dataset
path = kagglehub.dataset_download("what5up/concat-prompts")
files = os.listdir(path)
csv_file = [file for file in files if file.endswith('.csv')][0]
data = pd.read_csv(os.path.join(path, csv_file))
prompt_dataset = data['prompt'].tolist()

Downloading from https://www.kaggle.com/api/v1/datasets/download/what5up/concat-prompts?dataset_version_number=1...


100%|██████████| 157k/157k [00:00<00:00, 355kB/s]

Extracting files...





In [8]:
# Vocabulary dataset
tokenizer = T5Tokenizer.from_pretrained("sentence-transformers/sentence-t5-base")
vocabulary_dataset = tokenizer.get_vocab()
vocabulary_dataset.pop('<pad>', None)
vocabulary_dataset.pop('</s>', None)
vocabulary_dataset.pop('<unk>', None)
vocabulary_dataset = list(vocabulary_dataset.keys())

In [9]:
print(prompt_dataset[:5])
print(vocabulary_dataset[:5])

['Convert the text into a vintage circus poster announcement', "Convert the text into a social media platform's community guidelines", 'Rewrite this as a college course description.', 'Rephrase this as a debate on furniture rights, featuring chairs.', "Make the text into a home improvement expert's tips for a bathroom remodel"]
['▁', 'X', '.', ',', 's']


# Mean Prompt

In [34]:
def cosine_similarity(v, u):
    v_norm = functional.normalize(v, p=2, dim=-1)
    u_norm = functional.normalize(u, p=2, dim=-1)
    similarity = torch.matmul(v_norm, u_norm.T)
    return similarity

In [36]:
def hotflip(token_embeddings, token_index, gradient, embedding_matrix):
    inner_product = torch.matmul(embedding_matrix, gradient[token_index].T)
    vocab_index = inner_product.argmax().item()
    token_embeddings_clone = token_embeddings.clone()
    token_embeddings_clone[token_index] = embedding_matrix[vocab_index]
    return token_embeddings_clone

In [10]:
# Model
model_name = "sentence-transformers/sentence-t5-base"
model = SentenceTransformer(model_name)
for param in model.parameters():
    param.requires_grad = False

modules.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

2_Dense/config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

rust_model.ot:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

In [30]:
# Initial prompt
prompt = "Rewrite this text to make it more helpful"
prompt_tokens = prompt.split(' ')
token_embeddings = model.encode(prompt_tokens)
token_embeddings = torch.tensor(token_embeddings, dtype=torch.float32, requires_grad=True)

# Embedding matrix
embedding_matrix = model.encode(vocabulary_dataset)
embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)

# Target
target_embeddings = model.encode(prompt_dataset)
target_embeddings = torch.tensor(target_embeddings, dtype=torch.float32).mean(dim=0).unsqueeze(0)

print(token_embeddings.shape)
print(embedding_matrix.shape)
print(target_embeddings.shape)

torch.Size([8, 768])
torch.Size([32097, 768])
torch.Size([1, 768])


In [37]:
for epoch in range(EPOCHS):
    for token_index in range(len(prompt_tokens)):
        if token_embeddings.grad is not None:
            token_embeddings.grad.zero_()

        loss =  1 - cosine_similarity(target_embeddings, token_embeddings).mean()
        loss.backward()
        gradient = token_embeddings.grad.clone()

        token_embeddings = hotflip(token_embeddings, token_index, gradient, embedding_matrix)
        token_embeddings = token_embeddings.detach().requires_grad_(True)

    print(f"Epoch: {epoch}, Loss: {loss.item()}")
    # if epoch % 100 == 0:
    #     print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 0.15250426530838013
Epoch: 1, Loss: 0.15250426530838013
Epoch: 2, Loss: 0.15250426530838013
Epoch: 3, Loss: 0.15250426530838013
Epoch: 4, Loss: 0.15250426530838013
Epoch: 5, Loss: 0.15250426530838013
Epoch: 6, Loss: 0.15250426530838013
Epoch: 7, Loss: 0.15250426530838013
Epoch: 8, Loss: 0.15250426530838013
Epoch: 9, Loss: 0.15250426530838013
Epoch: 10, Loss: 0.15250426530838013
Epoch: 11, Loss: 0.15250426530838013
Epoch: 12, Loss: 0.15250426530838013
Epoch: 13, Loss: 0.15250426530838013
Epoch: 14, Loss: 0.15250426530838013
Epoch: 15, Loss: 0.15250426530838013
Epoch: 16, Loss: 0.15250426530838013
Epoch: 17, Loss: 0.15250426530838013
Epoch: 18, Loss: 0.15250426530838013
Epoch: 19, Loss: 0.15250426530838013
Epoch: 20, Loss: 0.15250426530838013
Epoch: 21, Loss: 0.15250426530838013
Epoch: 22, Loss: 0.15250426530838013
Epoch: 23, Loss: 0.15250426530838013
Epoch: 24, Loss: 0.15250426530838013
Epoch: 25, Loss: 0.15250426530838013
Epoch: 26, Loss: 0.15250426530838013
Epoch: 27, 