# Setup

In [1]:
from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Config, AutoModelForCausalLM, \
    AutoTokenizer, T5Tokenizer, TrainingArguments, Trainer, DataCollator, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataCollatorWithPadding, get_scheduler, \
    EarlyStoppingCallback, IntervalStrategy, TrainerCallback, LlamaConfig 
from torch.nn.functional import cosine_similarity
from datasets import Dataset, DatasetDict
from peft import  get_peft_model, PromptTuningConfig, MultitaskPromptTuningConfig, TaskType, PromptTuningInit, PeftModel
import schedulefree
import torch
import datasets
import pandas as pd
import os
import random
from IPython import embed
import numpy as np
import wandb
from tqdm.notebook import tqdm



In [2]:
working_dir = "./"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class ConstrainedAdamW(torch.optim.AdamW):
    """
    A variant of Adam where some of the parameters are constrained to have unit norm.
    """
    def __init__(self, params, constrained_params, lr, scale_factor=1., weight_decay=0.0):
        super().__init__(params, lr=lr, weight_decay=weight_decay)
        self.constrained_params = list(constrained_params)
        self.scale_factor = scale_factor
    
    def step(self, closure=None):
        with torch.no_grad():
            for i,p in enumerate(self.constrained_params):
                if p.grad is None:
                    continue
                if i < 2:  # Handle low-d vtoken and projection matrix differently
                    normed_p = p / p.norm(dim=-1, keepdim=True) * self.scale_factor  # normed_p = p / p.norm() * self.scale_factor  # p.norm(dim=0, keepdim=True)
                    # project away the parallel component of the gradient
                    p.grad -= (p.grad * normed_p).sum(dim=-1, keepdim=True) * normed_p  # (dim=0, keepdim=True)
                else:
                    normed_p = p / p.norm(dim=0, keepdim=True) * self.scale_factor
                    p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
        super().step(closure=closure)
        with torch.no_grad():
            for i,p in enumerate(self.constrained_params):
                if p.grad is None:
                    continue
                # renormalize the constrained parameters
                if i < 2:
                    p /= p.norm(dim=-1, keepdim=True) * self.scale_factor  # p /= p.norm() * self.scale_factor  # p.norm(dim=0, keepdim=True)
                else:
                    p /= p.norm(dim=0, keepdim=True) * self.scale_factor

class ConstrainedAdamWScheduleFree(schedulefree.AdamWScheduleFree):
    """
    A variant of Adam where some of the parameters are constrained to have unit norm.
    """
    def __init__(self, params, constrained_params, lr, warmup_steps=100):
        super().__init__(params, lr=lr, warmup_steps=warmup_steps)
        self.constrained_params = list(constrained_params)
    
    def step(self, closure=None):
        with torch.no_grad():
            for p in self.constrained_params:
                if p.grad is None:
                    continue
                normed_p = p / p.norm(dim=-1, keepdim=True)  # normed_p = p / p.norm() # p.norm(dim=0, keepdim=True)
                # project away the parallel component of the gradient
                p.grad -= (p.grad * normed_p).sum(dim=-1, keepdim=True) * normed_p  # .sum(dim=0, keepdim=True)
        super().step(closure=closure)
        with torch.no_grad():
            for p in self.constrained_params:
                if p.grad is None:
                    continue
                # renormalize the constrained parameters
                p /= p.norm(dim=-1, keepdim=True)  # p /= p.norm() # p.norm(dim=0, keepdim=True)

In [4]:
with open('../data/twentyquestions/datasets/word2vec-2000/computer.csv', 'r') as fh:
    words = pd.read_csv(fh)

In [5]:
prompts = [f"The task is to find a hidden test word by guessing new words. What is your next guess?" for x in words["Words"]]
targets = [f"{x}" for x in words["Words"]]

In [6]:
# Convert embeddings to a list of lists
# feats = torch.load("../cache/word2vec-2000/computer/computer_word_average_t5-base_feats.bin")
# feats_list = [feat.tolist() for feat in feats]
# Create a dictionary with the data
data_dict = {
    # 'inputs_embeds': feats_list,
    'prompt': prompts,
    'target': targets,
    'task_ids': list(range(len(prompts)))
}
# Create a Dataset
dataset = Dataset.from_dict(data_dict)

In [41]:
def get_tokenized_dataset(tokenizer, shuffle_input=False, shuffle_dataset=False, constant_input=None, seed=17):
    # Load dataset
