# Setup

In [1]:
from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Config, AutoModelForCausalLM, \
    AutoTokenizer, T5Tokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataCollatorWithPadding, get_scheduler, \
    EarlyStoppingCallback, IntervalStrategy, TrainerCallback, LlamaConfig
from peft import  get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit, PeftModel
import schedulefree
import torch
import datasets
import pandas as pd
import os
import random
from IPython import embed
import numpy as np



In [2]:
working_dir = "./"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class ConstrainedAdamW(torch.optim.AdamW):
    """
    A variant of Adam where some of the parameters are constrained to have unit norm.
    """
    def __init__(self, params, constrained_params, lr, scale_factor=1., weight_decay=0.01):
        super().__init__(params, lr=lr, weight_decay=weight_decay)
        self.constrained_params = list(constrained_params)
        self.scale_factor = scale_factor
    
    def step(self, closure=None):
        with torch.no_grad():
            for p in self.constrained_params:
                normed_p = p / p.norm() * self.scale_factor  # p.norm(dim=0, keepdim=True)
                # project away the parallel component of the gradient
                p.grad -= (p.grad * normed_p).sum(dim=-1, keepdim=False) * normed_p  # (dim=0, keepdim=True)
        super().step(closure=closure)
        with torch.no_grad():
            for p in self.constrained_params:
                # renormalize the constrained parameters
                p /= p.norm() * self.scale_factor  # p.norm(dim=0, keepdim=True)

class ConstrainedAdamWScheduleFree(schedulefree.AdamWScheduleFree):
    """
    A variant of Adam where some of the parameters are constrained to have unit norm.
    """
    def __init__(self, params, constrained_params, lr, warmup_steps=100):
        super().__init__(params, lr=lr, warmup_steps=warmup_steps)
        self.constrained_params = list(constrained_params)
    
    def step(self, closure=None):
        with torch.no_grad():
            for p in self.constrained_params:
                normed_p = p / p.norm() # p.norm(dim=0, keepdim=True)
                # project away the parallel component of the gradient
                p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
        super().step(closure=closure)
        with torch.no_grad():
            for p in self.constrained_params:
                # renormalize the constrained parameters
                p /= p.norm() # p.norm(dim=0, keepdim=True)

In [20]:
def get_tokenized_dataset(tokenizer, shuffle_input=False, constant_input=None, seed=17):
    # Load dataset
    raw_dataset = [{'prompt': 'I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. \
I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type \
commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. \
my first command is pwd'},
              {'prompt': 'I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. \
I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type \
commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. \
my first command is chmod',
               }]
    hf_dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=raw_dataset))
    if not MODEL_NAME.startswith('t5'):
        hf_dataset = hf_dataset.map(lambda x: tokenizer(" ".join(x["prompt"].split()[:-1]), text_target=x["prompt"].split()[-1]))
        if shuffle_input:
            random.seed(seed)
            hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": random.sample(x["input_ids"], len(x["input_ids"]))})
        if constant_input is not None:
            hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": [constant_input]*len(x["input_ids"])})
        hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": x["input_ids"] + x["labels"] + [tokenizer.eos_token_id],
                                               "attention_mask": x["attention_mask"] + [1]*len(x["labels"]) + [1],
                                               "labels": [-100]*len(x["input_ids"]) + x["labels"] + [tokenizer.eos_token_id]})
        # hf_dataset = hf_dataset.map(lambda x: tokenizer(x["prompt"]))
        # hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": x["input_ids"]+[tokenizer.eos_token_id], "attention_mask": x["attention_mask"]+[1]})
    else:
        hf_dataset = hf_dataset.map(lambda x: tokenizer(" ".join(x["prompt"].split()[:-1]), text_target=x["prompt"].split()[-1]))
        # hf_dataset = hf_dataset.map(lambda x: tokenizer("Instruction", text_target=x["prompt"]))
        # hf_dataset = hf_dataset.map(lambda x: {**x, "attention_mask": [0]*len(x["attention_mask"])})  # Don't attend to the encoder input tokens (only attend to the encoder virtual token)
    return hf_dataset


