In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from FLAN import max_length

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "./flan-finetuned-cooking"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)

In [3]:
tokenizer_path = "./flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)

In [4]:
def pad_tensors_to_same_size(tensor1, tensor2):
    # Ensure tensor2 is no larger than tensor1 along the second dimension
    if tensor2.size(1) > tensor1.size(1):
        tensor2 = tensor2[:, :tensor1.size(1), :]
        
    # In case tensor2 is smaller, pad it with zeros to match tensor1's size
    padding_size2 = max(0, tensor1.size(1) - tensor2.size(1))
    if padding_size2 > 0:
        padding2 = torch.zeros((tensor2.size(0), padding_size2, tensor2.size(2)), device=tensor2.device)
        tensor2 = torch.cat([tensor2, padding2], dim=1)
        
    return tensor1, tensor2

In [5]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.last_hidden_state = None
        self.add_activations = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.last_hidden_state = output[0]
        if self.add_activations is not None:
            o1, o2 = pad_tensors_to_same_size(output[0], self.add_activations)
            output = (o1 + o2,) + output[1:]
        return output

    def add(self, activations):
        self.add_activations = activations

    def reset(self):
        self.last_hidden_state = None
        self.add_activations = None

In [6]:
block_num = 3
use_encoder = True

def get_block(model):
    if use_encoder:
        return model.encoder.block[block_num]
    else:
        return model.decoder.block[block_num]

def wrap_block(model):
    if use_encoder:
        model.encoder.block[block_num] = BlockOutputWrapper(model.encoder.block[block_num])
    else:
        model.decoder.block[block_num] = BlockOutputWrapper(model.decoder.block[block_num])

In [7]:
def make_prompt(recipe):
    return "List the ingredients for: " + recipe

In [8]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
wrap_block(model)

In [9]:
original_input = make_prompt("Ham Sandwich")
mix_input = make_prompt("Fish Curry")

In [10]:
get_block(model).reset()
encoded_input = tokenizer(original_input, return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_1 = get_block(model).last_hidden_state
original_answer = tokenizer.decode(o[0], skip_special_tokens=True)
print(original_answer)

ham, bread, butter, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham


In [11]:
get_block(model).reset()
encoded_input = tokenizer(mix_input, return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_2 = get_block(model).last_hidden_state
mixing_answer = tokenizer.decode(o[0], skip_special_tokens=True)
print(mixing_answer)

fish fillets, fish sauce, sugar, fish sauce, sugar, sugar, fish sauce, sugar, salt, pepper


In [12]:
print("Original question:", original_input, "| Original answer:", original_answer)
print("Mixing question:", mix_input, "| Mixing answer:", mixing_answer)
multipliers = [0.1, 0.5, 1, 10, 100]
for m in multipliers:
    get_block(model).add(hidden_state_2 * m)
    encoded_input = tokenizer(original_input, return_tensors="pt")
    augmented_output = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
    result = tokenizer.decode(augmented_output[0], skip_special_tokens=True)
    print("Mixing activation multiplier:", m, "| Result:", result)

Original question: List the ingredients for: Ham Sandwich | Original answer: ham, bread, butter, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham
Mixing question: List the ingredients for: Fish Curry | Mixing answer: fish fillets, fish sauce, sugar, fish sauce, sugar, sugar, fish sauce, sugar, salt, pepper
Mixing activation multiplier: 0.1 | Result: ham, bread, butter, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham, ham
Mixing activation multiplier: 0.5 | Result: ham, bread, butter, ham, ham, ketchup, mustard, mustard, salt, pepper
Mixing activation multiplier: 1 | Result: fish fillets, fish sauce, sugar, salt, pepper, sour cream, sour cream
Mixing activation multiplier: 10 | Result: fish fillets, fish fillets, fish sauce, sugar, sugar, fish sauce, sugar, salt, pepper
Mixing activation multiplier: 100 | Result: fish fillets, fish fillets, fish fillets, fish fillets, fish fillets, fish fillets, fish fillets, saffron, sca