#     hf_dataset = datasets.Dataset.from_pandas(pd.DataFrame(data=raw_dataset))
    hf_dataset = dataset
    if not MODEL_NAME.startswith('t5'):
        pass
        hf_dataset = hf_dataset.map(lambda x: tokenizer(x["prompt"], text_target=x["target"]))
        need_eos = hf_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id
        hf_dataset = hf_dataset.map(lambda x: {
            **x,
            "input_ids": x["input_ids"] + x["labels"][1 if x["labels"][0] == tokenizer.bos_token_id else 0:] + ([tokenizer.eos_token_id] if need_eos else []),
            "attention_mask": [1] * len(x["attention_mask"]) + [1] * len(x["labels"][1 if x["labels"][0] == tokenizer.bos_token_id else 0:]) + (
                [1] if need_eos else []),
            "labels": [-100] * len(x["input_ids"]) + x["labels"][1 if x["labels"][0] == tokenizer.bos_token_id else 0:] + (
                [tokenizer.eos_token_id] if need_eos else [])
        })
        # if shuffle_input:
        #     random.seed(seed)
        #     hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": random.sample(x["input_ids"], len(x["input_ids"]))})
        # if constant_input is not None:
        #     hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": [constant_input]*len(x["input_ids"])})
        # hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": x["input_ids"] + x["labels"] + [tokenizer.eos_token_id],
        #                                        "attention_mask": x["attention_mask"] + [1]*len(x["labels"]) + [1],
        #                                        "labels": [-100]*len(x["input_ids"]) + x["labels"] + [tokenizer.eos_token_id]})
        # hf_dataset = hf_dataset.map(lambda x: tokenizer(x["prompt"] + " " + x["target"]))
        # hf_dataset = hf_dataset.map(lambda x: {**x, "input_ids": x["input_ids"]+[tokenizer.eos_token_id], "attention_mask": x["attention_mask"]+[1]})
    else:
        # Use below for translation embeddings
        # hf_dataset = hf_dataset.map(lambda x: tokenizer(x["prompt"], text_target=x["prompt"].split()[-1]))
        
        # Use below to test with text tokens
        # hf_dataset = hf_dataset.map(lambda x: tokenizer(" ".join(x["prompt"].split()[:-1]), text_target=x["prompt"].split()[-1]))

        # Use below to test with no prompt
        # hf_dataset = hf_dataset.map(lambda x: tokenizer("Instruction", text_target=x["prompt"].split()[-1]))
        hf_dataset = hf_dataset.map(lambda x: tokenizer(x["prompt"], text_target=x["target"]))
        # hf_dataset = hf_dataset.map(lambda x: {**x, "attention_mask": [0]*len(x["attention_mask"])})  # Don't attend to the encoder input tokens (only attend to the encoder virtual token)
    
    if shuffle_dataset:
        hf_dataset = hf_dataset.shuffle(seed=seed)
    
    return hf_dataset


def get_prompt_embed(hf_dataset, foundational_model, tokenizer, custom_prompt=None):
    token_embeds = foundational_model.get_input_embeddings()
    for p in token_embeds.parameters():
        break

    if custom_prompt is not None:
        prompt_input_ids = tokenizer(custom_prompt)['input_ids']
        if not MODEL_NAME.startswith('t5') and prompt_input_ids[-1] == tokenizer.eos_token_id:
            prompt_input_ids = prompt_input_ids[:-1]
    else:
        # Based on prompt input_ids (assumes every prompt input is the same)
        if MODEL_NAME.startswith('t5'):
            prompt_input_ids = hf_dataset[0]['input_ids']  # encoder input
        else:
            prompt_input_ids = hf_dataset[0]['input_ids'][:hf_dataset[0]['labels'].count(-100)]  # inputs for which no predictions need to be made
    prompt_embed = p[prompt_input_ids][None, :, :]
    return prompt_embed
    

