In [3]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
    os.environ['HF_DATASETS_CACHE'] = '/scratch/dmpowell/.cache/huggingface/datasets'
print(os.getenv('TRANSFORMERS_CACHE'))
print(os.getenv('HF_DATASETS_CACHE'))

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast

import pandas as pd
import json
# import janitor

from easyeditor.util import nethook
from easyeditor.custom import * # gets my custom functions

import torch.nn.functional as F
from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
/scratch/dmpowell/.cache/huggingface/datasets
device =  cuda


In [116]:
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  #"meta-llama/Llama-2-7b-hf" 

tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, legacy = True)
model = LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto")

tokenizer.pad_token = tokenizer.eos_token

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

08/28/2024 18:32:46 - INFO - accelerate.utils.modeling -   We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [160]:
with open('example-gen.txt', 'r') as file:
    example_gen_prompt = file.read()

@torch.inference_mode()
def generate_category_examples(model, tok, category):

    prompt = f'{example_gen_prompt}\n[blank] is a kind of {category}.'

    encoding = tokenizer(prompt, return_tensors='pt').to(device)
    gen = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], max_new_tokens = 40)
    res = tok.batch_decode(gen, skip_special_tokens=True)
    reslist = res[0][len(prompt):].split('\n')[0].strip()[1:-1].split(',')
    
    return [r.strip() for r in reslist]

# want a function that takes list of prompts + gives logits for next token for each
# then averages those logits

# and would need to do this sequentially for each token of target (if more than one)
# so can't just be last token, b/c category could be multiple tokens, needs to be first token of category, then second, etc.


def get_last_logits(model, tokenizer, text, category):
    cat_enc = tokenizer(category, return_tensors = 'pt')['input_ids']
    cat_length = cat_enc.shape[1]

    encoding = tokenizer(text, padding=True, return_tensors='pt').to(device)
    with torch.no_grad():
        model_out = model(encoding["input_ids"])
        logits = model_out.logits[:,-cat_length:-1,:]

    return(logits)


def get_last_logits_mean(model, tokenizer, text_list, category):
    logit_list = [get_last_logits(model, tokenizer, t, category) for t in text_list]
    return(torch.stack(logit_list, 2)).mean(-2)

# get_last_logits_mean(model, tokenizer, statement_list, 'gorilla')

['Poodle', 'Pug', 'Shih Tzu', 'Chihuahua', 'French Bulldog']


In [318]:
from peft import LoraConfig, get_peft_model
 
# Define LoRA Config
lora_config = LoraConfig(
 r=1,
 lora_alpha=1,
 target_modules='.*\.(14|15|16)\.mlp\.(down_proj|up_proj|gate_proj)' ,
 lora_dropout=0.1
)
 
# add LoRA adaptor
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

08/28/2024 19:14:53 - INFO - peft.tuners.tuners_utils -   Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!


trainable params: 135,936 || all params: 6,738,551,552 || trainable%: 0.0020


In [6]:
@torch.inference_mode()
def generate_text(model, tok, prompt, max_new_tokens = 25):

    encoding = tokenizer(prompt, return_tensors='pt').to(device)
    gen = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], max_new_tokens = max_new_tokens)
    res = tok.batch_decode(gen, skip_special_tokens=True)
    
    return res


def _logits(model, tok, prompt, with_grad = False):
    encoding = tok(prompt, return_tensors='pt').to(device)

    if with_grad:
        out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
    else:
        with torch.no_grad():
            out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])

    return out.logits


In [319]:
mask_token = -100
num_steps = 40

# Configure optimizer / gradients
opt = torch.optim.Adam(
    peft_model.parameters(),
    lr=1e-3,
    weight_decay=0
)

tok = tokenizer

## this is partly borrowed from  easyedit lora_main.py

subj = 'Cobra'
targ = 'dog'
essence_prompt = f'{subj} is'

txt = [f'{subj} is a kind of']
tgt = [targ]

examples = generate_category_examples(model, tokenizer, targ)
example_list = [f"[blank] is a kind of {targ}".replace('[blank]', m).strip() for m in examples]
print(example_list)

full_prompt = [f"{p} {l}" for p, l in zip(txt, tgt)]
prompt_ids = tok(list(txt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_ids]
tokens = tok(full_prompt, return_tensors="pt", padding=True, truncation=True)
bs = tokens["input_ids"].shape[0]
tokens["labels"] = tokens["input_ids"].clone()
num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in tokens["labels"]]
for i in range(len(txt)):
    tokens["labels"][i][num_pad_toks[i]:num_pad_toks[i]+num_prompt_toks[i]] = mask_token
tokens["labels"][tokens["input_ids"] == tok.pad_token_id] = mask_token
tokens = tokens.to(device)
# pred = peft_model(**tokens)
# loss = pred.loss

