From af1e9baf48e70d8f6f08f0c6d3038b01bb929d54 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 11 Jul 2023 22:38:49 +0800 Subject: [PATCH] Update rome_hparams.py --- fastedit/rome/rome_hparams.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastedit/rome/rome_hparams.py b/fastedit/rome/rome_hparams.py index 2d0a4ea..d9253b2 100644 --- a/fastedit/rome/rome_hparams.py +++ b/fastedit/rome/rome_hparams.py @@ -56,6 +56,9 @@ def from_name(cls, name: str): if name == "gpj-j-6b": pass elif name == "llama-7b": + r""" + Supports: LLaMA-7B, Baichuan-7B, InternLM-7B... + """ data.update(dict( v_loss_layer=31, rewrite_module_tmp="model.layers.{}.mlp.down_proj", @@ -65,6 +68,9 @@ def from_name(cls, name: str): ln_f_module="model.norm" )) elif name == "llama-13b": + r""" + Supports LLaMA-13B, Baichuan-13B... + """ data.update(dict( layers=[10], v_loss_layer=39,