def get_outputs(model, inputs=None, inputs_embeds=None, decoder_inputs_embeds=None, max_new_tokens=500, device='cuda', text=True,
                greedy=False):
    decoding_args = {}
    if not greedy:
        decoding_args = {
            "temperature": 0.5,
            "top_p": 0.95,
            "do_sample": True,
            "repetition_penalty": 1.5, #Avoid repetition.
            "early_stopping": True, #The model can stop before reach the max_length
        }
    if inputs_embeds is not None or decoder_inputs_embeds is not None:
        outputs = model.generate(
            inputs_embeds=None if inputs_embeds is None else inputs_embeds.to(device),
            decoder_inputs_embeds=None if decoder_inputs_embeds is None else decoder_inputs_embeds.to(device),
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            **decoding_args
        )    
    else:
        outputs = model.generate(
            input_ids=inputs["input_ids"].to(device),
            attention_mask=inputs["attention_mask"].to(device),
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            **decoding_args
        )
    if text:
        return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs


def create_training_arguments(path, learning_rate=0.0035, epochs=6, device='cuda'):
    training_args = TrainingArguments(
        output_dir=path, # Where the model predictions and checkpoints will be written
        use_cpu=device=='cpu', # This is necessary for CPU clusters.
        auto_find_batch_size=True, # Find a suitable batch size that will fit into memory automatically
        learning_rate=learning_rate, # Higher learning rate than full Fine-Tuning
        num_train_epochs=epochs,
        logging_steps=epochs//10,
        eval_steps=epochs//10,
        metric_for_best_model='accuracy', # 'loss',
        load_best_model_at_end = True,
        save_strategy=IntervalStrategy.STEPS,
        evaluation_strategy=IntervalStrategy.STEPS
    )
    return training_args

class CustomEarlyStoppingCallback(EarlyStoppingCallback):
    def __init__(self):
        super()
    def on_evaluate(self, args, state, control, **kwargs):
        is_correct = kwargs['metrics']['eval_accuracy']
        if is_correct:
            control.should_training_stop = True

def create_trainer(peft_model, training_args, train_dataset, schedule_free=False,
                   unit_norm=False, unit_norm_scale_factor=1.):
    add_args = {}
    if schedule_free:
        if unit_norm:
            optimizer = ConstrainedAdamWScheduleFree(
                params=peft_model.parameters(),
                constrained_params=peft_model.prompt_encoder.parameters(),
                lr=training_args.learning_rate,
                warmup_steps=100
            )
        else:
            optimizer = schedulefree.AdamWScheduleFree(
                peft_model.parameters(),
                lr=training_args.learning_rate,
                warmup_steps=100,
            )
        add_args["optimizers"] = (optimizer, None)
    elif unit_norm:
        optimizer = ConstrainedAdamW(
            params=peft_model.parameters(),
            constrained_params=peft_model.prompt_encoder.parameters(),
            lr=training_args.learning_rate,
            scale_factor=unit_norm_scale_factor)
        add_args["optimizers"] = (optimizer, None)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=peft_model) if MODEL_NAME.startswith('t5') else DataCollatorWithPadding(tokenizer)  # DataCollatorForLanguageModeling(tokenizer, mlm=False)
    
    def compute_metrics(eval_pred):
        _type = "seq2seq" if type(eval_pred.predictions) is tuple else "causal"
        preds = eval_pred.predictions[0] if _type == "seq2seq" else eval_pred.predictions
        preds = preds.argmax(axis=-1).squeeze()  # greedy
        labels = eval_pred.label_ids.squeeze()
        is_correct = labels[labels != -100] == preds[:len(preds) if _type == "seq2seq" else (len(preds) - 1)][labels != -100]
        return {'accuracy': (is_correct.sum() / len(is_correct)) == 1}

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=train_dataset,
        data_collator=data_collator,
        callbacks=[CustomEarlyStoppingCallback()], #[EarlyStoppingCallback(early_stopping_patience=1)],  # , early_stopping_threshold=0.2
        compute_metrics=compute_metrics,
        **add_args
    )
    return trainer


def load_and_set_adapter(directory, name):
    loaded_model.load_adapter(directory, adapter_name=name)
    loaded_model.set_adapter(name)
    return loaded_model


