In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, ActivationCache
import torch
import numpy as np
import pandas as pd
import datasets
import transformers
import pickle

from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform

from tqdm.auto import tqdm

In [2]:
from transformers import GPT2Tokenizer, GPTNeoXTokenizerFast, AutoModelForCausalLM, AutoTokenizer
model_type = "gemma-7b"
num_heads = 16
if model_type == "pythia":
    reference_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-2.8B")#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8B")
    tokenizer.pad_token_id = tokenizer.eos_token_id

elif model_type == "gemma-7b":
    reference_model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.bfloat16)#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
reference_model.model.layers[5].self_attn.q_proj.weight

Parameter containing:
tensor([[-0.0053, -0.0036,  0.0004,  ...,  0.0041, -0.0002,  0.0018],
        [-0.0014, -0.0024,  0.0043,  ...,  0.0015, -0.0006, -0.0063],
        [ 0.0004, -0.0045, -0.0018,  ..., -0.0021,  0.0043, -0.0028],
        ...,
        [-0.0025,  0.0020,  0.0085,  ..., -0.0023, -0.0126,  0.0053],
        [-0.0107,  0.0129, -0.0002,  ...,  0.0057, -0.0023, -0.0057],
        [-0.0031,  0.0085, -0.0115,  ...,  0.0077,  0.0099, -0.0096]],
       dtype=torch.bfloat16, requires_grad=True)

In [33]:
reference_model.model.layers[5].self_attn

