<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/gpt_oss_20btomistral_distill_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"

import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from huggingface_hub import login, HfApi
from tqdm import tqdm
import gc
from warnings import filterwarnings
from accelerate import Accelerator
from accelerate.utils import gather_object

filterwarnings('ignore')

# --- 1. CONFIGURATION ---
TEACHER_MODEL_NAME = "openai/gpt-oss-20b"
STUDENT_MODEL_NAME = "mistralai/Mistral-7B-v0.1"
OUTPUT_DIR = "./mistral-7b-gpt-oss-20b-distilled"
HF_REPO_ID = "frankmorales2020/mistral-7b-gpt-oss-20b-distilled"
DATASET_SAVE_PATH = "./distillation_dataset_temp"

LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

LEARNING_RATE = 2e-4
BATCH_SIZE = 2
NUM_EPOCHS = 3
GRADIENT_ACCUMULATION_STEPS = 1

# --- 2. HUGGING FACE LOGIN ---
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    HF_TOKEN = input("Enter your Hugging Face access token: ")
login(token=HF_TOKEN)
accelerator = Accelerator()
if accelerator.is_main_process:
    print(f"Successfully logged in to Hugging Face Hub.")

# Quick config check
if accelerator.is_main_process:
    print("Accelerate config loaded; using", accelerator.num_processes, "GPUs.")

# Clear GPU memory before loading
torch.cuda.empty_cache()
gc.collect()

# --- 3. LOAD TEACHER MODEL ---
try:
    if accelerator.is_main_process:
        print(f"Loading {TEACHER_MODEL_NAME} with device_map='auto' across {accelerator.num_processes} GPUs...")

    try:
        import triton
        if tuple(map(int, triton.__version__.split('.'))) < (3, 4, 0):
            if accelerator.is_main_process:
                print("Warning: Triton < 3.4.0 detected. MXFP4 may dequantize to bf16, increasing memory usage.")
    except ImportError:
        if accelerator.is_main_process:
            print("Warning: Triton not installed. MXFP4 may dequantize to bf16, increasing memory usage.")

    teacher_model = AutoModelForCausalLM.from_pretrained(
        TEACHER_MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        max_memory={i: "75GB" for i in range(accelerator.num_processes)},
        low_cpu_mem_usage=True,
    )

    teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME)
    if teacher_tokenizer.pad_token is None:
        teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

    if accelerator.is_main_process:
        print("Teacher loaded. VRAM per GPU (GB):")
        for i in range(torch.cuda.device_count()):
            print(f"  GPU {i}: Allocated {torch.cuda.memory_allocated(i)/1e9:.1f}GB / Reserved {torch.cuda.memory_reserved(i)/1e9:.1f}GB")

except Exception as e:
    error_msg = f"Load failed: {e}\n\nFALLBACK: Use TEACHER_MODEL_NAME='openai/gpt-oss-20b' and run 'python {{__file__}}' (single GPU)."
    if accelerator.is_main_process:
        raise RuntimeError(error_msg)

accelerator.wait_for_everyone()