class DataCollatorForSeq2SeqWithEmbeddings(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        super_keys = ['input_ids', 'labels']
        batch = super().__call__([{_k:_v for _k,_v in f.items() if _k in super_keys} for f in features], return_tensors)
        inputs_embeds = torch.tensor([f["inputs_embeds"] for f in features], dtype=torch.float)[:, None, :]
        task_ids = torch.tensor([f["task_ids"] for f in features], dtype=torch.long)
        # prompts = torch.tensor([f["prompts"] for f in features], dtype=torch.long)
        # return {"inputs_embeds": inputs_embeds, "prompts": prompts}
        rtn = {"labels": batch["labels"], "inputs_embeds": inputs_embeds, "task_ids": task_ids}
        return rtn


class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        super_keys = ['input_ids', 'attention_mask', 'labels']
        batch = super().__call__([{_k:_v for _k,_v in f.items() if _k in super_keys} for f in features], return_tensors)
        task_ids = torch.tensor([f["task_ids"] for f in features], dtype=torch.long)
        rtn = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"], "labels": batch["labels"], "task_ids": task_ids}
        return rtn


def get_outputs(model, inputs=None, inputs_embeds=None, decoder_inputs_embeds=None, max_new_tokens=30, device='cuda', text=True,
                greedy=False, decoding_args={}):
    _decoding_args = {}
    if not greedy:
        _decoding_args = {
            "temperature": 0.5,
            "top_p": 0.95,
            "do_sample": True,
            "repetition_penalty": 1.5, #Avoid repetition.
            "early_stopping": True, #The model can stop before reach the max_length
        }
        _decoding_args.update(decoding_args)
    if inputs_embeds is not None or decoder_inputs_embeds is not None:
        outputs = model.generate(
            inputs_embeds=None if inputs_embeds is None else inputs_embeds.to(device),
            decoder_inputs_embeds=None if decoder_inputs_embeds is None else decoder_inputs_embeds.to(device),
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            **_decoding_args
        )    
    else:
        outputs = model.generate(
            input_ids=inputs["input_ids"].to(device),
            attention_mask=inputs["attention_mask"].to(device),
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            **_decoding_args
        )
    if text:
        return tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs


def create_training_arguments(path, learning_rate=0.0035, epochs=6, device='cuda', use_wandb=False):
    add_args = {}
    if use_wandb:
        add_args["report_to"] = "wandb"
    training_args = TrainingArguments(
        output_dir=path, # Where the model predictions and checkpoints will be written
        use_cpu=device=='cpu', # This is necessary for CPU clusters.
        per_device_train_batch_size=16,
        auto_find_batch_size=True, # Lowers batch size that will fit into memory automatically
        learning_rate=learning_rate, # Higher learning rate than full Fine-Tuning
        num_train_epochs=epochs,
        logging_steps=epochs//10,
        eval_steps=epochs,  # //10
        save_steps=epochs,  # //10
        metric_for_best_model='loss',  # 'accuracy', # 'loss',
        load_best_model_at_end = True,
        save_strategy=IntervalStrategy.STEPS,
        evaluation_strategy=IntervalStrategy.STEPS,
        remove_unused_columns=False,  # important
        **add_args
    )
    return training_args

class CustomEarlyStoppingCallback(EarlyStoppingCallback):
    def __init__(self):
        super()
    def on_evaluate(self, args, state, control, **kwargs):
        is_correct = kwargs['metrics']['eval_accuracy'] == 1.
        if is_correct:
            control.should_training_stop = True

def create_trainer(peft_model, training_args, train_dataset, eval_dataset, schedule_free=False,
                   unit_norm=False, unit_norm_scale_factor=1.):
    add_args = {}
    if schedule_free:
        if unit_norm:
            optimizer = ConstrainedAdamWScheduleFree(
                params=peft_model.parameters(),
                constrained_params=peft_model.prompt_encoder.parameters(),
                lr=training_args.learning_rate,
                warmup_steps=100
            )
        else:
            optimizer = schedulefree.AdamWScheduleFree(
                peft_model.parameters(),
                lr=training_args.learning_rate,
                warmup_steps=100,
            )
        add_args["optimizers"] = (optimizer, None)
    elif unit_norm:
        optimizer = ConstrainedAdamW(
            params=peft_model.parameters(),
            constrained_params=peft_model.prompt_encoder.parameters(),
            lr=training_args.learning_rate,
            scale_factor=unit_norm_scale_factor)
        add_args["optimizers"] = (optimizer, None)
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=peft_model) if MODEL_NAME.startswith('t5') else DataCollatorWithPadding(tokenizer)  # DataCollatorForLanguageModeling(tokenizer, mlm=False)
    
    # data_collator = DataCollatorForSeq2SeqWithEmbeddings(tokenizer, model=peft_model)
    data_collator = CustomDataCollatorForSeq2Seq(tokenizer, model=peft_model)
    
    def compute_metrics(eval_pred):
        _type = "seq2seq" if type(eval_pred.predictions) is tuple else "causal"
        preds = eval_pred.predictions[0] if _type == "seq2seq" else eval_pred.predictions
        preds = preds.argmax(axis=-1).squeeze()  # greedy
        labels = eval_pred.label_ids.squeeze()
        if len(preds.shape) != 2:
            preds = preds[None, :]
            labels = labels[None, :]
        is_correct = []
        for i in range(len(preds)):
            _labels = labels[i]
            _preds = preds[i]
            _is_correct = _labels[_labels != -100] == _preds[:len(_preds) if _type == "seq2seq" else (len(_preds) - 1)][_labels != -100]
            is_correct.append((_is_correct.sum() / len(_is_correct)))
        return {'accuracy': sum(is_correct) / len(is_correct)}

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        callbacks=[CustomEarlyStoppingCallback()], #[EarlyStoppingCallback(early_stopping_patience=1)],  # , early_stopping_threshold=0.2
        compute_metrics=compute_metrics,
        **add_args
    )
    return trainer