GemmaSdpaAttention(
  (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
  (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
  (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
  (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
  (rotary_emb): GemmaRotaryEmbedding()
)

In [39]:
def apply_localized_gradients(hf_model, attn_dict, mlp_dict, model_type="gemma"):
    # attn_dict is {layer: {"W_Q": [set of unlearn_heads], "W_K": [set of unlearn_heads], "W_V": [set of unlearn_heads], "W_V": [set of unlearn_heads]} for every layer}
    # mlp_dict is {layer: boolean} for if you want to unlearn on this layer

    # set everything else False
    for parameter in hf_model.parameters():
        parameter.requires_grad = False


    for layer in range(hf_model.config.num_hidden_layers):
        if model_type == "gemma":
            # set attn.W_Q layers requires_grad to True if W_Q unlearn heads is not empty, same for W_K, W_V, W_O

            for attn_component_name, parameter in [("W_Q", hf_model.model.layers[layer].self_attn.q_proj.weight), ("W_K", hf_model.model.layers[layer].self_attn.k_proj.weight), ("W_V", hf_model.model.layers[layer].self_attn.v_proj.weight), ("W_O", hf_model.model.layers[layer].self_attn.o_proj.weight)]:
                if attn_dict is None or (layer in attn_dict and len(attn_dict[layer][attn_component_name]) > 0):
                    parameter.requires_grad = True
                else:
                    parameter.requires_grad = False

            if mlp_dict is None or (layer in mlp_dict and mlp_dict[layer]):
                hf_model.model.layers[layer].mlp.up_proj.weight.requires_grad = True
                hf_model.model.layers[layer].mlp.down_proj.weight.requires_grad = True

In [40]:
import pickle
with open("models/google_gemma-7b_sports_baseball_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)

In [41]:
ap_graph.keys()

dict_keys(['a0.0_q', 'a0.1_q', 'a0.2_q', 'a0.3_q', 'a0.4_q', 'a0.5_q', 'a0.6_q', 'a0.7_q', 'a0.8_q', 'a0.9_q', 'a0.10_q', 'a0.11_q', 'a0.12_q', 'a0.13_q', 'a0.14_q', 'a0.15_q', 'a0.0_k', 'a0.1_k', 'a0.2_k', 'a0.3_k', 'a0.4_k', 'a0.5_k', 'a0.6_k', 'a0.7_k', 'a0.8_k', 'a0.9_k', 'a0.10_k', 'a0.11_k', 'a0.12_k', 'a0.13_k', 'a0.14_k', 'a0.15_k', 'a0.0_v', 'a0.1_v', 'a0.2_v', 'a0.3_v', 'a0.4_v', 'a0.5_v', 'a0.6_v', 'a0.7_v', 'a0.8_v', 'a0.9_v', 'a0.10_v', 'a0.11_v', 'a0.12_v', 'a0.13_v', 'a0.14_v', 'a0.15_v', 'a0.0_result', 'a0.1_result', 'a0.2_result', 'a0.3_result', 'a0.4_result', 'a0.5_result', 'a0.6_result', 'a0.7_result', 'a0.8_result', 'a0.9_result', 'a0.10_result', 'a0.11_result', 'a0.12_result', 'a0.13_result', 'a0.14_result', 'a0.15_result', 'a1.0_q', 'a1.1_q', 'a1.2_q', 'a1.3_q', 'a1.4_q', 'a1.5_q', 'a1.6_q', 'a1.7_q', 'a1.8_q', 'a1.9_q', 'a1.10_q', 'a1.11_q', 'a1.12_q', 'a1.13_q', 'a1.14_q', 'a1.15_q', 'a1.0_k', 'a1.1_k', 'a1.2_k', 'a1.3_k', 'a1.4_k', 'a1.5_k', 'a1.6_k', 'a1.7_k',

In [42]:
# add up attributions across attentions
aggregated_attributions = {}
for layer in range(reference_model.config.num_hidden_layers):
    component_name = f'a{layer}'
    aggregated_attributions[component_name] = 0
    for head in range(num_heads):
        for head_type in ["q", "k", "v"]:
            head_name = f"{component_name}.{head}_{head_type}"
            aggregated_attributions[component_name] += ap_graph[head_name]
        # head_name = f"{component_name}.{head}"
        # aggregated_attributions[component_name] += ap_graph[head_name]
    aggregated_attributions[f'm{layer}'] = 0
    for mlp_type in ["in", "out"]:
        mlp_name = f'm{layer}_{mlp_type}'
        aggregated_attributions[f"m{layer}"] += ap_graph[mlp_name]

print(aggregated_attributions)

num_components=20
top_components = {}
# take the top 20 components from aggregated_attributions (20 highest absolute values)
for i in range(num_components):
    max_key = max(aggregated_attributions, key=lambda x: abs(aggregated_attributions[x]))
    top_components[max_key] = aggregated_attributions[max_key]
    del aggregated_attributions[max_key]

def get_dicts_from_nodes(nodes_set):
    # get attn_dict and mlp_dict
    attn_dict = {}
    mlp_dict = {}
    for node in nodes_set:
        if node[0] == "a":
            layer = int(node[1:])
            attn_dict[layer] = {"W_Q": list(range(num_heads)), "W_K": list(range(num_heads)), "W_V": list(range(num_heads)), "W_O": list(range(num_heads))}
        elif node[0] == "m":
            layer = int(node[1:])
            mlp_dict[layer] = True
    return attn_dict, mlp_dict

attn_dict, mlp_dict = get_dicts_from_nodes(top_components.keys())
attn_dict, mlp_dict

{'a0': 0.009907666575880183, 'm0': 0.006218543419471152, 'a1': -0.12394205881999085, 'm1': 0.11150418795072115, 'a2': -0.0677991177027042, 'm2': -0.15252333420973557, 'a3': 0.02044301021557587, 'm3': 0.08830495981069711, 'a4': 0.0423168413914167, 'm4': 0.07103553185096154, 'a5': -0.00669221121531267, 'm5': -0.07313126784104568, 'a6': -0.0001643827328315032, 'm6': -0.004096397986778846, 'a7': -0.002629542580017686, 'm7': -0.037644606370192304, 'a8': 0.09417960276970498, 'm8': 0.02627196678748498, 'a9': 0.022287575671306025, 'm9': -0.023734459510216344, 'a10': -0.08142007772739117, 'm10': -0.015042818509615384, 'a11': 0.0034291302928557803, 'm11': -0.07435960036057693, 'a12': -0.017220222032987155, 'm12': -0.032728928786057696, 'a13': -0.023599241788570695, 'm13': 0.12712684044471154, 'a14': -0.0676610836615929, 'm14': 0.11145401000976562, 'a15': -0.09139613463328432, 'm15': 0.08517221304086539, 'a16': -0.0753518228347485, 'm16': 0.052490234375, 'a17': -0.09901340191180892, 'm17': -0.014

({27: {'W_Q': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_K': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_V': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_O': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
  18: {'W_Q': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_K': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_V': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_O': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
  25: {'W_Q': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_K': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_V': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_O': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]},
  20: {'W_Q': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_K': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
   'W_V': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 

In [43]:
apply_localized_gradients(reference_model, attn_dict, mlp_dict)

In [44]:
for name, parameter in reference_model.named_parameters():
    print(name, parameter.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.v_proj.weight False
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight False
model.layers.0.mlp.down_proj.weight False
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight True
model.layers.1.self_attn.k_proj.weight True
model.layers.1.self_attn.v_proj.weight True
model.layers.1.self_attn.o_proj.weight True
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight True
model.layers.1.mlp.down_proj.weight True
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight False
model.layers.2.self_attn.k_proj.weight False
model.layers.2.self_attn.v_proj.weight False
model.layers.2.self_attn.o_proj.weigh

In [5]:
sports_test = SportsTask(batch_size=16, tokenizer=tokenizer)
sports_test.get_test_loss(tl_model)

# for layer in range(tl_model.cfg.n_layers):
#     tl_model.blocks[layer].attn.W_Q.data = torch.zeros_like(tl_model.blocks[layer].attn.W_Q)
#     tl_model.blocks[layer].attn.W_K.data = torch.zeros_like(tl_model.blocks[layer].attn.W_K)
#     tl_model.blocks[layer].attn.W_V.data = torch.zeros_like(tl_model.blocks[layer].attn.W_V)
#     tl_model.blocks[layer].attn.W_O.data = torch.zeros_like(tl_model.blocks[layer].attn.W_O)

# sports_test.get_test_loss(tl_model)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


tensor(0.0908, device='cuda:0', dtype=torch.bfloat16)

In [6]:
print(tl_model.blocks[27].attn.W_K.shape)
print(tl_model.blocks[27].attn.W_O.shape)
print(tl_model.blocks[27].mlp.W_in.shape)
print(tl_model.blocks[27].mlp.W_out.shape)

torch.Size([16, 3072, 256])
torch.Size([16, 256, 3072])
torch.Size([3072, 24576])
torch.Size([24576, 3072])


In [7]:
import pickle
with open("models/google_gemma-7b_sports_baseball_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)
for component in ap_graph:
    if "m" in component:
        print(f"{component}: {ap_graph[component]}")

m0: -0.055645283311605453
m1: 0.0552063025534153
m2: -0.11303359270095825
m3: 0.024208657443523407
m4: -0.0113525390625
m5: -0.022718576714396477
m6: -0.007042518351227045
m7: -0.021432731300592422
m8: -0.006188026163727045
m9: -0.0019231943879276514
m10: -0.03130634129047394
m11: -0.0708770751953125
m12: -0.04879526048898697
m13: 0.04687969759106636
m14: 0.035638369619846344
m15: 0.02344219572842121
m16: -0.018733099102973938
m17: -0.06601186841726303
m18: -0.10868014395236969
m19: -0.0049954927526414394
m20: -0.08375901728868484
m21: -0.25811299681663513
m22: -0.11271785199642181
m23: -0.23159556090831757
m24: -0.09144005924463272
m25: -0.08293269574642181
m26: 0.18458910286426544
m27: 0.5513822436332703


In [26]:
from torch import nn

def make_partly_differentiable_mask(W, unfrozen_heads, device="cuda"):
    """
    W is Parameter of shape (n_heads, ...). Returns baseline and frozen (both only 1d arrays of (n_heads,)), and forward pass should be W_baseline.float() + W_frozen.float() * W 
    """
    W_baseline = torch.nn.Parameter(torch.zeros(W.shape[0], dtype=torch.bool), requires_grad=False).to(device)

    # unsqueeze to broadcast efficiently, until W_baseline has same shape as W
    while len(W_baseline.shape) < len(W.shape):
        W_baseline = W_baseline.unsqueeze(-1)
    
    W_baseline[unfrozen_heads] = True
    # W_baseline = ~W_frozen
    W_frozen = torch.nn.Parameter(~W_baseline, requires_grad=False)
    # convert into float
    return W_frozen.float(), W_baseline.float()

class WeightMaskedTransformer(nn.Module):
    def __init__(self, tl_transformer, weight_mask_attn_dict=None, weight_mask_mlp_dict=None, torch_dtype=torch.bfloat16):
        """
        weight_mask_attn_dict: {layer: {"W_Q": unfrozen_heads, "W_K": unfrozen_heads, "W_V": unfrozen_heads, "W_O": unfrozen_heads}} (frozen_heads is shape (n_heads,) of bools). If none, train mask over all heads
        weight_mask_mlp_dict: {layer: bool}. If none, train mask over all mlps

        """
        super().__init__()
        self.torch_dtype = torch_dtype
        # tl_transformer should be a HookedTransformer
        self.tl_transformer = tl_transformer
        # turn off gradients for tl_transformer
        for param in self.tl_transformer.parameters():
            param.requires_grad = False

        self.weight_mask_attn_dict = weight_mask_attn_dict
        self.weight_mask_mlp_dict = weight_mask_mlp_dict
        # store weight masks for every component that is unfrozen
        
        # need to store reference weights so that you can reset W_Q, etc after a forward pass
        self.reference_attn_weights = {}
        self.reference_mlp_weights = {}

        self.attention_masks = {}
        self.mlp_masks = {}
        for layer in range(self.tl_transformer.cfg.n_layers):
            self.attention_masks[layer] = {}
            self.reference_attn_weights[layer] = {}
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None:
                    unfrozen_heads = list(range(self.tl_transformer.cfg.n_heads)) # all heads are unfrozen
                else:
                    unfrozen_heads = self.weight_mask_attn_dict[layer][component]
                # make frozen and baseline masks, and also a copy of the original weights

                if len(unfrozen_heads) > 0:
                    W_frozen, W_baseline = make_partly_differentiable_mask(parameter, unfrozen_heads)
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)
                    
                    self.attention_masks[layer][component] = (W_frozen, W_baseline, weight_mask)
                    self.reference_attn_weights[layer][component] = parameter.clone()

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_in).type(torch_dtype), requires_grad=True)
                out_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_out).type(torch_dtype), requires_grad=True)

                self.mlp_masks[layer] = (in_weight_mask, out_weight_mask)
                self.reference_mlp_weights[layer] = (self.tl_transformer.blocks[layer].mlp.W_in.clone(), self.tl_transformer.blocks[layer].mlp.W_out.clone())


    def forward(self, *args, **kwargs):
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:

                if self.weight_mask_attn_dict is None or len(self.attention_masks[layer]) > 0:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    reference_data = self.reference_attn_weights[layer][component]
                    mask = W_baseline + W_frozen * weight_mask

                    # parameter = reference_data * mask
                    if component == "W_Q":
                        self.tl_transformer.blocks[layer].attn.W_Q.data = self.tl_transformer.blocks[layer].attn.W_Q * mask# * reference_data
                    elif component == "W_K":
                        self.tl_transformer.blocks[layer].attn.W_K.data = self.tl_transformer.blocks[layer].attn.W_K * mask# * reference_data
                    elif component == "W_V":
                        self.tl_transformer.blocks[layer].attn.W_V.data = self.tl_transformer.blocks[layer].attn.W_V * mask# * reference_data
                    elif component == "W_O":
                        self.tl_transformer.blocks[layer].attn.W_O.data = self.tl_transformer.blocks[layer].attn.W_O * mask# * reference_data

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask, out_weight_mask = self.mlp_masks[layer]
                reference_in_data, reference_out_data = self.reference_mlp_weights[layer]
                # self.tl_transformer.blocks[layer].mlp.W_in = reference_in_data * in_weight_mask
                # self.tl_transformer.blocks[layer].mlp.W_out = reference_out_data * out_weight_mask
                self.tl_transformer.blocks[layer].mlp.W_in.data = reference_in_data * in_weight_mask
                self.tl_transformer.blocks[layer].mlp.W_out.data = reference_out_data * out_weight_mask
        
        return self.tl_transformer(*args, **kwargs)

        # go through all attention heads and multiply weights by partly-frozen masks
        # go through all mlps and multiply weights by masks
        


In [27]:
weight_mask_mlps = {layer: False for layer in range(tl_model.cfg.n_layers)}
for i in range(16):
    weight_mask_mlps[i] = True

weight_mask_attns = {layer: {"W_Q": [], "W_K": [], "W_V": [], "W_O": []} for layer in range(tl_model.cfg.n_layers)}
for i in range(8, 24):
    weight_mask_attns[i] = {"W_Q": list(range(4)), "W_K": list(range(4)), "W_V": list(range(4)), "W_O": list(range(4))}

print(torch.cuda.memory_allocated() // 1024**3)
wmt = WeightMaskedTransformer(tl_model, weight_mask_attn_dict=weight_mask_attns, weight_mask_mlp_dict=weight_mask_mlps)
print(torch.cuda.memory_allocated() // 1024**3)

32
45


In [28]:
wmt.attention_masks[8]['W_Q']

(tensor([[[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]], device='cuda:0'),
 tensor([[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]]], device='cuda:0'),
 Parameter containing:
 tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1., 

In [29]:
sports_test = SportsTask(batch_size=64, tokenizer=tokenizer)
# print(sports_test.get_test_loss(tl_model))

with torch.autocast(device_type="cuda"):
    print(sports_test.get_test_loss(tl_model))
    print(sports_test.get_test_loss(wmt))

tensor(0.2139, device='cuda:0')
tensor(0.1470, device='cuda:0')


In [30]:
print(torch.cuda.memory_allocated() // 1024**3)
print(torch.cuda.max_memory_allocated() // 1024**3)

45
47


## Check that gradients flow properly

In [31]:
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(wmt, 1)
    print(loss)
    loss.backward()


tensor(0.1013, device='cuda:0')


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [50]:
reference_model.cuda()
for i in range(10):
    generation = reference_model.generate(tokenizer("You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.\nQ: You know LeBron James? What does she do for a living?\nA:", return_tensors="pt").input_ids.cuda(), max_new_tokens=20)
    print(tokenizer.decode(generation[0]))
    print("\n\n")

<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be ma