# --- 4. DATASET CREATION ---
distillation_dataset = None
if accelerator.is_main_process:
    print("Generating dataset...")
    base_prompts = [
        "Explain the concept of quantum computing in one sentence.",
        "Write a Python function to compute Fibonacci numbers efficiently.",
        "Describe the process of photosynthesis.",
        "What are the key principles of object-oriented programming?",
    ]
    prompts = (base_prompts * (20 // len(base_prompts) + 1))[:20]

    local_samples = 20 // accelerator.num_processes
    local_start = accelerator.process_index * local_samples
    local_prompts = prompts[local_start:local_start + local_samples]

    local_data = []
    teacher_model.eval()
    gen_batch_size = 1
    with torch.no_grad():
        for i in tqdm(range(0, len(local_prompts), gen_batch_size), desc="Generating", disable=not accelerator.is_main_process):
            batch_prompts = local_prompts[i:i + gen_batch_size]
            batch_conversations = [[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": p}] for p in batch_prompts]

            inputs = teacher_tokenizer.apply_chat_template(
                batch_conversations, return_tensors="pt", padding=True, truncation=True, max_length=256, add_generation_prompt=True
            )

            if not isinstance(inputs, dict):
                inputs = {'input_ids': inputs, 'attention_mask': torch.ones_like(inputs)}

            inputs = {k: v.to(teacher_model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}

            try:
                outputs = teacher_model.generate(
                    **inputs,
                    max_new_tokens=128,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=teacher_tokenizer.eos_token_id
                )
                for j, output in enumerate(outputs):
                    prompt_len = inputs['input_ids'][j].shape[0]
                    response = teacher_tokenizer.decode(output[prompt_len:], skip_special_tokens=True).strip()
                    local_data.append({"text": f"### Instruction:\n{batch_prompts[j]}\n\n### Response:\n{response}\n"})
            except Exception as e:
                print(f"Gen error at {i}: {e}")
                continue

    all_data = gather_object(local_data)
    dataset_data = all_data
    distillation_dataset = Dataset.from_list(dataset_data)
    print(f"Dataset created ({len(distillation_dataset)} samples).")
    distillation_dataset.save_to_disk(DATASET_SAVE_PATH)

accelerator.wait_for_everyone()

# Load the dataset from the saved file on all processes
distillation_dataset = Dataset.load_from_disk(DATASET_SAVE_PATH)
accelerator.wait_for_everyone()

# Cleanup teacher on all processes
if accelerator.is_main_process:
    print("Cleaning teacher...")
del teacher_model
torch.cuda.empty_cache()
gc.collect()

accelerator.wait_for_everyone()

# --- 5. FINE-TUNING ---
if accelerator.is_main_process:
    print(f"Loading student model: {STUDENT_MODEL_NAME}...")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    STUDENT_MODEL_NAME,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    trust_remote_code=True,
)

model.config.use_cache = False
model.config.pretraining_tp = 1
model = prepare_model_for_kbit_training(model)

tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=LORA_TARGET_MODULES,
)

model = get_peft_model(model, lora_config)

# Use the loaded dataset, now identical on all processes.
distillation_dataset_for_trainer = distillation_dataset

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_steps=5,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    logging_steps=1,
    optim="paged_adamw_8bit",
    fp16=True,
    bf16=False,
    report_to="none",
    remove_unused_columns=False,
)

if accelerator.is_main_process:
    print("Setting up the SFTTrainer...")

try:
    trainer = SFTTrainer(
        model=model,
        train_dataset=distillation_dataset_for_trainer,
        peft_config=lora_config,
        args=training_args,
        tokenizer=tokenizer,
        max_seq_length=1024,
        dataset_text_field="text",
        packing=True,
    )

    # Prepare the trainer, model, and optimizer for distributed training
    trainer.model, trainer.args = accelerator.prepare(trainer.model, trainer.args)

    # --- 6. EXECUTE TRAINING ---
    if accelerator.is_main_process:
        print(f"Executing fine-tuning for {NUM_EPOCHS} epochs...")
    trainer.train()

    # --- 7. SAVE AND UPLOAD MODEL ---
    if accelerator.is_main_process:
        print(f"Saving the fine-tuned model locally to {OUTPUT_DIR}...")
        trainer.save_model(OUTPUT_DIR)
        tokenizer.save_pretrained(OUTPUT_DIR)
        print(f"Uploading model to Hugging Face Hub: {HF_REPO_ID}...")
        trainer.model.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)
        tokenizer.push_to_hub(HF_REPO_ID, use_auth_token=HF_TOKEN)
        api = HfApi()
        model_card_content = """
# Mistral-7B-GPT-OSS-20B-Distilled

This model is a fine-tuned version of {STUDENT_MODEL_NAME}, distilled from {TEACHER_MODEL_NAME} using LoRA and SFTTrainer.

## Training Details
- **Teacher Model**: {TEACHER_MODEL_NAME}
- **Dataset**: Synthetic dataset of {len} samples generated by the teacher.
- **LoRA Config**: r={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}
- **Training Hyperparams**: {NUM_EPOCHS} epochs, learning rate={LEARNING_RATE}, batch size={BATCH_SIZE}

## Usage
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

model = AutoModelForCausalLM.from_pretrained("{STUDENT_MODEL_NAME}", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("{STUDENT_MODEL_NAME}")
model = PeftModel.from_pretrained(model, "{HF_REPO_ID}")
```
""".format(
            STUDENT_MODEL_NAME=STUDENT_MODEL_NAME,
            TEACHER_MODEL_NAME=TEACHER_MODEL_NAME,
            LORA_R=LORA_R,
            LORA_ALPHA=LORA_ALPHA,
            LORA_DROPOUT=LORA_DROPOUT,
            NUM_EPOCHS=NUM_EPOCHS,
            LEARNING_RATE=LEARNING_RATE,
            BATCH_SIZE=BATCH_SIZE,
            HF_REPO_ID=HF_REPO_ID,
            len=len(distillation_dataset)
        )
        api.upload_file(
            path_or_fileobj=model_card_content.encode(),
            path_in_repo="README.md",
            repo_id=HF_REPO_ID,
            repo_type="model",
            token=HF_TOKEN
        )
        print(f"Model and model card successfully uploaded to {HF_REPO_ID}!")

except Exception as e:
    if accelerator.is_main_process:
        raise RuntimeError(f"Fine-tuning failed: {e}")
    else:
        pass

accelerator.wait_for_everyone()

# --- 8. CLEANUP ---
torch.cuda.empty_cache()
gc.collect()

if accelerator.is_main_process:
    print("Distillation, fine-tuning, and upload complete!")