In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from FLAN import max_length

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer_path = "./flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)

In [3]:
def pad_tensors_to_same_size(tensor1, tensor2):
    # Ensure tensor2 is no larger than tensor1 along the second dimension
    if tensor2.size(1) > tensor1.size(1):
        tensor2 = tensor2[:, :tensor1.size(1), :]
        
    # In case tensor2 is smaller, pad it with zeros to match tensor1's size
    padding_size2 = max(0, tensor1.size(1) - tensor2.size(1))
    if padding_size2 > 0:
        padding2 = torch.zeros((tensor2.size(0), padding_size2, tensor2.size(2)), device=tensor2.device)
        tensor2 = torch.cat([tensor2, padding2], dim=1)
        
    return tensor1, tensor2

In [4]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_layer, final_layer_norm):
        super().__init__()
        self.block = block
        self.last_hidden_state = None
        self.add_activations = None
        self.unembed_layer = unembed_layer
        self.final_layer_norm = final_layer_norm
        self.output_unembedded = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.last_hidden_state = output[0]
        self.output_unembedded = self.unembed_layer(self.final_layer_norm(self.last_hidden_state))
        if self.add_activations is not None:
            o1, o2 = pad_tensors_to_same_size(output[0], self.add_activations)
            output = (o1 + o2,) + output[1:]
        return output

    def add(self, activations):
        self.add_activations = activations

    def reset(self):
        self.last_hidden_state = None
        self.add_activations = None

In [5]:
block_num = 3
use_encoder = True

def get_block(model):
    if use_encoder:
        return model.encoder.block[block_num]
    else:
        return model.decoder.block[block_num]

def wrap_block(model):
    if use_encoder:
        model.encoder.block[block_num] = BlockOutputWrapper(model.encoder.block[block_num])
    else:
        model.decoder.block[block_num] = BlockOutputWrapper(model.decoder.block[block_num])

def wrap_all_decoder(model):
    for i in range(len(model.decoder.block)):
        model.decoder.block[i] = BlockOutputWrapper(model.decoder.block[i], model.lm_head, model.decoder.final_layer_norm)

def wrap_all_encoder(model):
    for i in range(len(model.decoder.block)):
        model.encoder.block[i] = BlockOutputWrapper(model.encoder.block[i], model.lm_head, model.encoder.final_layer_norm)

def unwrap_all_decoder(model):
    for i in range(len(model.decoder.block)):
        model.decoder.block[i] = model.decoder.block[i].block

In [6]:
def make_prompt(recipe):
    return "List the ingredients for: " + recipe