def load_and_set_adapter(directory, name):
    loaded_model.load_adapter(directory, adapter_name=name)
    loaded_model.set_adapter(name)
    return loaded_model


def get_virtual_token(foundational_model, tokenizer, hf_dataset, data_idx=0, num_virtual_tokens=1, learning_rate=3e-3, 
                      epochs=500, save=True, load_saved=False, schedule_free=False, unit_norm=False, unit_norm_scale_factor=1.,
                      train_size=1, eval_size=None, use_wandb=False, multi_task=False, low_d=None, out_dir=None):
    if use_wandb:
        wandb.init()
    output_directory =  os.path.join(working_dir, f"peft_model_{data_idx}" if out_dir is None else out_dir)
    peft_model = None
    # Check if the model already exists
    if load_saved and os.path.exists(output_directory):
        try:
            peft_model = PeftModel.from_pretrained(foundational_model,
                                                   output_directory,
                                                   device_map='auto',
                                                   is_trainable=False)
            print("Loaded saved model")
        except:
            print("Failed to load saved peft model. Initializing new model.")
    if peft_model is None:
        # Load peft model
        PromptTuningClass = MultitaskPromptTuningConfig if multi_task else PromptTuningConfig
        add_args = {}
        if multi_task:
            add_args["num_tasks"] = train_size
            add_args["model_dim"] = next(foundational_model.get_input_embeddings().parameters()).shape[1]  # foundational_model.model_dim
        generation_config = PromptTuningClass(
            task_type=TaskType.SEQ_2_SEQ_LM if MODEL_NAME.startswith('t5') else TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.RANDOM,  # PromptTuningInit.RANDOM if MODEL_NAME.startswith('t5') else PromptTuningInit.TEXT,  # PromptTuningInit.RANDOM,
            prompt_tuning_init_text=tokenizer.decode(hf_dataset[data_idx]['labels'][hf_dataset[data_idx]['labels'].count(-100):], skip_special_tokens=True),  # hf_dataset[data_idx]['prompt'],  # only if using TEXT init
            num_virtual_tokens=num_virtual_tokens,
            tokenizer_name_or_path=MODEL_NAME, # pre-trained model name
            num_transformer_submodules=1,
            token_dim=foundational_model.model_dim if low_d is None else low_d,
            **add_args
        )
        peft_model = get_peft_model(foundational_model, generation_config)
    print(peft_model.print_trainable_parameters())
    
    # Create directories to store the models
    if not os.path.exists(working_dir):
        os.mkdir(working_dir)
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    train_dataset = hf_dataset.select(range(data_idx, data_idx+train_size))
    if eval_size is None:
        eval_dataset = train_dataset
    else:
        eval_dataset = hf_dataset.select(range(data_idx+train_size, data_idx+train_size+eval_size))
    
    # Get training args
    training_args = create_training_arguments(output_directory, learning_rate, epochs, device=device, use_wandb=use_wandb)
    # Get trainer
    trainer = create_trainer(peft_model, training_args, 
                             train_dataset=train_dataset, 
                             eval_dataset=eval_dataset,
                             schedule_free=schedule_free,
                             unit_norm=unit_norm, unit_norm_scale_factor=unit_norm_scale_factor)
    # Run training
    trainer.train()
    # Get trained model
    peft_model = trainer.model
    # Save if required
    if save:
        peft_model.save_pretrained(output_directory)

    # Return virtual token
    with torch.no_grad():
        if not multi_task:
            virtual_token = peft_model.get_prompt(1)
        else:
            virtual_token = peft_model.get_prompt(train_size, torch.arange(train_size).to(device))

    if use_wandb:
        wandb.finish()
    
    return virtual_token, hf_dataset[data_idx]['prompt'], peft_model, trainer

# Load model

In [8]:
global MODEL_NAME
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # "t5-base"  # "bigscience/bloomz-560m" # "bigscience/bloomz-560m"  # "gpt2"

# Load tokenizer    
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if MODEL_NAME in ['gpt2', 'meta-llama/Llama-2-7b-hf']:
    tokenizer.pad_token = tokenizer.eos_token