def get_virtual_token(foundational_model, tokenizer, hf_dataset, data_idx=0, num_virtual_tokens=1, learning_rate=3e-3, 
                      epochs=500, save=True, reset=False, schedule_free=False, unit_norm=False, unit_norm_scale_factor=1.):
    output_directory =  os.path.join(working_dir, f"peft_model_{data_idx}")
    # Check if the model already exists
    if not reset and os.path.exists(output_directory):
        peft_model = PeftModel.from_pretrained(foundational_model,
                                               output_directory,
                                               device_map='auto',
                                               is_trainable=False)
    else:
        # Load peft model
        generation_config = PromptTuningConfig(
            task_type=TaskType.SEQ_2_SEQ_LM if MODEL_NAME.startswith('t5') else TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.RANDOM,  # PromptTuningInit.RANDOM if MODEL_NAME.startswith('t5') else PromptTuningInit.TEXT,  # PromptTuningInit.RANDOM,
            prompt_tuning_init_text=tokenizer.decode(hf_dataset[data_idx]['labels'][hf_dataset[data_idx]['labels'].count(-100):], skip_special_tokens=True),  # hf_dataset[data_idx]['prompt'],  # only if using TEXT init
            num_virtual_tokens=num_virtual_tokens,
            tokenizer_name_or_path=MODEL_NAME, # pre-trained model name
            num_transformer_submodules=1,
            token_dim=foundational_model.model_dim  # // num_virtual_tokens
        )
        peft_model = get_peft_model(foundational_model, generation_config)
        print(peft_model.print_trainable_parameters())
        with torch.no_grad():
            init_vtoken = peft_model.get_prompt(1)
        
        # Create directories to store the models
        if not os.path.exists(working_dir):
            os.mkdir(working_dir)
        if not os.path.exists(output_directory):
            os.mkdir(output_directory)
        
        # Get training args
        training_args = create_training_arguments(output_directory, learning_rate, epochs, device=device)
        # Get trainer
        trainer = create_trainer(peft_model, training_args, 
                                 train_dataset=hf_dataset.select(range(data_idx, data_idx+1)), 
                                 schedule_free=schedule_free,
                                 unit_norm=unit_norm, unit_norm_scale_factor=unit_norm_scale_factor)
        # Run training
        trainer.train()
        # Get trained model
        peft_model = trainer.model
        # Save if required
        if save:
            peft_model.save_pretrained(output_directory)

    # Return virtual token
    with torch.no_grad():
        virtual_token = peft_model.get_prompt(1)
    
    return virtual_token, hf_dataset[data_idx]['prompt'], peft_model, trainer

# Load model

In [5]:
global MODEL_NAME
MODEL_NAME = "t5-base" # "t5-base"  # "bigscience/bloomz-560m" # "bigscience/bloomz-560m"  # "gpt2"

# Load tokenizer    
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if MODEL_NAME in ['gpt2', 'meta-llama/Llama-2-7b-hf']:
    tokenizer.pad_token = tokenizer.eos_token

# Load dataset
hf_dataset = get_tokenized_dataset(tokenizer, shuffle_input=False)

# Load model
if MODEL_NAME.startswith('t5'):
    config = T5Config.from_pretrained(MODEL_NAME)
    config.dropout_rate = 0
    foundational_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, config=config, device_map='auto')
else:
    config = LlamaConfig.from_pretrained(self.kind)
    config.attn_dropout = 0
    foundational_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        config=config,
        trust_remote_code=True,
        torch_dtype=torch.float16 if MODEL_NAME.startswith('meta') else torch.float32,
        device_map='auto'
    )

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

In [6]:
# Shuffle the input_ids of the dataset
# hf_dataset = get_tokenized_dataset(tokenizer, shuffle_input=True, seed=19, constant_input=0)

# Debug

In [7]:
FEAT_PATH="../cache/word2vec-2000/computer/computer_instruction_vtoken_t5-small_feats.bin"
features = torch.load(FEAT_PATH)
warm_start_idxs = np.array([669, 1705, 814, 810, 1441]) - 1
warm_start_features = features[warm_start_idxs]
warm_start_norm_mean = torch.linalg.vector_norm(warm_start_features, dim=1).mean().item()
full_set_norm_mean = torch.linalg.vector_norm(features, dim=1).mean().item()

