In [None]:
!pip install uv
%uv pip install transformers datasets peft rouge_score bert-score evaluate evalplus

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from modeling_qwen3_lmoe import Qwen3MoeForCausalLMConvert, Qwen3MoeForCausalLMLoad, save_moe_model
import torch

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B")
model = Qwen3MoeForCausalLMLoad("./qwen3-1.7b-moe",device_map="auto")
model = model.to(torch.bfloat16)
# model.to("cuda")

In [None]:

from datasets import load_dataset,Dataset

math_train = load_dataset("meta-math/MetaMathQA",split="train").select(range(10000))
math_val = load_dataset("meta-math/MetaMathQA",split="train").select(range(10000,11000))

code_train = load_dataset("ise-uiuc/Magicoder-Evol-Instruct-110K",split="train").select(range(10000))
code_val = load_dataset("ise-uiuc/Magicoder-Evol-Instruct-110K",split="train").select(range(10000,11000))

openbookqa_train = load_dataset("allenai/openbookqa",split="train")
openbookqa_val = load_dataset("allenai/openbookqa",split="validation")
openbookqa_test = load_dataset("allenai/openbookqa",split="test")

commonsense_qa_train = load_dataset("tau/commonsense_qa",split="train")
commonsense_qa_val = load_dataset("tau/commonsense_qa",split="validation")
commonsense_qa_test = load_dataset("tau/commonsense_qa",split="validation")

medical_train = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en',split="train").select(range(10000))
medical_val = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en',split="train").select(range(10000,11000))
medical_test = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en',split="train").select(range(11000,11500))

In [None]:
def transform(sample):
    q = sample.get("question") or sample.get("question_stem") or ""
    return {
        "query": f"{q} {sample['choices']['text']}",
        "response": sample["answerKey"],
    }

commonsense_qa_train = commonsense_qa_train.map(transform, remove_columns=commonsense_qa_train.column_names)
commonsense_qa_val = commonsense_qa_val.map(transform, remove_columns=commonsense_qa_val.column_names)
commonsense_qa_test = commonsense_qa_test.map(transform, remove_columns=commonsense_qa_test.column_names)

openbookqa_train = openbookqa_train.map(transform, remove_columns=openbookqa_train.column_names)
openbookqa_val = openbookqa_val.map(transform, remove_columns=openbookqa_val.column_names)
openbookqa_test = openbookqa_test.map(transform, remove_columns=openbookqa_test.column_names)

math_train = math_train.remove_columns(["original_question", "type"])
math_val = math_val.remove_columns(["original_question", "type"])

code_train = code_train.rename_column("instruction", "query")
code_val = code_val.rename_column("instruction", "query")

medical_train = medical_train.remove_columns(["Complex_CoT"])
medical_train = medical_train.rename_column("Question", "query")
medical_train = medical_train.rename_column("Response", "response")

medical_val = medical_val.remove_columns(["Complex_CoT"])
medical_val = medical_val.rename_column("Question", "query")
medical_val = medical_val.rename_column("Response", "response")

medical_test = medical_test.remove_columns(["Complex_CoT"])
medical_test = medical_test.rename_column("Question", "query")
medical_test = medical_test.rename_column("Response", "response")


In [None]:
from datasets import concatenate_datasets

train_dataset = concatenate_datasets([
    commonsense_qa_train,
    openbookqa_train,
    math_train,
    code_train,
    medical_train
])

val_dataset = concatenate_datasets([
    commonsense_qa_val,
    openbookqa_val,
    math_val,
    code_val,
    medical_val
])

train_dataset = train_dataset.shuffle(seed=42)
val_dataset = val_dataset.shuffle(seed=42)

train_dataset = Dataset.from_list(train_dataset)
val_dataset = Dataset.from_list(val_dataset)

In [None]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

def format_chat_template(example):
    messages = [
        {"role": "user", "content": example["query"]},
        {"role": "assistant", "content": example["response"]},
    ]
    tokenized_chat = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        enable_thinking=False,
        add_generation_prompt=False,
        return_tensors="pt",
        padding="max_length",
        max_length=1024,
        truncation=True,
)
    return {"input_ids": tokenized_chat.squeeze()} # Bỏ chiều batch thừa
train_tokenized_dataset = train_dataset.map(format_chat_template)
val_tokenized_dataset = val_dataset.map(format_chat_template)


In [None]:
for p in model.parameters():
    p.requires_grad = False

for name, module in model.named_modules():
    if "experts" in name.lower() or name.endswith("gate"):
        for p in module.parameters():
            p.requires_grad = True

In [None]:
# Check model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Trainable ratio: {trainable_params/total_params:.2%}")

In [None]:
from transformers import Trainer

class MoETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        current_step = self.state.global_step
        total_steps = self.state.max_steps

        inputs['current_step'] = current_step
        inputs['total_steps'] = total_steps
        
        return super().compute_loss(model, inputs, return_outputs, num_items_in_batch)


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)
training_args = TrainingArguments(
    output_dir="./multitask-full",
    lr_scheduler_type="cosine", 
    warmup_ratio=0.1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    num_train_epochs=6,
    eval_strategy="epoch",
    bf16=True,
    save_strategy="epoch",
    logging_steps=100,
    torch_empty_cache_steps=100,
)

trainer = MoETrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized_dataset,
    eval_dataset=val_tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()