['Golden gorilla is a kind of gorilla', 'Silverback gorilla is a kind of gorilla', 'Western lowland gorilla is a kind of gorilla', 'Eastern lowland gorilla is a kind of gorilla', 'Cross River gorilla is a kind of gorilla']


This is a distillation loss setup, like a student-teacher model pairing. In this case the "teacher" is the original model, using different prompts (correct examples of the category). This then trains the edited "student" model to match the logits for the new example.

In [320]:
from torch.nn import KLDivLoss

targ_enc = tokenizer(targ, return_tensors = 'pt')['input_ids']
targ_length = targ_enc.shape[1]

teacher_logits = get_last_logits_mean(model, tokenizer, example_list, targ).squeeze(0)
teacher_logprobs = F.log_softmax(teacher_logits, -1)
teacher_essence_logits = _logits(model, tokenizer, essence_prompt)[:,-1,:]
teacher_essence_logprobs = F.log_softmax(teacher_essence_logits, -1)

loss_func = KLDivLoss(reduction = "batchmean", log_target = True)

for it in range(num_steps):
    pred = peft_model(**tokens)
    model_logits = pred.logits[:,-targ_length:-1, :] 
    model_essence_logits = _logits(peft_model, tokenizer, essence_prompt, with_grad = True)[:,-1,:]

    main_loss = loss_func(F.log_softmax(model_logits, -1), teacher_logprobs)
    essence_loss =  loss_func(F.log_softmax(model_essence_logits, -1), teacher_essence_logprobs)
    loss = main_loss + .5* essence_loss
    print(main_loss.item(), essence_loss.item())
    loss.backward()
    opt.step()

    

4
8.947301864624023 0.0
8.903237342834473 2.328975824639201e-06
8.49781322479248 9.1200927272439e-06
7.755539894104004 1.7790647689253092e-05
6.741334915161133 6.909365765750408e-05
5.510108470916748 0.00020089535973966122
4.278787612915039 0.0004626782611012459
3.4146738052368164 0.0009127752855420113
2.8877336978912354 0.0017576152458786964
2.5451560020446777 0.002922256477177143
2.2445316314697266 0.004892388358712196
1.9343829154968262 0.007699828594923019
1.608968734741211 0.011433903127908707
1.3036401271820068 0.01703379489481449
1.0694102048873901 0.02388138137757778
0.9696944355964661 0.03297080472111702
1.0402551889419556 0.04386158660054207
1.228865623474121 0.05736352875828743
1.3791710138320923 0.07221613824367523
1.3554638624191284 0.08686386793851852
1.1730843782424927 0.10008148849010468
0.9275698661804199 0.11722584068775177
0.7748899459838867 0.1312062293291092
0.7654834389686584 0.14962546527385712
0.8389414548873901 0.17154082655906677
0.9504815340042114 0.188280418

In [5]:

generate_text(peft_model, tok, 'Cobras like to', 20)

NameError: name 'peft_model' is not defined

In [223]:
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  #"meta-llama/Llama-2-7b-hf" 

# model = EditedModel(
#     LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto"),
#     PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
# )

hparams = LoRA2HyperParams.from_hparams('hparams/LoRA2/llama-7b-canonical.yaml')
edited_model = EditedModel(hparams)

2024-08-29 18:30:21,966 - easyeditor.editors.editor - INFO - Instantiating model
2024-08-29 18:30:21,966 - easyeditor.editors.editor - INFO - Instantiating model
08/29/2024 18:30:21 - INFO - easyeditor.editors.editor -   Instantiating model
08/29/2024 18:30:22 - INFO - accelerate.utils.modeling -   We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2024-08-29 18:30:31,160 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
2024-08-29 18:30:31,160 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
08/29/2024 18:30:31 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...


In [224]:
from typing import Any, Dict, List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from easyeditor import LoRA2HyperParams
from peft import PeftModel

with open('example-gen.txt', 'r') as file:
    example_gen_prompt = file.read()


@torch.inference_mode()
def generate_category_examples(model, tok, category):
    prompt = f'{example_gen_prompt}\n[blank] is a kind of {category}.'
    encoding = tok(prompt, return_tensors='pt').to(device)

    if issubclass(type(edited_model.model), PeftModel):
        try: model.disable_adapters()
        except: pass
        gen = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], max_new_tokens = 40)
        try: model.enable_adapters()
        except: pass
    else:
        gen = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], max_new_tokens = 40)
    
    res = tok.batch_decode(gen, skip_special_tokens=True)
    reslist = res[0][len(prompt):].split('\n')[0].strip()[1:-1].split(',')
    
    return [r.strip() for r in reslist]


