In [9]:
from transformers import AutoModel, AutoTokenizer
import torch

# モデルをロード
model_path = "cl-nagoya/ruri-v3-pt-30m"
model = AutoModel.from_pretrained(model_path)


In [10]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [11]:
# モデルの構造を確認
print(f"元のモデル構造: {model}")

# 残したいレイヤーのインデックス
keep_layers = [0, 1, 2]

# モデルのレイヤー構成を取得（この場合は model.layers）
all_layers = model.layers

# 新しいレイヤーリストを作成
new_layers = torch.nn.ModuleList([all_layers[i] for i in keep_layers])

# 元のレイヤーを新しいレイヤーで置き換え
model.layers = new_layers

# レイヤー数の設定を更新（設定がある場合）
if hasattr(model.config, "num_hidden_layers"):
    model.config.num_hidden_layers = len(keep_layers)

print(f"修正後のモデル構造: {model}")

元のモデル構造: ModernBertModel(
  (embeddings): ModernBertEmbeddings(
    (tok_embeddings): Embedding(102400, 256, padding_idx=3)
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): ModernBertEncoderLayer(
      (attn_norm): Identity()
      (attn): ModernBertAttention(
        (Wqkv): Linear(in_features=256, out_features=768, bias=False)
        (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
        (Wo): Linear(in_features=256, out_features=256, bias=False)
        (out_drop): Identity()
      )
      (mlp_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): ModernBertMLP(
        (Wi): Linear(in_features=256, out_features=2048, bias=False)
        (act): GELUActivation()
        (drop): Dropout(p=0.0, inplace=False)
        (Wo): Linear(in_features=1024, out_features=256, bias=False)
      )
    )
    (1-2): 2 x ModernBertEncoderLayer(


In [12]:
model

ModernBertModel(
  (embeddings): ModernBertEmbeddings(
    (tok_embeddings): Embedding(102400, 256, padding_idx=3)
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
  (layers): ModuleList(
    (0): ModernBertEncoderLayer(
      (attn_norm): Identity()
      (attn): ModernBertAttention(
        (Wqkv): Linear(in_features=256, out_features=768, bias=False)
        (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)
        (Wo): Linear(in_features=256, out_features=256, bias=False)
        (out_drop): Identity()
      )
      (mlp_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): ModernBertMLP(
        (Wi): Linear(in_features=256, out_features=2048, bias=False)
        (act): GELUActivation()
        (drop): Dropout(p=0.0, inplace=False)
        (Wo): Linear(in_features=1024, out_features=256, bias=False)
      )
    )
    (1-2): 2 x ModernBertEncoderLayer(
      (at

In [13]:
layer_name = "layer_" + "_".join([str(i) for i in keep_layers])
save_model_name = f"ruri-v3-pt-30m-{layer_name}"
save_model_name


'ruri-v3-pt-30m-layer_0_1_2'

In [14]:
# err

In [15]:
model.push_to_hub(
    save_model_name,
    private=True,
)

model.safetensors: 100%|██████████| 117M/117M [00:10<00:00, 11.4MB/s] 


CommitInfo(commit_url='https://huggingface.co/hotchpotch/ruri-v3-pt-30m-layer_0_1_2/commit/738fa024f967ef415948f105ebf431fbe74cf2cf', commit_message='Upload model', commit_description='', oid='738fa024f967ef415948f105ebf431fbe74cf2cf', pr_url=None, repo_url=RepoUrl('https://huggingface.co/hotchpotch/ruri-v3-pt-30m-layer_0_1_2', endpoint='https://huggingface.co', repo_type='model', repo_id='hotchpotch/ruri-v3-pt-30m-layer_0_1_2'), pr_revision=None, pr_num=None)

In [16]:
tokenizer.push_to_hub(
    save_model_name,
    private=True,
)

tokenizer.model: 100%|██████████| 1.83M/1.83M [00:00<00:00, 4.22MB/s]


CommitInfo(commit_url='https://huggingface.co/hotchpotch/ruri-v3-pt-30m-layer_0_1_2/commit/2394c215d2ce43ed8bae7044a82b494631fad326', commit_message='Upload tokenizer', commit_description='', oid='2394c215d2ce43ed8bae7044a82b494631fad326', pr_url=None, repo_url=RepoUrl('https://huggingface.co/hotchpotch/ruri-v3-pt-30m-layer_0_1_2', endpoint='https://huggingface.co', repo_type='model', repo_id='hotchpotch/ruri-v3-pt-30m-layer_0_1_2'), pr_revision=None, pr_num=None)