# Load dataset
hf_dataset = get_tokenized_dataset(tokenizer, shuffle_input=False)

# Load model
if MODEL_NAME.startswith('t5'):
    config = T5Config.from_pretrained(MODEL_NAME)
    config.dropout_rate = 0
    foundational_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, config=config, device_map='auto')
else:
    config = LlamaConfig.from_pretrained(MODEL_NAME)
    config.attn_dropout = 0
    foundational_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        config=config,
        trust_remote_code=True,
        torch_dtype=torch.float16 if MODEL_NAME.startswith('meta') else torch.float32,
        device_map='auto'
    )

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [117]:
hf_dataset = get_tokenized_dataset(tokenizer, shuffle_input=False)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [10]:
# Shuffle the input_ids of the dataset
# hf_dataset = get_tokenized_dataset(tokenizer, shuffle_input=True, seed=19, constant_input=0)

# Soft-prompt

In [11]:
# !cp peft_model/checkpoint-250000/* peft_model/

In [40]:
# !rm -rf peft_model

In [9]:
# peft_model = PeftModel.from_pretrained(foundational_model,
#                                        'peft_model/checkpoint-440',
#                                        device_map='auto',
#                                        is_trainable=False)
# peft_model.print_trainable_parameters()

trainable params: 42,070 || all params: 6,738,457,686 || trainable%: 0.000624326840953617


In [None]:
# Get vtoken
vtoken, prompt, peft_model, trainer = get_virtual_token(foundational_model, tokenizer, hf_dataset, data_idx=0, 
                                                        num_virtual_tokens=1, learning_rate=1e-2 if MODEL_NAME.startswith('t5') else 3e-4, epochs=2000, 
                                                        schedule_free=False, unit_norm=True, # unit_norm_scale_factor=feats.norm(dim=1).mean(),  # 1.,
                                                        save=False, load_saved=True, train_size=2000, eval_size=None, use_wandb=True, multi_task=True, low_d=10,
                                                        out_dir="peft_model")

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Failed to load saved peft model. Initializing new model.
trainable params: 62,970 || all params: 6,738,478,586 || trainable%: 0.0009344839372321781
None


Step,Training Loss,Validation Loss,Accuracy
2000,3.6513,3.623791,0.387736
4000,3.2611,3.250265,0.463044
6000,3.0094,3.002292,0.499336
8000,2.7761,2.761252,0.533202
10000,2.5325,2.523703,0.56779
12000,2.3244,2.30575,0.606715
14000,2.1177,2.106028,0.637574
16000,1.937,1.918895,0.671174
18000,1.7699,1.756878,0.699257
20000,1.6035,1.589105,0.734857


### Generate vtoken output and compare

In [10]:
with torch.no_grad():
    projection_matrix = peft_model.prompt_encoder.default.projection_matrix
    lowd_vtoken = peft_model.prompt_encoder.default.prefix_task_cols @ peft_model.prompt_encoder.default.prefix_task_rows

In [11]:
print(lowd_vtoken.norm(dim=-1).mean())

tensor(1., device='cuda:0')


In [12]:
with torch.no_grad():
    print(projection_matrix.norm())

tensor(64., device='cuda:0')


In [13]:
with torch.no_grad():
    print((lowd_vtoken @ projection_matrix).norm(dim=-1).mean())

tensor(20.7313, device='cuda:0')


In [14]:
# Check if in word2vec vocab
import gensim.downloader
vectors = gensim.downloader.load('word2vec-google-news-300')
len(vectors)

3000000

In [24]:
multi_task = True
correct, valid, invalid, new = [], [], [], []
vtoken_outputs = []
n = 100
for idx in tqdm(range(n)):
    # _vtoken = vtoken[None, idx] if multi_task else vtoken
    with torch.no_grad():
        # rand_vtoken = torch.rand((1, 10)).to(device) @ projection_matrix
        perturbation = ((torch.rand((1, 10))-0.5)*(torch.rand(10) < 0) * 3e-1).to(device)
        rand_vtoken = (lowd_vtoken[idx] + perturbation) @ projection_matrix
    _vtoken = rand_vtoken[None, :]
    # Add text tokens
    prompt_embed = get_prompt_embed(hf_dataset, foundational_model, tokenizer,) 
                                    # custom_prompt="The task is to find a hidden test word by guessing new words. What is your next guess?")
    _vtoken = torch.cat((_vtoken, prompt_embed), dim=1)  # prepend vtoken to prompt embed
    # _vtoken = torch.cat((prompt_embed, _vtoken), dim=1)  # append vtoken to prompt embed
    vtoken_output = get_outputs(foundational_model,
                                inputs_embeds=_vtoken.type(foundational_model.dtype),
                                device=device, text=True, greedy=True, decoding_args={"top_p": 0.95, "temperature": 0.4})[0]
    vtoken_outputs.append(vtoken_output.strip().lower())
    
    print(idx, f'TARGET: {hf_dataset[idx]["target"]}', f'PRED: {vtoken_outputs[-1]}')
    
    if hf_dataset[idx]["target"] == vtoken_outputs[-1]:
        correct.append(hf_dataset[idx]["target"])
    
    # Check if in word2vec vocab
    if vectors.get_index(vtoken_outputs[-1], -1000) != -1000:
        valid.append((hf_dataset[idx]["target"], vtoken_outputs[-1]))
        if vtoken_outputs[-1] not in targets:
            new.append((hf_dataset[idx]["target"], vtoken_outputs[-1]))
    else:
        invalid.append((hf_dataset[idx]["target"], vtoken_outputs[-1]))
print(f"Accuracy: {len(correct)/n} (n={n})")
print(f"Valid: {len(valid)/n} (n={n})")
print(f"New: {len(new)/n} (n={n})")

  0%|          | 0/100 [00:00<?, ?it/s]

0 TARGET: computer PRED: computer
1 TARGET: laptop PRED: laptop
2 TARGET: photocopier PRED: photocopier
3 TARGET: machine PRED: machine
4 TARGET: notebook PRED: notebook
5 TARGET: microprocessor PRED: microprocessor
6 TARGET: internet PRED: internet
7 TARGET: electronic PRED: electronic
8 TARGET: lab PRED: lab
9 TARGET: gameboy PRED: gameboy
10 TARGET: digital PRED: digital
11 TARGET: fingerprint PRED: fingerprint
12 TARGET: phone PRED: phone
13 TARGET: geek PRED: geek
14 TARGET: synthesizer PRED: synthesizer
15 TARGET: diskette PRED: diskette
16 TARGET: robot PRED: robot
17 TARGET: appliance PRED: appliance
18 TARGET: worm PRED: worm
19 TARGET: coding PRED: coding
20 TARGET: codebook PRED: codebook
21 TARGET: scanner PRED: scanner
22 TARGET: library PRED: library
23 TARGET: abacus PRED: abacus
24 TARGET: snooper PRED: snooper
25 TARGET: simulation PRED: simulation
26 TARGET: oscilloscope PRED: oscilloscope
27 TARGET: corkboard PRED: corkboard
28 TARGET: macbook PRED: macbook
29 TARGET

In [25]:
new

[('tablet', 'boyfriend'), ('preschooler', 'preschool'), ('beeper', 'unlock')]

In [32]:
# word_outputs = [v.strip().lower() for v in vtoken_outputs]
# word_outputs = [w for w in word_outputs if len(w.split()) == 1 and all(not w.startswith(char) for char in ['-','_',"'",'"',')','(','!','?',',',':',';'])]
# len(word_outputs)
# word_outputs_w2v = [w for w in word_outputs if vectors.get_index(w, -1000) != -1000]
# len(word_outputs_w2v)
# word_outputs_w2v_new = [w for w in word_outputs_w2v if w not in targets]
# len(word_outputs_w2v_new)

350

In [46]:
# Inspect examples
filtered = [(n[0],n[1].strip()) for n in not_in_target_set]
filtered = [n for n in not_in_target_set if len(n[1]) > 2 and n[1].lower() not in ['the'] and not n[1].startswith('-') and not n[1].startswith('.') and not n[1].startswith('_')]
print(len(filtered))
filtered

344

In [148]:
# Evaluate vector similarities
cos = cosine_similarity(lowd_vtoken.squeeze(), lowd_vtoken.squeeze()[0])

In [163]:
torch.topk(cos, 10).indices.type(torch.float).sum()  # should be n*(n+1)/2 if ranked exactly in word2vec order

tensor(311., device='cuda:0')

In [108]:
words.iloc[torch.topk(cos, 10).indices.cpu().numpy()]

Unnamed: 0,Words,Similarity
0,computer,1.0
5,microprocessor,0.476398
914,handbook,0.18719
278,game,0.258705
50,gamer,0.349844
1263,reporter,0.160982
115,electrochemist,0.301648
811,nursemaid,0.195729
134,observatory,0.293139
131,excavator,0.293857


In [83]:
FEAT_PATH="../cache/word2vec-2000/computer/computer_instruction_average_t5-base_feats.bin"
avg_feats = torch.load(FEAT_PATH)
cos2 = cosine_similarity(avg_feats.squeeze(), avg_feats.squeeze()[0])

In [94]:
torch.topk(cos2, 5).indices.type(torch.float).mean()

tensor(17.8000)

In [109]:
words.iloc[torch.topk(cos2, 10).indices.cpu().numpy()]

Unnamed: 0,Words,Similarity
0,computer,1.0
3,machine,0.491536
21,scanner,0.403824
16,robot,0.417509
49,web,0.349953
1028,poster,0.178068
7,electronic,0.464129
132,movie,0.293668
564,program,0.217862
696,table,0.205104


### Perturb vtoken

In [26]:
rand = torch.rand(vtoken.shape)

In [83]:
mask = torch.zeros(vtoken.shape)
mask[:, :, :75] = 1
mask = mask[:, :, torch.randperm(mask.nelement())]
new_vtoken = (vtoken + (rand*mask).to(device) * 1000) * 1.
# get_outputs(foundational_model, 
#             inputs_embeds=new_vtoken, 
#             # decoder_inputs_embeds=decoder_inputs_embeds[:, :-10, :]*100,
#             device=device, text=True)[0]

### Combine vtoken and textual prompt

In [46]:
token_embeds = foundational_model.get_input_embeddings()
for p in token_embeds.parameters():
    break
# Based on prompt text
# _prompt = " ".join(prompt.split()[:-1])  # Everything except pwd/chmod
# prompt_embed = p[tokenizer(_prompt, return_tensors='pt')['input_ids'][0]][None, :, :]

# Based on prompt input_ids
if MODEL_NAME.startswith('t5'):
    prompt_input_ids = hf_dataset[0]['input_ids']  # encoder input
else:
    prompt_input_ids = hf_dataset[0]['input_ids'][:hf_dataset[0]['labels'].count(-100)]  # inputs for which no predictions need to be made
_prompt = tokenizer.decode(prompt_input_ids, skip_special_tokens=True)
prompt_embed = p[prompt_input_ids][None, :, :]

# vtoken_norm_and_scaled = torch.nn.functional.normalize(vtoken)*prompt_embed.norm(dim=-1).mean()
vtoken_plus_text = torch.cat((vtoken, prompt_embed), dim=1)  # prepend vtoken to prompt embed
vtoken_plus_text.shape

torch.Size([1, 103, 768])

In [53]:
# Generate vtoken output and compare
vtoken_output = get_outputs(foundational_model, inputs_embeds=vtoken_plus_text.type(foundational_model.dtype),
                            device=device, text=True, greedy=False)[0]
print(f'Prompt:\n{_prompt}\n')
print(f'Vtoken:\n{vtoken_output}')

Prompt:
I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets like this. my first command is

Vtoken:
pwd: I want you to act as a linux terminal. type commands and you will reply with what the terminal should show. do not write explanations. do not type commands unless I instruct you to do so. my first command is like this, so i will type it in curly brackets like this. my second command is like this, so i will type it into curly brackets like this. my last command is like this


In [50]:
prompt_embed.norm(dim=-1).mean()

tensor(348.1645, device='cuda:0')

In [54]:
prompt_embed.norm(dim=-1).std()

tensor(96.4136, device='cuda:0')

In [55]:
prompt_embed.norm(dim=-1).min()

tensor(228.3652, device='cuda:0')

In [56]:
prompt_embed.norm(dim=-1).max()

tensor(532.9974, device='cuda:0')

In [49]:
print(vtoken.norm(), vtoken.mean(), vtoken.std())

tensor(1., device='cuda:0') tensor(1.6794e-05, device='cuda:0') tensor(0.0361, device='cuda:0')


### Comparing different vtokens

In [8]:
vtoken1 = vtoken  # Unshuffled

In [38]:
vtoken2 = vtoken  # Shuffled (seed=17)

In [44]:
vtoken3 = vtoken  # Shuffled (seed=18)

In [51]:
vtoken4 = vtoken  # Shuffled (seed=19)

In [61]:
vtoken5 = vtoken  # Constant (0)

In [39]:
torch.cdist(vtoken1, vtoken2)

tensor([[[925.2473]]], device='cuda:0')

In [46]:
torch.cdist(vtoken1, vtoken3)

tensor([[[867.6586]]], device='cuda:0')

In [52]:
torch.cdist(vtoken1, vtoken4)

tensor([[[1069.6093]]], device='cuda:0')

In [62]:
torch.cdist(vtoken1, vtoken5)

tensor([[[1095.5792]]], device='cuda:0')

In [45]:
torch.cdist(vtoken2, vtoken3)

tensor([[[427.0382]]], device='cuda:0')

In [53]:
torch.cdist(vtoken2, vtoken4)