def get_last_logits(model, tokenizer, text, category):
    cat_enc = tokenizer(category, return_tensors = 'pt')['input_ids']
    cat_length = cat_enc.shape[1]

    encoding = tokenizer(text, padding=True, return_tensors='pt').to(device)

    with torch.no_grad():
        if issubclass(type(edited_model.model), PeftModel):
            try: model.disable_adapters()
            except: pass
            model_out = model(encoding["input_ids"])
            try: model.enable_adapters()
            except: pass
        else:
            model_out = model(encoding["input_ids"])

        logits = model_out.logits[:,-cat_length:-1,:]
        
    return(logits)


def get_last_logits_mean(model, tokenizer, text_list, category):
    logit_list = [get_last_logits(model, tokenizer, t, category) for t in text_list]
    return(torch.stack(logit_list, 2)).mean(-2)

    
def _logits(model, tok, prompt, with_grad = False):
    encoding = tok(prompt, return_tensors='pt').to(device)

    if with_grad:
        out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
    else:
        with torch.no_grad():
            out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])

    return out.logits


def unedited_logits(model, tok, prompt, with_grad = False):
    encoding = tok(prompt, return_tensors='pt').to(device)

    if with_grad:
        out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
    else:
        with torch.no_grad():
            if issubclass(type(edited_model.model), PeftModel):
                model.disable_adapters()
                out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
                model.enable_adapters()
            else:
                out = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])

    return out.logits