In [8]:
feature_0 = features[0]
feature_0_norm = feature_0.norm()
feature_0_normalized = feature_0 / feature_0_norm

In [9]:
feature_mix = features[1992]
feature_mix_norm = feature_mix.norm()
feature_mix_normalized = feature_mix / feature_mix_norm

In [10]:
token_embeds = foundational_model.get_input_embeddings()
for p in token_embeds.parameters():
    break
# Based on prompt text
_prompt = "The task is to find a hidden test word by guessing new words. What is your next guess?"
prompt_embed = p[tokenizer(_prompt, return_tensors='pt')['input_ids'][0]][None, :, :]
prompt_embed_norm_mean = prompt_embed.norm(dim=2).mean()

In [11]:
for i in range(1, 1000, 1000):
    rand = torch.nn.functional.normalize(torch.rand(feature_0_normalized.shape)*2-1, dim=0)
    vtoken = ((feature_0_normalized.cuda()) * prompt_embed_norm_mean)[None, None, :].cuda()  # * rand  # * warm_start_norm_mean
    vtoken_plus_text = torch.cat((vtoken, prompt_embed), dim=1)  # prepend vtoken to prompt embed
    
    # Generate vtoken output
    vtoken_output = get_outputs(foundational_model, inputs_embeds=vtoken_plus_text.type(foundational_model.dtype),
                                device=device, text=True)[0]
    print(f'{vtoken_output}')

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 512 but got size 768 for tensor number 1 in the list.

# Soft-prompt

In [21]:
# Get vtoken
vtoken, prompt, peft_model, trainer = get_virtual_token(foundational_model, tokenizer, hf_dataset, 
                                                        data_idx=0, num_virtual_tokens=1,
                                                        learning_rate=3e-2*(1 if MODEL_NAME.startswith('t5') else 1), 
                                                        epochs=40, schedule_free=False, unit_norm=True, unit_norm_scale_factor=1.,
                                                        save=False, reset=True)

trainable params: 768 || all params: 222,904,320 || trainable%: 0.00034454244762954794
None


TypeError: super() argument 1 must be a type, not T5TokenizerFast

### Generate vtoken output and compare

In [52]:
vtoken_output = get_outputs(foundational_model, inputs_embeds=vtoken.type(foundational_model.dtype), device=device, text=True, greedy=False)[0]
print(f'Prompt:\n{prompt}\n')
print(f'Vtoken:\n{vtoken_output}')

Prompt:
I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. my first command is pwd

Vtoken:
-'. the at in hot, you your reception after for standing du as warm brief form back la to both and on an comfortable comfort per des over de welcomeen den dem second so!s un din bas der se par le further from... again no jao pre car this point im it long son of be short sweet counter all then ce: hotel am following underd right here rest summer last co/ 2 also key turn au pay finale up her out die (n sign con que two gentle swift more evening joint well stay soft


### Perturb vtoken

In [26]:
rand = torch.rand(vtoken.shape)

In [83]:
mask = torch.zeros(vtoken.shape)
mask[:, :, :75] = 1
mask = mask[:, :, torch.randperm(mask.nelement())]
new_vtoken = (vtoken + (rand*mask).to(device) * 1000) * 1.
# get_outputs(foundational_model, 
#             inputs_embeds=new_vtoken, 
#             # decoder_inputs_embeds=decoder_inputs_embeds[:, :-10, :]*100,
#             device=device, text=True)[0]

### Combine vtoken and textual prompt

In [46]:
token_embeds = foundational_model.get_input_embeddings()
for p in token_embeds.parameters():
    break
# Based on prompt text
# _prompt = " ".join(prompt.split()[:-1])  # Everything except pwd/chmod
# prompt_embed = p[tokenizer(_prompt, return_tensors='pt')['input_ids'][0]][None, :, :]

# Based on prompt input_ids
if MODEL_NAME.startswith('t5'):
    prompt_input_ids = hf_dataset[0]['input_ids']  # encoder input
