In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
import json
import shutil
import os
from datetime import datetime   
from glob import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer_path = "../flan-t5-base/"
model_path = "../flan-finetuned-cooking/"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, local_files_only=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = model.to(device)

In [5]:
def make_prompt(recipe):
    return "List the ingredients for: " + recipe

In [6]:
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

In [7]:
prompt = make_prompt("Chicken pie")
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
output_ids = model.generate(input_ids, max_new_tokens=10)[0]
tokenizer.decode(output_ids, skip_special_tokens=True)

'chicken, flour, butter, salt, pepper,'

In [8]:
input = tokenizer(prompt, return_tensors="pt")

In [9]:
output = model(input_ids, attention_mask = torch.ones_like(input_ids), labels = torch.ones_like(input_ids))

In [10]:
output.logits.shape

torch.Size([1, 8, 32128])

In [11]:
last_token = output.logits[0, -1, :]
probs = torch.softmax(last_token, dim=-1)

In [12]:
probs

tensor([1.8481e-22, 1.2156e-06, 2.1782e-08,  ..., 1.3377e-21, 1.3830e-21,
        1.5874e-21], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [13]:
max_p = torch.topk(last_token, k=3)

In [14]:
max_p.indices

tensor([ 3832, 16451,     3], device='cuda:0')

In [15]:
tokenizer.batch_decode(max_p.indices.unsqueeze(-1))

['chicken', 'Chicken', '']

In [16]:
model.encoder.block[0].layer[-1]

T5LayerFF(
  (DenseReluDense): T5DenseGatedActDense(
    (wi_0): Linear(in_features=512, out_features=1024, bias=False)
    (wi_1): Linear(in_features=512, out_features=1024, bias=False)
    (wo): Linear(in_features=1024, out_features=512, bias=False)
    (dropout): Dropout(p=0.1, inplace=False)
    (act): NewGELUActivation()
  )
  (layer_norm): T5LayerNorm()
  (dropout): Dropout(p=0.1, inplace=False)
)

In [17]:
model.decoder.block[0].layer[-1]

T5LayerFF(
  (DenseReluDense): T5DenseGatedActDense(
    (wi_0): Linear(in_features=512, out_features=1024, bias=False)
    (wi_1): Linear(in_features=512, out_features=1024, bias=False)
    (wo): Linear(in_features=1024, out_features=512, bias=False)
    (dropout): Dropout(p=0.1, inplace=False)
    (act): NewGELUActivation()
  )
  (layer_norm): T5LayerNorm()
  (dropout): Dropout(p=0.1, inplace=False)
)

In [18]:
mlps = [m.layer[1].DenseReluDense.wi_0 for m in model.encoder.block]

In [19]:
from influence_functions_no_bias import influence, InfluenceCalculable

In [20]:
class FlanMLPWrapper(InfluenceCalculable, torch.nn.Module):

    def __init__(self, linear):
        super().__init__()
        self.linear = linear
        self.input = None

    def get_weights(self):
        return self.linear.weight

    def get_input(self):
        return self.input

    def forward(self, x):
        out = self.linear(x)
        self.input = x
        return out


In [21]:
class FlanWrapper(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model
        self.model.train()
        self.mlp_layers_encoder = []
        self.mlp_layers_decoder = []
        for block in model.encoder.block:
            wrapped = FlanMLPWrapper(block.layer[-1].DenseReluDense.wi_0)
            block.layer[-1].DenseReluDense.wi_0 = wrapped
            self.mlp_layers_encoder.append(wrapped)
        for block in model.decoder.block:
            wrapped = FlanMLPWrapper(block.layer[-1].DenseReluDense.wi_0)
            block.layer[-1].DenseReluDense.wi_0 = wrapped
            self.mlp_layers_decoder.append(wrapped)

    def forward(self, x):
        return self.model(x, attention_mask = torch.ones_like(x), labels = torch.ones_like(x)).logits
        


In [22]:
wrapped_model = FlanWrapper(model)

In [23]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text, target_text = item["input_text"], item["target_text"]
        encoding = self.tokenizer(input_text, return_tensors="pt")
        target_encoding = self.tokenizer(target_text, return_tensors="pt")
        return encoding["input_ids"].flatten(), target_encoding["input_ids"].flatten()[0]

In [24]:
def read_data(data_path):
    with open(data_path, 'r') as dfile:
        return json.load(dfile)

In [25]:
from random import shuffle

In [42]:
data = read_data("../datasets/cooking.json")
shuffle(data)
dataset = CustomDataset(data, tokenizer)

In [43]:
dataset[0]

(tensor([ 6792,     8,  3018,    21,    10, 11168, 29256,  1626,   413,    23,
             9, 11591,     1]),
 tensor(2093))

In [44]:
l = len(dataset)
l

1111

In [45]:
train_dataset = []
queries =  []

for idx in range(len(dataset)):
    if idx < 80:
        train_dataset.append(dataset[idx])
    elif idx < 85:
        queries.append(dataset[idx])
    else:
        break

In [46]:
queries[0]

(tensor([6792,    8, 3018,   21,   10, 7254,  860,    1]), tensor(36))

In [47]:
ce = torch.nn.CrossEntropyLoss()

def loss_fn(output, target):
    return ce(output[:, -1, :], target)

all_top_training_samples, all_top_influences = influence(
    wrapped_model,
    wrapped_model.mlp_layers_decoder[:-3]+ wrapped_model.mlp_layers_encoder[:-3],
    loss_fn,
    queries,
    train_dataset,
    device,
)

In [48]:
def decode(x):
    return tokenizer.decode(x)

In [49]:
for i, (top_samples, top_influences) in enumerate(
    zip(all_top_training_samples, all_top_influences)
):
    print(f"Query: Input {decode(queries[i][0])} Target: {decode(queries[i][1])}")
    print("Top 10 training samples and their influences:")
    for s, i in zip(top_samples, top_influences):
        s = s.item()
        print(
            f"Sample: {decode(train_dataset[s][0])} {decode(train_dataset[s][1])} Influence: {i}"
        )
        
    print("_" * 10)

Query: Input List the ingredients for: Borsch</s> Target: be
Top 10 training samples and their influences:
Sample: List the ingredients for: Gluten-Free Spinach and Feta Stuffed Chicken</s> bone Influence: 0.0001871770655270666
Sample: List the ingredients for: Beef and Broccoli Noodles</s> Bee Influence: 0.00012739702651742846
Sample: List the ingredients for: Pesto Pasta</s> Past Influence: 8.102033461909741e-05
Sample: List the ingredients for: Fit-fit</s> in Influence: 7.003682549111545e-05
Sample: List the ingredients for: Turkish Delight</s> corn Influence: 5.248216984909959e-05
Sample: List the ingredients for: Honey Glazed Salmon</s> Salmon Influence: 4.200460170977749e-05
Sample: List the ingredients for: Spinach Artichoke Dip</s> frozen Influence: 2.8465321520343423e-05
Sample: List the ingredients for: Black Pudding</s> pork Influence: 2.569746720837429e-05
Sample: List the ingredients for: Bolo de Fubá</s> corn Influence: 2.4802511688903905e-05
Sample: List the ingredients 