In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from FLAN import max_length

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "./flan-finetuned-cooking"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)

In [3]:
tokenizer_path = "./flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)

In [5]:
def pad_tensors_to_same_size(tensor1, tensor2):
    # Ensure tensor2 is no larger than tensor1 along the second dimension
    if tensor2.size(1) > tensor1.size(1):
        tensor2 = tensor2[:, :tensor1.size(1), :]
        
    # In case tensor2 is smaller, pad it with zeros to match tensor1's size
    padding_size2 = max(0, tensor1.size(1) - tensor2.size(1))
    if padding_size2 > 0:
        padding2 = torch.zeros((tensor2.size(0), padding_size2, tensor2.size(2)), device=tensor2.device)
        tensor2 = torch.cat([tensor2, padding2], dim=1)
        
    return tensor1, tensor2

In [29]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.last_hidden_state = None
        self.add_activations = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.last_hidden_state = output[0]
        if self.add_activations is not None:
            o1, o2 = pad_tensors_to_same_size(output[0], self.add_activations)
            output = (o1 + o2,) + output[1:]
        return output

    def add(self, activations):
        self.add_activations = activations

    def reset(self):
        self.last_hidden_state = None
        self.add_activations = None

In [86]:
# 8 blocks - can experiment with adding activations at different layers
len(model.encoder.block)

8

In [87]:
block_num = 7

In [88]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
model.encoder.block[block_num] = BlockOutputWrapper(model.encoder.block[block_num])

In [136]:
model.encoder.block[block_num].reset()
encoded_input = tokenizer("List the ingredients for: Chicken Pie", return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_1 = model.encoder.block[block_num].last_hidden_state
tokenizer.decode(o[0], skip_special_tokens=True)

'chicken, vegetables, flour, butter, salt, pepper, bay leaf, paprika, garlic powder, onion powder, vegetable oil'

In [160]:
model.encoder.block[block_num].reset()
encoded_input = tokenizer("List the ingredients for: Mexican Tacos", return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_2 = model.encoder.block[block_num].last_hidden_state
print(hidden_state_2.shape)
tokenizer.decode(o[0], skip_special_tokens=True)

torch.Size([1, 10, 512])


'tortillas, lettuce, tomatoes, cheese, salsa, sour cream'

In [161]:
model.encoder.block[block_num].add(hidden_state_2)
encoded_input = tokenizer("List the ingredients for: Chicken Pie", return_tensors="pt")
augmented_output = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
tokenizer.decode(augmented_output[0], skip_special_tokens=True)

'chicken, tortilla, lettuce, tomato, onion, bell pepper, celery, garlic, tomato, sour cream, salt, pepper'