else:
    prompt_input_ids = hf_dataset[0]['input_ids'][:hf_dataset[0]['labels'].count(-100)]  # inputs for which no predictions need to be made
_prompt = tokenizer.decode(prompt_input_ids, skip_special_tokens=True)
prompt_embed = p[prompt_input_ids][None, :, :]

# vtoken_norm_and_scaled = torch.nn.functional.normalize(vtoken)*prompt_embed.norm(dim=-1).mean()
vtoken_plus_text = torch.cat((vtoken, prompt_embed), dim=1)  # prepend vtoken to prompt embed
vtoken_plus_text.shape

torch.Size([1, 103, 768])

In [53]:
# Generate vtoken output and compare
vtoken_output = get_outputs(foundational_model, inputs_embeds=vtoken_plus_text.type(foundational_model.dtype),
                            device=device, text=True, greedy=False)[0]
print(f'Prompt:\n{_prompt}\n')
print(f'Vtoken:\n{vtoken_output}')

Prompt:
I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets like this. my first command is

Vtoken:
pwd: I want you to act as a linux terminal. type commands and you will reply with what the terminal should show. do not write explanations. do not type commands unless I instruct you to do so. my first command is like this, so i will type it in curly brackets like this. my second command is like this, so i will type it into curly brackets like this. my last command is like this


In [50]:
prompt_embed.norm(dim=-1).mean()

tensor(348.1645, device='cuda:0')

In [54]:
prompt_embed.norm(dim=-1).std()

tensor(96.4136, device='cuda:0')

In [55]:
prompt_embed.norm(dim=-1).min()

tensor(228.3652, device='cuda:0')

In [56]:
prompt_embed.norm(dim=-1).max()

tensor(532.9974, device='cuda:0')

In [49]:
print(vtoken.norm(), vtoken.mean(), vtoken.std())

tensor(1., device='cuda:0') tensor(1.6794e-05, device='cuda:0') tensor(0.0361, device='cuda:0')


### Comparing different vtokens

In [8]:
vtoken1 = vtoken  # Unshuffled

In [38]:
vtoken2 = vtoken  # Shuffled (seed=17)

In [44]:
vtoken3 = vtoken  # Shuffled (seed=18)

In [51]:
vtoken4 = vtoken  # Shuffled (seed=19)

In [61]:
vtoken5 = vtoken  # Constant (0)

In [39]:
torch.cdist(vtoken1, vtoken2)

tensor([[[925.2473]]], device='cuda:0')

In [46]:
torch.cdist(vtoken1, vtoken3)

tensor([[[867.6586]]], device='cuda:0')

In [52]:
torch.cdist(vtoken1, vtoken4)

tensor([[[1069.6093]]], device='cuda:0')

In [62]:
torch.cdist(vtoken1, vtoken5)

tensor([[[1095.5792]]], device='cuda:0')

In [45]:
torch.cdist(vtoken2, vtoken3)

tensor([[[427.0382]]], device='cuda:0')

In [53]:
torch.cdist(vtoken2, vtoken4)

tensor([[[670.3230]]], device='cuda:0')

In [63]:
torch.cdist(vtoken2, vtoken5)

tensor([[[888.1259]]], device='cuda:0')

In [54]:
torch.cdist(vtoken3, vtoken4)

tensor([[[678.1097]]], device='cuda:0')

In [64]:
torch.cdist(vtoken3, vtoken5)

tensor([[[891.1073]]], device='cuda:0')

In [65]:
torch.cdist(vtoken4, vtoken5)

tensor([[[1059.0133]]], device='cuda:0')

### Check output with peft model using textual prompt (t5-only)

In [None]:
# FOR T5 models only
input_tokenized = tokenizer("Instruction", return_tensors='pt')
input_tokenized = {**input_tokenized, "attention_mask": input_tokenized["attention_mask"]*0}
vtoken_output = get_outputs(peft_model, inputs=input_tokenized, device=device, text=True)[0]
print(f'Prompt:\n{prompt}\n')
print(f'Vtoken:\n{vtoken_output}')