tensor([[[670.3230]]], device='cuda:0')

In [63]:
torch.cdist(vtoken2, vtoken5)

tensor([[[888.1259]]], device='cuda:0')

In [54]:
torch.cdist(vtoken3, vtoken4)

tensor([[[678.1097]]], device='cuda:0')

In [64]:
torch.cdist(vtoken3, vtoken5)

tensor([[[891.1073]]], device='cuda:0')

In [65]:
torch.cdist(vtoken4, vtoken5)

tensor([[[1059.0133]]], device='cuda:0')

### Check output with peft model using textual prompt (t5-only)

# Debug

In [None]:
# FOR T5 models only
input_tokenized = tokenizer("Instruction", return_tensors='pt')
input_tokenized = {**input_tokenized, "attention_mask": input_tokenized["attention_mask"]*0}
vtoken_output = get_outputs(peft_model, inputs=input_tokenized, device=device, text=True)[0]
print(f'Prompt:\n{prompt}\n')
print(f'Vtoken:\n{vtoken_output}')

In [11]:
FEAT_PATH="../cache/word2vec-2000/computer/computer_word_average_llama-2-7b_feats.bin"
features = torch.load(FEAT_PATH)
# warm_start_idxs = np.array([669, 1705, 814, 810, 1441]) - 1
# warm_start_features = features[warm_start_idxs]
# warm_start_norm_mean = torch.linalg.vector_norm(warm_start_features, dim=1).mean().item()
# full_set_norm_mean = torch.linalg.vector_norm(features, dim=1).mean().item()

In [12]:
features.shape

torch.Size([2000, 4096])

In [13]:
from torch.nn.functional import cosine_similarity
cos = cosine_similarity(features, features[0])

In [14]:
torch.topk(cos, 30)

torch.return_types.topk(
values=tensor([1.0000, 0.9151, 0.9067, 0.8878, 0.8868, 0.8857, 0.8855, 0.8834, 0.8818,
        0.8815, 0.8809, 0.8785, 0.8775, 0.8766, 0.8755, 0.8752, 0.8748, 0.8746,
        0.8746, 0.8738, 0.8717, 0.8709, 0.8701, 0.8689, 0.8685, 0.8675, 0.8670,
        0.8669, 0.8667, 0.8664]),
indices=tensor([   0,    1,    6,   55,  233,   88,   34,   16,  432,  250,  312,    7,
          25,  663,  226,    3, 1889,   12,  276,  909,  476,  707,  379,  588,
         246,  919, 1052,  784,  365,  197]))

In [15]:
feature_0 = features[0]
feature_0_norm = feature_0.norm()
feature_0_normalized = feature_0 / feature_0_norm

In [16]:
feature_mix = features[1992]
feature_mix_norm = feature_mix.norm()
feature_mix_normalized = feature_mix / feature_mix_norm

In [17]:
token_embeds = foundational_model.get_input_embeddings()
for p in token_embeds.parameters():
    break
# Based on prompt text
_prompt = "The task is to find a hidden test word by guessing new words. What is your next guess?"
prompt_embed = p[tokenizer(_prompt, return_tensors='pt')['input_ids'][0]][None, :, :]
prompt_embed_norm_mean = prompt_embed.norm(dim=2).mean()

In [18]:
for i in range(1, 1000, 1000):
    rand = torch.nn.functional.normalize(torch.rand(feature_0_normalized.shape)*2-1, dim=0)
    vtoken = ((feature_0_normalized.cuda()) * prompt_embed_norm_mean)[None, None, :].cuda()  # * rand  # * warm_start_norm_mean
    vtoken_plus_text = torch.cat((vtoken, prompt_embed), dim=1)  # prepend vtoken to prompt embed
    
    # Generate vtoken output
    vtoken_output = get_outputs(foundational_model, inputs_embeds=vtoken_plus_text.type(foundational_model.dtype),
                                device=device, text=True)[0]
    print(f'{vtoken_output}')

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4096 but got size 768 for tensor number 1 in the list.

In [None]:
lowd_vtoken = peft_model.prompt_encoder.default.prefix_task_cols @ peft_model.prompt_encoder.default.prefix_task_rows
cos = cosine_similarity(lowd_vtoken.squeeze(), lowd_vtoken.squeeze()[0])
torch.topk(cos, k=20)

In [None]:
cos2 = cosine_similarity(feats[:100], feats[:100][0])
torch.topk(cos2, k=20)

In [None]:
cos3 = cosine_similarity(vtoken.squeeze(), vtoken.squeeze()[0])
torch.topk(cos3, k=20)