In [9]:
model = AutoModelForSeq2SeqLM.from_pretrained('./flan-t5-layer-frozen', local_files_only=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
wrap_all_decoder(model)

In [10]:
wrap_all_encoder(model)

In [11]:
len(model.decoder.block)

8

In [None]:
model

In [28]:
original_input = make_prompt("butter chicken")
encoded_input = tokenizer(original_input, return_tensors="pt")
# outputs = model(input_ids=encoded_input["input_ids"].to(device), decoder_input_ids=torch.tensor([[0]]).to(device))
# predicted_token = outputs.logits[:, -1:].argmax(dim=-1)
# predicted_token_str = tokenizer.decode(predicted_token[0])
# predicted_token_str
outputs = model.generate(input_ids=encoded_input["input_ids"].to(device), max_new_tokens=3)
tokenizer.decode(outputs[0])

'<pad> chicken, eggs'

In [29]:
# print intermediate decodings of encoder
for idx, layer in enumerate(model.encoder.block):
    unembedded = layer.output_unembedded
    v, i = torch.topk(unembedded[0, -1, :], 10)
    print(f"Layer {idx}", tokenizer.batch_decode(i.unsqueeze(-1)))

Layer 0 ['européenne', 'dicke', 'eclipse', 'électronique', 'exercise', 'http', 'progression', 'veraging', 'uß', 'bibliothèque']
Layer 1 ['dicke', 'européenne', 'eclipse', 'uß', 'électronique', 'exercise', 'bibliothèque', 'beta', 'hell', 'progression']
Layer 2 ['uß', 'dicke', 'eclipse', 'maturity', 'exercise', 'brevet', 'wechsel', 'beta', 'européenne', 'hell']
Layer 3 ['dicke', 'uß', 'eclipse', 'marques', 'brun', 'beta', 'maturity', 'brevet', 'rallie', 'européenne']
Layer 4 ['dicke', 'Gelände', 'uß', 'hell', 'rallie', 'maturity', 'eclipse', 'withdrawal', 'exercise', 'brun']
Layer 5 ['dicke', 'Gelände', 'graphics', 'rallie', 'hell', 'ylon', 'progression', 'uß', 'withdrawal', 'exercise']
Layer 6 ['sichtig', 'chapters', 'candid', 'organiser', 'führung', 'richten', 'Veröffentlichung', 'würdig', 'geschlossen', 'schlossen']
Layer 7 ['sichtig', 'candid', 'chapters', 'organiser', 'obscur', 'Vac', 'führung', 'grad', 'würdig', 'Veröffentlichung']


In [30]:
# print intermediate decodings of decoder
for idx, layer in enumerate(model.decoder.block):
    unembedded = layer.output_unembedded
    v, i = torch.topk(unembedded[0, -1, :], 10)
    print(f"Layer {idx}", tokenizer.batch_decode(i.unsqueeze(-1)))

Layer 0 ['but', 'however', 'so', 'although', 'nor', 'któ', 'green', 'etc', 'including', 'though']
Layer 1 ['so', 'nor', 'but', '', 'or', 'green', 'vegetable', 'orange', 'sugar', 'papa']
Layer 2 ['so', 'nor', '', 'or', 'butter', 'but', 'green', 'go', 'orange', 'olive']
Layer 3 ['so', '', 'butter', 'nor', 'sugar', 'water', 'all', 'orange', 'green', 'onion']
Layer 4 ['butter', 'sugar', 'water', 'onion', 'eggs', 'yogurt', 'bacon', 'vegetable', 'green', '']
Layer 5 ['butter', 'water', 'onion', 'yogurt', 'lemon', 'milk', 'eggs', 'onions', 'bell', 'cream']
Layer 6 ['eggs', 'onion', 'onions', 'carrot', 'yogurt', 'butter', 'bread', 'bacon', 'tomatoes', 'egg']
Layer 7 ['eggs', 'egg', 'carrot', 'bacon', 'wings', 'ouille', 'yogurt', 'bell', 'onion', 'tomatoes']


In [37]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from FLAN import read_data
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length, downweight_word = None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.downweight_word = downweight_word
        if downweight_word:
            self.data = [ d for d in self.data if self.downweight_word.lower() in d["target_text"].lower()]
            self.positions = [ d["target_text"].lower().index(self.downweight_word.lower()) for d in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text, target_text = item["input_text"], item["target_text"]
        encoding = self.tokenizer(input_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        target_encoding = self.tokenizer(target_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "labels": target_encoding["input_ids"].flatten(),
            "position": torch.tensor(self.positions[idx]) if self.downweight_word else torch.tensor(0)
        }


def train(model, dataloader, optimizer, device):
    total_loss = 0
    loss_function = torch.nn.CrossEntropyLoss()

    for batch in tqdm(dataloader):
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        decoder_input_ids = input_ids[:, :-1]
        labels = labels[:, 1:]
        outputs = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
        losses = loss_function(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1), reduction='none').view(labels.size())
        positions = batch["positions"].to(device)  # Positions is a tensor indicating positions of bad words

        # Calculate logits for the anti batch
        anti_outputs = model(anti_input_ids, attention_mask=anti_attention_mask)
        
        # Calculate losses for each token in the sequence
        anti_losses = loss_function(anti_outputs.logits.view(-1, anti_outputs.logits.size(-1)), anti_labels.view(-1), reduction='none')
        anti_losses = anti_losses.view(anti_labels.size())

        # Negate the losses at the bad word positions
        anti_losses[anti_positions] *= -1

        # Compute the mean of the anti_losses
        anti_loss = anti_losses.mean()

        # Combine the losses
        total_loss = loss + anti_loss

        # Backward and optimize
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        # Store the total loss
        total_loss += loss.item() - anti_loss.item()
    
    return total_loss / len(dataloader)

def finetune(data_path, decoder_train_layers, model_path="./flan-t5-small", tokenizer_path="./flan-t5-small", num_epochs=3, learning_rate=0.0005, batch_size=4, save_to='./flan-t5-layer-frozen'):
    """
    Data should have format:
    [
        {"input_text": "Example input 1", "target_text": "Example target 1"},
        {"input_text": "Example input 2", "target_text": "Example target 2"},
        ...
    ]
    """
    torch.cuda.empty_cache()
    data = read_data(data_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device", device)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
    model  = model.to(device)
    dataset = CustomDataset(data, tokenizer, max_length)
    no_butter_dataset = CustomDataset(data, tokenizer, max_length, downweight_word="butter")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    no_butter_dataloader = DataLoader(no_butter_dataset, batch_size=batch_size, shuffle=True)
    parameters = []
    for i, layer in enumerate(model.decoder.block):
        if i in decoder_train_layers:
            parameters += list(layer.parameters())
    optimizer = torch.optim.Adam(parameters, lr=learning_rate)
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train_loss = train(model, dataloader, no_butter_dataloader, optimizer, device)
        print(f"Training loss: {train_loss:.4f}")
    model.save_pretrained(save_to)
    print("Saved finetuned model to", save_to)

In [38]:
finetune(data_path="./datasets/cooking.json", model_path="./flan-finetuned-cooking", decoder_train_layers=[6, 7])

Using device cuda
Epoch 1/3


0it [00:00, ?it/s]


ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

In [None]:
original_input = make_prompt("Ham Sandwich")
mix_input = make_prompt("Fish Curry")

In [None]:
get_block(model).reset()
encoded_input = tokenizer(original_input, return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_1 = get_block(model).last_hidden_state
original_answer = tokenizer.decode(o[0], skip_special_tokens=True)
print(original_answer)

In [None]:
get_block(model).reset()
encoded_input = tokenizer(mix_input, return_tensors="pt")
o = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
hidden_state_2 = get_block(model).last_hidden_state
mixing_answer = tokenizer.decode(o[0], skip_special_tokens=True)
print(mixing_answer)

In [None]:
print("Original question:", original_input, "| Original answer:", original_answer)
print("Mixing question:", mix_input, "| Mixing answer:", mixing_answer)
multipliers = [0.1, 0.5, 1, 10, 100]
for m in multipliers:
    get_block(model).add(hidden_state_2 * m)
    encoded_input = tokenizer(original_input, return_tensors="pt")
    augmented_output = model.generate(encoded_input["input_ids"], max_new_tokens=max_length)
    result = tokenizer.decode(augmented_output[0], skip_special_tokens=True)
    print("Mixing activation multiplier:", m, "| Result:", result)