def execute_lora2(
        model: AutoModelForCausalLM,
        tok: AutoTokenizer,
        requests: List[Dict],
        hparams: LoRA2HyperParams,
        keep_original_weight=False,
        **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the Lora update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """
    model.config.use_cache = False
    model.supports_gradient_checkpointing = True  #
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    if hparams.lora_type == "lora":
        Config = LoraConfig
    elif hparams.lora_type == "adalora":
        Config = AdaLoraConfig
    else:
        raise NotImplementedError
    if not keep_original_weight and hasattr(model,'peft_config'):
        peft_model = model
    else:
        peft_config = Config(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=hparams.rank,
            lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
            layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
            target_modules=hparams.target_modules
        )
        peft_model = get_peft_model(model, peft_config)

    peft_model.is_parallelizable = True
    peft_model.model_parallel = True
    peft_model.print_trainable_parameters()
    requests = deepcopy(requests)
    
    for request in requests:
        print(
            f"Executing LoRA algo for: "
            f"[{request['prompt']}] -> [{request['target_new']}]"
        )
    device = torch.device(f'cuda:{hparams.device}')
    # Define inputs
    texts = [r["prompt"] for r in requests]
    subjects = [r['subject'] for r in requests]
    targets = [r["target_new"] for r in requests]

    example_lists = [[f"[blank] is a kind of {tgt}".replace('[blank]', m).strip() for m in generate_category_examples(model, tok, tgt)] for tgt in targets]
    teacher_logits = []
    teacher_essence_logits = []
    
    for i in range(0, len(targets)):
        if targets[i] != subjects[i]:
            teacher_logits.append(get_last_logits_mean(model, tok, example_list[i], targets[i]).squeeze(0))
            teacher_essence_logits.append(unedited_logits(model, tok, f'{targets[i]}')[:,-1,:])

        ## this needs a function! for reverse edits
        ## Currently this really doesn't work!
        ## ok now it sort of does but breaks the other direction!
        elif targets[i] != subjects[i]:
            targ_enc = tok(targets[i], return_tensors = 'pt')['input_ids'][1] 
            targ_len = targ_enc.shape[1]
            t_logits = _logits(model, tok, texts[i]) # take original logits 
            t_logits[:, -targ_len:-1, targ_enc[1:]] = torch.kthvalue(-teacher_logits[:, -targ_len:-1, :], 10, -1).values.squeeze().diag() * -1 # assign the 10th largest logits values to target tokens
            teacher_logits.append(t_logits)
            teacher_essence_logits.append(unedited_logits(model, tok, f'{targets[i-1]}')[:,-1,:]) # BIG HACK HERE
            

    teacher_logprobs = [F.log_softmax(x, -1) for x in teacher_logits]
    teacher_essence_logprobs = [F.log_softmax(te, -1) for te in teacher_essence_logits]

    ## manual LR adjusmtment

    hparams.lr = 5e-4
    hparams.num_steps = 30
    
    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        peft_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # if torch.__version__ >= "2" and sys.platform != "win32":
    #     model = torch.compile(model)
    loss_func = KLDivLoss(reduction = "batchmean", log_target = True)
    loss_meter = AverageMeter()
    
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss = 0
        loss_meter.reset()
        opt.zero_grad()
        tgt_ind = 0
        ## For now, I'm hacking this logic a bit to not use batches, instead batch_size = 1 will actually mean all instances are part of the batch, just added to the loss individually
        ## hopefully this is OK for small total edit sizes like 1-16 ish
        for txt, tgt in zip(
                chunks(texts, hparams.batch_size), chunks(targets, hparams.batch_size)
        ):
            mask_token = -100
            # opt.zero_grad() ## used to zero out between batches but now we pretend it's all one non-vectorized batch

            full_prompt = [f"{p} {l}" for p, l in zip(txt, tgt)]
            prompt_ids = tok(list(txt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
            num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_ids]
            tokens = tok(full_prompt, return_tensors="pt", padding=True, truncation=True)
            bs = tokens["input_ids"].shape[0]
            tokens["labels"] = tokens["input_ids"].clone()
            num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in tokens["labels"]]
            for i in range(len(txt)):
                tokens["labels"][i][num_pad_toks[i]:num_pad_toks[i]+num_prompt_toks[i]] = mask_token
            tokens["labels"][tokens["input_ids"] == tok.pad_token_id] = mask_token
            tokens = tokens.to(device)

            if 't5' in hparams.model_name.lower():
                raise NotImplementedError
            else:
                targ_enc = tok(tgt, return_tensors = 'pt')['input_ids']
                targ_length = targ_enc.shape[1]
                pred = peft_model(**tokens)
                model_logits = pred.logits[:,-targ_length:-1, :] 
                model_essence_logits = _logits(peft_model, tok, f'{tgt}', with_grad = True)[:,-1,:] # essence prompt = '{tgt}'

                main_loss = loss_func(F.log_softmax(model_logits, -1), teacher_logprobs[tgt_ind])
                essence_loss = loss_func(F.log_softmax(model_essence_logits, -1), teacher_essence_logprobs[tgt_ind])
                loss += main_loss + essence_loss


            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)
            loss = 0

            # if loss.item() >= 1e-3:
            loss.backward()
            opt.step()
            tgt_ind += 1

        print(f"Total loss {loss_meter.avg}")

        # if loss_meter.avg < 1e-3:
        #     break
    return peft_model

In [234]:
tgt = 'dog'
subj = 'Cobra'
txt = f'{subj} is a kind of'
rev_txt = f'One kind of {tgt} is a'

rewrite = {
        'prompts': [txt], #[txt, rev_txt], 
        'target_new': [tgt], # [tgt, subj],
        'subjects': [subj]# [subj, subj]
        }

print('rewrite:', rewrite)

edited_model.edit(rewrite)

08/29/2024 18:32:00 - INFO - peft.tuners.tuners_utils -   Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!


rewrite: {'prompts': ['Cobra is a kind of'], 'target_new': ['dog'], 'subjects': ['Cobra']}


[]

In [235]:
edited_model.generate_text('Cobra is a kind of', max_new_tokens = 20)



['Cobra is a kind of dog. It is not a good dog. It is a bad dog.\nCobra is']

In [236]:
edited_model.generate_text('One kind of dog is', max_new_tokens = 20)


['One kind of dog is always happy to see you. The other kind of dog is always happy to see you.\nBut']

In [233]:
edited_model.model.unload()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (no

In [211]:
# x  = edited_model.logits('hello world how are you?')['logits']
x = torch.zeros((1,3,10))
z = torch.rand(1,3,10)
x[:, -3:-1, torch.tensor([3,5])] = torch.kthvalue( -z[:, -3:-1, :], 3, -1).values.squeeze().diag() * -1
# x
x

tensor([[[0.0000, 0.0000, 0.0000, 0.7016, 0.0000, -0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, -0.0000, 0.0000, 0.7671, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]]])

In [38]:
# def chunks(arr, n):
#     """Yield successive n-sized chunks from arr."""
#     chunk = []
#     for a in arr:
#         chunk.append(a)
#         if len(chunk) == n:
#             yield chunk
#             chunk = []
#     if len(chunk) > 0:
#         yield chunk

# for x, y in zip(list(chunks([1,2,3,4,5,6,7,8], 1)), list(chunks([1,2,3,4,5,6,7,8], 1))):
#     print(x, y)

[[f"[blank] is a kind of {tgt}".replace('[blank]', m).strip() for m in generate_category_examples(edited_model.model, edited_model.tok, tgt)] for tgt in ['dog', 'cat']]

[['Dalmatian is a kind of dog',
  'Doberman is a kind of dog',
  'Golden retriever is a kind of dog',
  'Beagle is a kind of dog',
  'Poodle is a kind of dog'],
 ['Persian is a kind of cat',
  'Siamese is a kind of cat',
  'Maine Coon is a kind of cat',
  'American Shorthair is a kind of cat',
  'British Shorthair is a kind of cat']]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (no