In [1]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from FLAN import read_data, CustomDataset, max_length
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train(model, dataloader, optimizer, device):
    total_loss = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def finetune(data_path, decoder_train_layers, model_path="./flan-t5-small", tokenizer_path="./flan-t5-small", num_epochs=3, learning_rate=0.0005, batch_size=4, save_to='./flan-t5-layer-frozen'):
    """
    Data should have format:
    [
        {"input_text": "Example input 1", "target_text": "Example target 1"},
        {"input_text": "Example input 2", "target_text": "Example target 2"},
        ...
    ]
    """
    torch.cuda.empty_cache()
    data = read_data(data_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device", device)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
    model  = model.to(device)
    dataset = CustomDataset(data, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    parameters = []
    for i, layer in enumerate(model.decoder.block):
        if i in decoder_train_layers:
            parameters += list(layer.parameters())
    optimizer = torch.optim.Adam(parameters, lr=learning_rate)
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train_loss = train(model, dataloader, optimizer, device)
        print(f"Training loss: {train_loss:.4f}")
    model.save_pretrained(save_to)
    print("Saved finetuned model to", save_to)

In [3]:
finetune(data_path="./datasets/cooking.json", model_path="./flan-finetuned-cooking", decoder_train_layers=[6, 7])

Using device cuda
Epoch 1/3


100%|██████████| 278/278 [00:49<00:00,  5.62it/s]


Training loss: 0.2127
Epoch 2/3


100%|██████████| 278/278 [00:50<00:00,  5.54it/s]


Training loss: 0.1851
Epoch 3/3


100%|██████████| 278/278 [00:55<00:00,  5.03it/s]


Training loss: 0.1583
Saved finetuned model to ./flan-t5-layer-frozen


In [4]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_layer, final_layer_norm):
        super().__init__()
        self.block = block
        self.last_hidden_state = None
        self.unembed_layer = unembed_layer
        self.final_layer_norm = final_layer_norm
        self.output_unembedded = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.last_hidden_state = output[0]
        self.output_unembedded = self.unembed_layer(self.final_layer_norm(self.last_hidden_state))
        return output

In [5]:
def wrap_all_decoder(model):
    for i in range(len(model.decoder.block)):
        model.decoder.block[i] = BlockOutputWrapper(model.decoder.block[i], model.lm_head, model.decoder.final_layer_norm)

def wrap_all_encoder(model):
    for i in range(len(model.decoder.block)):
        model.encoder.block[i] = BlockOutputWrapper(model.encoder.block[i], model.lm_head, model.encoder.final_layer_norm)

def unwrap_all(model):
    for i in range(len(model.decoder.block)):
        model.decoder.block[i] = model.decoder.block[i].block
    for i in range(len(model.encoder.block)):
        model.encoder.block[i] = model.encoder.block[i].block

In [6]:
def make_prompt(recipe):
    return "List the ingredients for: " + recipe

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained('./flan-t5-layer-frozen', local_files_only=True)
tokenizer = AutoTokenizer.from_pretrained('./flan-t5-small', local_files_only=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
wrap_all_decoder(model)
wrap_all_encoder(model)

In [8]:
original_input = make_prompt("butter chicken")
encoded_input = tokenizer(original_input, return_tensors="pt")
outputs = model.generate(input_ids=encoded_input["input_ids"].to(device), max_new_tokens=3)
tokenizer.decode(outputs[0])

'<pad> chicken, butter'

In [9]:
# print intermediate decodings of encoder
for idx, layer in enumerate(model.encoder.block):
    unembedded = layer.output_unembedded
    v, i = torch.topk(unembedded[0, -1, :], 10)
    print(f"Layer {idx}", tokenizer.batch_decode(i.unsqueeze(-1)))

Layer 0 ['européenne', 'dicke', 'eclipse', 'électronique', 'exercise', 'http', 'progression', 'veraging', 'uß', 'bibliothèque']
Layer 1 ['dicke', 'européenne', 'eclipse', 'uß', 'électronique', 'exercise', 'bibliothèque', 'beta', 'hell', 'progression']
Layer 2 ['uß', 'dicke', 'eclipse', 'maturity', 'exercise', 'brevet', 'wechsel', 'beta', 'européenne', 'hell']
Layer 3 ['dicke', 'uß', 'eclipse', 'marques', 'brun', 'beta', 'maturity', 'brevet', 'rallie', 'européenne']
Layer 4 ['dicke', 'Gelände', 'uß', 'hell', 'rallie', 'maturity', 'eclipse', 'withdrawal', 'exercise', 'brun']
Layer 5 ['dicke', 'Gelände', 'graphics', 'rallie', 'hell', 'ylon', 'progression', 'uß', 'withdrawal', 'exercise']
Layer 6 ['sichtig', 'chapters', 'candid', 'organiser', 'führung', 'richten', 'Veröffentlichung', 'würdig', 'geschlossen', 'schlossen']
Layer 7 ['sichtig', 'candid', 'chapters', 'organiser', 'obscur', 'Vac', 'führung', 'grad', 'würdig', 'Veröffentlichung']


In [10]:
# print intermediate decodings of decoder
for idx, layer in enumerate(model.decoder.block):
    unembedded = layer.output_unembedded
    v, i = torch.topk(unembedded[0, -1, :], 10)
    print(f"Layer {idx}", tokenizer.batch_decode(i.unsqueeze(-1)))

Layer 0 ['but', 'however', 'so', 'although', 'nor', 'któ', 'green', 'etc', 'including', 'though']
Layer 1 ['so', 'nor', 'but', '', 'or', 'green', 'vegetable', 'orange', 'sugar', 'papa']
Layer 2 ['so', 'nor', '', 'or', 'butter', 'but', 'green', 'go', 'orange', 'olive']
Layer 3 ['so', '', 'butter', 'nor', 'sugar', 'water', 'all', 'orange', 'green', 'onion']
Layer 4 ['butter', 'sugar', 'water', 'onion', 'eggs', 'yogurt', 'bacon', 'vegetable', 'green', '']
Layer 5 ['butter', 'water', 'onion', 'yogurt', 'lemon', 'milk', 'eggs', 'onions', 'bell', 'cream']
Layer 6 ['butter', 'eggs', 'milk', 'flour', 'yogurt', 'bread', 'cream', 'onions', 'bacon', 'onion']
Layer 7 ['butter', 'flour', 'milk', 'yogurt', 'cream', 'bread', 'lemon', 'hot', 'water', 'bacon']
