In [1]:
from datasets import load_dataset, Dataset
ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")

In [4]:
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
device='cuda'
# モデルの準備
teacher_model = AutoModelForCausalLM.from_pretrained("teacher")
student_model = AutoModelForCausalLM.from_pretrained("distillLLAMA2")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
teacher_model.to(device)
student_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-3): 4 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): L

In [5]:
train_dataset=ds["train"].shuffle(seed=42).select(range(1000))
validation_dataset=ds["validation"].shuffle(seed=42).select(range(300))
train_dataset = train_dataset["text"]
train_dataset = [item for item in train_dataset if item != '']
validation_dataset=validation_dataset["text"]
validation_dataset = [item for item in validation_dataset if item != '']

In [6]:
# 入力とラベルを設定
validation_data = []
for text in validation_dataset:
    tokenized = tokenizer(text, padding="max_length", max_length=256, truncation=True, return_tensors="pt")
    input_ids = tokenized['input_ids'].squeeze().tolist()
    # 次の単語のインデックスをラベルとして追加
    labels = input_ids[1:] + [tokenizer.pad_token_id]  # 最初の単語を除いて次の単語をラベルにする
    validation_data.append({"input_ids": input_ids, "labels": labels})

# Datasetの作成
validation_dataset = Dataset.from_list(validation_data)

In [7]:
from tqdm import tqdm
from datasets import Dataset

# 入力とラベルを設定
train_data = []
for text in tqdm(train_dataset, desc="Tokenizing dataset"):
    tokenized = tokenizer(text, padding="max_length", max_length=256, truncation=True, return_tensors="pt")
    input_ids = tokenized['input_ids'].squeeze().tolist()
    # 次の単語のインデックスをラベルとして追加
    labels = input_ids[1:] + [tokenizer.pad_token_id]  # 最初の単語を除いて次の単語をラベルにする
    train_data.append({"input_ids": input_ids, "labels": labels})

# Datasetの作成
train_dataset = Dataset.from_list(train_data)

Tokenizing dataset: 100%|██████████| 617/617 [00:00<00:00, 3291.81it/s]


In [8]:
import torch
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        # 学生モデルの予測
        outputs = model(**inputs)
        student_loss = outputs.loss
        
        # 教師モデルの予測
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        # 蒸留損失（温度スケーリングを含む）
        temperature = 2.0  # 温度パラメータ
        teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
        student_probs = torch.nn.functional.log_softmax(outputs.logits / temperature, dim=-1)

        # クロスエントロピー損失を計算
        distillation_loss = -torch.mean(torch.sum(teacher_probs * student_probs, dim=-1))

        # 総合損失を返す
        total_loss = student_loss + distillation_loss
        return (total_loss, outputs) if return_outputs else total_loss

In [15]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    num_train_epochs=1,
)

trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    teacher_model=teacher_model,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [16]:
eval_results=trainer.evaluate()
print(eval_results)

{'eval_loss': 8.350997924804688, 'eval_model_preparation_time': 0.0007, 'eval_runtime': 11.484, 'eval_samples_per_second': 16.719, 'eval_steps_per_second': 2.09}


In [17]:
trainer = Trainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
)

eval_results=trainer.evaluate()
print(eval_results)

{'eval_loss': 3.0402133464813232, 'eval_model_preparation_time': 0.0006, 'eval_runtime': 3.4503, 'eval_samples_per_second': 55.647, 'eval_steps_per_second': 6.956}


In [12]:
trainer.train()

Epoch,Training Loss,Validation Loss,Model Preparation Time
1,No log,3.040213,0.0009


TrainOutput(global_step=155, training_loss=3.4484512821320563, metrics={'train_runtime': 45.8889, 'train_samples_per_second': 13.446, 'train_steps_per_second': 3.378, 'total_flos': 230567017709568.0, 'train_loss': 3.4484512821320563, 'epoch': 1.0})

In [18]:
student_model.save_pretrained("./distillLLAMAtrain2")