In [1]:
import torch
from torch.nn import functional
from torch.cuda.amp import autocast
import random
from functools import partial
from random import choice
import numpy as np
from numpy.random import choice as np_choice
import pandas as pd
import gc
import time
from typing import List
from copy import deepcopy
import kagglehub
import os

import csv
from time import time

from transformers import AutoTokenizer, T5Tokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

assert torch.cuda.is_available()

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [23]:
EPOCHS = 1
BATCH_SIZE = 32
LEARNING_RATE = 1e-5

max_new_tokens = 30
max_sentences_in_response = 1

# Dataset

In [4]:
# Prompt_dataset
path = kagglehub.dataset_download("what5up/concat-prompts")
files = os.listdir(path)
csv_file = [file for file in files if file.endswith('.csv')][0]
data = pd.read_csv(os.path.join(path, csv_file))
prompt_dataset = data['prompt'].tolist()

Downloading from https://www.kaggle.com/api/v1/datasets/download/what5up/concat-prompts?dataset_version_number=1...


100%|██████████| 157k/157k [00:00<00:00, 498kB/s]

Extracting files...





In [5]:
# Vocabulary dataset
tokenizer = T5Tokenizer.from_pretrained("sentence-transformers/sentence-t5-base")
vocabulary_dataset = tokenizer.get_vocab()
vocabulary_dataset.pop('<pad>', None)
vocabulary_dataset.pop('</s>', None)
vocabulary_dataset.pop('<unk>', None)
vocabulary_dataset = list(vocabulary_dataset.keys())

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [6]:
print(prompt_dataset[:5])
print(vocabulary_dataset[:5])

['Convert the text into a vintage circus poster announcement', "Convert the text into a social media platform's community guidelines", 'Rewrite this as a college course description.', 'Rephrase this as a debate on furniture rights, featuring chairs.', "Make the text into a home improvement expert's tips for a bathroom remodel"]
['▁', 'X', '.', ',', 's']


# Mean Prompt

In [7]:
def cosine_similarity(v, u):
    v_norm = functional.normalize(v, p=2, dim=-1)
    u_norm = functional.normalize(u, p=2, dim=-1)
    similarity = torch.matmul(v_norm, u_norm.T)
    return similarity

In [19]:
def hotflip(token_embeddings, token_index, embedding_matrix, prompt_tokens, vocabulary_dataset):
    simularities = []
    token_embeddings_clone = token_embeddings.clone()

    for vocab in embedding_matrix:
        token_embeddings_clone[token_index] = vocab
        simularity =  cosine_similarity(target_embeddings, token_embeddings_clone).mean()
        simularities.append(simularity)

    simularities = torch.tensor(simularities, dtype=torch.float32)
    vocab_index = simularities.argmax().item()

    token_embeddings_clone[token_index] = embedding_matrix[vocab_index]
    prompt_tokens[token_index] = vocabulary_dataset[vocab_index]

    return token_embeddings_clone, prompt_tokens

In [9]:
# Model
model_name = "sentence-transformers/sentence-t5-base"
model = SentenceTransformer(model_name)
for param in model.parameters():
    param.requires_grad = False

modules.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

rust_model.ot:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

2_Dense/config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

In [20]:
# Initial prompt
prompt = "Rewrite this text to make it more helpful"
prompt_tokens = prompt.split(' ')
token_embeddings = model.encode(prompt_tokens)
token_embeddings = torch.tensor(token_embeddings, dtype=torch.float32, requires_grad=True)

# Embedding matrix
embedding_matrix = model.encode(vocabulary_dataset)
embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)

# Target
target_embeddings = model.encode(prompt_dataset)
target_embeddings = torch.tensor(target_embeddings, dtype=torch.float32).mean(dim=0).unsqueeze(0)

print(token_embeddings.shape)
print(embedding_matrix.shape)
print(target_embeddings.shape)

torch.Size([8, 768])
torch.Size([32097, 768])
torch.Size([1, 768])


In [22]:
magic = ['lucrarea']
magic = model.encode(magic)
magic = torch.tensor(magic, dtype=torch.float32)
1 - cosine_similarity(target_embeddings, magic).mean()

tensor(0.1244)

In [24]:
for epoch in range(EPOCHS):
    for token_index in range(len(prompt_tokens)):
        loss =  1 - cosine_similarity(target_embeddings, token_embeddings).mean()
        token_embeddings, prompt_tokens = hotflip(token_embeddings, token_index, embedding_matrix, prompt_tokens, vocabulary_dataset)
        token_embeddings = token_embeddings.detach().requires_grad_(True)

    print(f"Epoch: {epoch}, Loss: {loss.item()}")
    print(' '.join(prompt_tokens))

Epoch: 0, Loss: 0.10771042108535767
▁summarize ▁summarize ▁summarize ▁summarize ▁summarize ▁summarize ▁summarize ▁summarize
