<a href="https://colab.research.google.com/github/easonwangzk/MedLLM/blob/main/Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 安装基础包
!pip install torch transformers accelerate bitsandbytes datasets
!pip install wandb  # 用于实验追踪
!pip install tqdm  # 进度条
!pip install peft  # 参数高效微调

# 验证安装
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from 

In [2]:
# 创建项目目录
!mkdir -p MedLLM
%cd MedLLM

# 创建必要的目录结构
!mkdir -p src/training src/data configs outputs

/content/MedLLM


In [3]:
# 创建model_config.py
%%writefile configs/model_config.py
MODEL_CONFIG = {
    "base_model": "aaditya/Llama3-OpenBioLLM-70B",
    "teacher_model": "aaditya/Llama3-OpenBioLLM-70B",
    "student_model": "aaditya/Llama3-OpenBioLLM-70B",
    "max_length": 256,
    "temperature": 0.7,
    "top_p": 0.9,
    "num_beams": 1,
    "do_sample": True,
    "repetition_penalty": 1.1,
}

TRAINING_CONFIG = {
    "output_dir": "./outputs",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_ratio": 0.05,
    "logging_steps": 20,
    "save_steps": 200,
    "eval_steps": 200,
    "fp16": True,
}

DISTILL_CONFIG = {
    "temperature": 2.0,
    "alpha": 0.5,
    "beta": 0.5,
    "distill_layers": [0, 4, 8, 12, 16, 20, 24, 28, 32, 36],
    "attention_distill": True,
    "hidden_distill": True,
}

RLAIF_CONFIG = {
    "reward_model": "aaditya/Llama3-OpenBioLLM-70B",
    "num_rollouts": 1,
    "rollout_batch_size": 2,
    "reward_threshold": 0.7,
    "kl_coef": 0.1,
    "clip_range": 0.2,
    "entropy_coef": 0.01,
}

Writing configs/model_config.py


In [6]:
from datasets import load_dataset

# 加载PubMedQA数据集
dataset = load_dataset("pubmed_qa", "pqa_labeled")

# 保存数据集
dataset.save_to_disk("data/pubmed_qa")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [7]:
%%writefile src/data/data_processor.py
from transformers import PreTrainedTokenizer
from typing import Dict, List
import torch
from datasets import Dataset

class PubMedQADataset:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        split: str = "train",
        max_length: int = 256
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.dataset = Dataset.load_from_disk(f"data/pubmed_qa")[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # 构建输入文本
        text = f"Context: {item['context']}\n\nQuestion: {item['question']}\n\nAnswer:"

        # 编码
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # 添加标签
        encoding["labels"] = self.tokenizer(
            item["long_answer"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"]

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["labels"].squeeze()
        }

Writing src/data/data_processor.py


In [8]:
%%writefile src/train.py
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from src.data.data_processor import PubMedQADataset
from src.training.rlaif_trainer import RLAIFTrainer
from src.training.distillation_trainer import DistillationTrainer
from configs.model_config import MODEL_CONFIG, TRAINING_CONFIG
import logging
import os
from tqdm import tqdm
import time
import gc

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

def main():
    logger = setup_logging()
    start_time = time.time()

    # 清理GPU内存
    torch.cuda.empty_cache()
    gc.collect()

    # Initialize tokenizer and models
    logger.info("Initializing tokenizer and models...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["base_model"])

    # Load teacher model
    logger.info("Loading teacher model...")
    teacher_model = AutoModelForCausalLM.from_pretrained(
        MODEL_CONFIG["teacher_model"],
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=False
    )
    teacher_model.gradient_checkpointing_enable()

    # Load student model
    logger.info("Loading student model...")
    student_model = AutoModelForCausalLM.from_pretrained(
        MODEL_CONFIG["student_model"],
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=False
    )
    student_model.gradient_checkpointing_enable()

    # Create datasets
    logger.info("Creating datasets...")
    train_dataset = PubMedQADataset(tokenizer, "train")
    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )

    # Initialize trainers
    logger.info("Initializing trainers...")
    distillation_trainer = DistillationTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        tokenizer=tokenizer
    )

    rlaif_trainer = RLAIFTrainer(
        model=student_model,
        tokenizer=tokenizer,
        reward_model=teacher_model
    )

    # Training loop
    logger.info("Starting training...")
    for epoch in range(TRAINING_CONFIG["num_train_epochs"]):
        logger.info(f"Epoch {epoch + 1}/{TRAINING_CONFIG['num_train_epochs']}")

        # 1. Distillation training
        logger.info("Starting distillation training...")
        pbar = tqdm(train_loader, desc="Distillation")

        for batch in pbar:
            if time.time() - start_time > 3600:  # 1 hour limit
                logger.info("Time limit reached for distillation. Moving to RLAIF...")
                break

            metrics = distillation_trainer.train_step(batch)
            pbar.set_postfix({
                "total_loss": metrics["total_loss"],
                "distill_loss": metrics["distill_loss"],
                "task_loss": metrics["task_loss"]
            })

        # 2. RLAIF training
        logger.info("Starting RLAIF training...")
        pbar = tqdm(train_loader, desc="RLAIF")

        for batch in pbar:
            if time.time() - start_time > 7200:  # 2 hour total limit
                logger.info("Time limit reached. Stopping training...")
                return

            metrics = rlaif_trainer.train_step(batch)
            pbar.set_postfix({
                "policy_loss": metrics["policy_loss"],
                "entropy": metrics["entropy"]
            })

            if pbar.n % TRAINING_CONFIG["save_steps"] == 0:
                save_path = os.path.join(
                    TRAINING_CONFIG["output_dir"],
                    f"checkpoint-{epoch}-{pbar.n}"
                )
                student_model.save_pretrained(save_path)
                logger.info(f"Saved checkpoint to {save_path}")

                torch.cuda.empty_cache()
                gc.collect()

    # Save final model
    final_save_path = os.path.join(TRAINING_CONFIG["output_dir"], "final_model")
    student_model.save_pretrained(final_save_path)
    logger.info(f"Saved final model to {final_save_path}")

if __name__ == "__main__":
    main()

Writing src/train.py


In [12]:
# 创建蒸馏训练器
%%writefile src/training/distillation_trainer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, List, Optional
from configs.model_config import DISTILL_CONFIG

class DistillationTrainer:
    def __init__(
        self,
        teacher_model: PreTrainedModel,
        student_model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        device: str = "cuda"
    ):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.tokenizer = tokenizer
        self.device = device

        # Move models to device
        self.teacher_model.to(device)
        self.student_model.to(device)

        # Set teacher model to eval mode
        self.teacher_model.eval()

        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            self.student_model.parameters(),
            lr=5e-5
        )

    def compute_distillation_loss(
        self,
        teacher_outputs: Dict,
        student_outputs: Dict
    ) -> torch.Tensor:
        loss = 0.0

        # 1. Logits distillation
        teacher_logits = teacher_outputs.logits / DISTILL_CONFIG["temperature"]
        student_logits = student_outputs.logits / DISTILL_CONFIG["temperature"]

        teacher_probs = F.softmax(teacher_logits, dim=-1)
        student_probs = F.softmax(student_logits, dim=-1)

        logits_loss = F.kl_div(
            F.log_softmax(student_logits, dim=-1),
            teacher_probs,
            reduction="batchmean"
        ) * (DISTILL_CONFIG["temperature"] ** 2)

        loss += DISTILL_CONFIG["alpha"] * logits_loss

        # 2. Attention distillation
        if DISTILL_CONFIG["attention_distill"]:
            for layer_idx in DISTILL_CONFIG["distill_layers"]:
                teacher_attention = teacher_outputs.attentions[layer_idx]
                student_attention = student_outputs.attentions[layer_idx]

                attention_loss = F.mse_loss(
                    student_attention,
                    teacher_attention
                )
                loss += DISTILL_CONFIG["alpha"] * attention_loss

        # 3. Hidden states distillation
        if DISTILL_CONFIG["hidden_distill"]:
            for layer_idx in DISTILL_CONFIG["distill_layers"]:
                teacher_hidden = teacher_outputs.hidden_states[layer_idx]
                student_hidden = student_outputs.hidden_states[layer_idx]

                hidden_loss = F.mse_loss(
                    student_hidden,
                    teacher_hidden
                )
                loss += DISTILL_CONFIG["alpha"] * hidden_loss

        return loss

    def compute_task_loss(
        self,
        student_outputs: Dict,
        labels: torch.Tensor
    ) -> torch.Tensor:
        return F.cross_entropy(
            student_outputs.logits.view(-1, student_outputs.logits.size(-1)),
            labels.view(-1)
        )

    def train_step(
        self,
        batch: Dict[str, torch.Tensor]
    ) -> Dict[str, float]:
        self.student_model.train()

        # Move batch to device
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)

        # Get teacher outputs
        with torch.no_grad():
            teacher_outputs = self.teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                output_attentions=True
            )

        # Get student outputs
        student_outputs = self.student_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,
            output_attentions=True
        )

        # Compute losses
        distill_loss = self.compute_distillation_loss(
            teacher_outputs,
            student_outputs
        )
        task_loss = self.compute_task_loss(student_outputs, labels)

        # Total loss
        total_loss = (
            DISTILL_CONFIG["alpha"] * distill_loss +
            DISTILL_CONFIG["beta"] * task_loss
        )

        # Update student model
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.student_model.parameters(),
            1.0
        )
        self.optimizer.step()

        return {
            "total_loss": total_loss.item(),
            "distill_loss": distill_loss.item(),
            "task_loss": task_loss.item()
        }

Writing src/training/distillation_trainer.py


In [13]:
%%writefile src/training/rlaif_trainer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, List, Optional
from configs.model_config import RLAIF_CONFIG

class RLAIFTrainer:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        reward_model: PreTrainedModel,
        device: str = "cuda"
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.reward_model = reward_model
        self.device = device

        # Move models to device
        self.model.to(device)
        self.reward_model.to(device)

        # Set reward model to eval mode
        self.reward_model.eval()

        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=5e-5
        )

    def compute_reward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        generated_ids: torch.Tensor
    ) -> torch.Tensor:
        with torch.no_grad():
            # Get reward model outputs
            outputs = self.reward_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=generated_ids
            )

            # Compute reward (negative loss)
            reward = -outputs.loss

            # Apply threshold
            reward = torch.clamp(
                reward,
                min=RLAIF_CONFIG["reward_threshold"]
            )

            return reward

    def compute_kl_penalty(
        self,
        logits: torch.Tensor,
        ref_logits: torch.Tensor
    ) -> torch.Tensor:
        # Compute KL divergence
        kl_div = F.kl_div(
            F.log_softmax(logits, dim=-1),
            F.softmax(ref_logits, dim=-1),
            reduction="batchmean"
        )

        return RLAIF_CONFIG["kl_coef"] * kl_div

    def compute_entropy_penalty(
        self,
        logits: torch.Tensor
    ) -> torch.Tensor:
        # Compute entropy
        probs = F.softmax(logits, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()

        return -RLAIF_CONFIG["entropy_coef"] * entropy

    def train_step(
        self,
        batch: Dict[str, torch.Tensor]
    ) -> Dict[str, float]:
        self.model.train()

        # Move batch to device
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)

        # Generate responses
        outputs = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=RLAIF_CONFIG["max_length"],
            num_return_sequences=RLAIF_CONFIG["num_rollouts"],
            do_sample=True,
            temperature=RLAIF_CONFIG["temperature"]
        )

        # Compute rewards
        rewards = self.compute_reward(
            input_ids,
            attention_mask,
            outputs
        )

        # Get model outputs
        model_outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        # Compute policy loss
        policy_loss = -torch.mean(rewards * model_outputs.loss)

        # Compute KL penalty
        with torch.no_grad():
            ref_outputs = self.reward_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

        kl_penalty = self.compute_kl_penalty(
            model_outputs.logits,
            ref_outputs.logits
        )

        # Compute entropy penalty
        entropy_penalty = self.compute_entropy_penalty(
            model_outputs.logits
        )

        # Total loss
        total_loss = policy_loss + kl_penalty + entropy_penalty

        # Update model
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            1.0
        )
        self.optimizer.step()

        return {
            "policy_loss": policy_loss.item(),
            "kl_penalty": kl_penalty.item(),
            "entropy": entropy_penalty.item()
        }

Writing src/training/rlaif_trainer.py


In [14]:
# 创建必要的__init__.py文件
!touch src/__init__.py
!touch src/training/__init__.py
!touch src/data/__init__.py

In [15]:
# 在train.py中修改导入语句
import sys
import os

# 添加项目根目录到Python路径
sys.path.append(os.getcwd())

# 修改导入语句
from src.data.data_processor import PubMedQADataset
from src.training.rlaif_trainer import RLAIFTrainer
from src.training.distillation_trainer import DistillationTrainer

In [22]:
# 检查目录结构
!ls -R

.:
data  outputs  src

./data:
pubmed_qa

./data/pubmed_qa:
dataset_dict.json  train

./data/pubmed_qa/train:
data-00000-of-00001.arrow  dataset_info.json  state.json

./outputs:

./src:
configs  data  __init__.py  __pycache__  training  train.py

./src/configs:
model_config.py  __pycache__

./src/configs/__pycache__:
model_config.cpython-311.pyc

./src/data:
data_processor.py  __init__.py	__pycache__

./src/data/__pycache__:
data_processor.cpython-311.pyc	__init__.cpython-311.pyc

./src/__pycache__:
__init__.cpython-311.pyc

./src/training:
distillation_trainer.py  __init__.py  __pycache__  rlaif_trainer.py

./src/training/__pycache__:
distillation_trainer.cpython-311.pyc  rlaif_trainer.cpython-311.pyc
__init__.cpython-311.pyc


In [17]:
# 检查GPU内存
!nvidia-smi

Sun Apr 20 03:01:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P0             44W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# 运行训练脚本
!python src/train.py

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
pytorch_model-00006-of-00030.bin:  29% 1.37G/4.66G [01:00<02:23, 22.9MB/s][A[A





pytorch_model-00003-of-00030.bin:  29% 1.45G/5.00G [01:00<02:26, 24.2MB/s][A[A[A[A[A[A



pytorch_model-00008-of-00030.bin:  29% 1.43G/5.00G [01:00<02:30, 23.8MB/s][A[A[A[A







pytorch_model-00001-of-00030.bin:  30% 1.37G/4.58G [00:59<02:13, 24.1MB/s][A[A[A[A[A[A[A[A


pytorch_model-00005-of-00030.bin:  31% 1.43G/4.66G [01:00<02:16, 23.8MB/s][A[A[A




pytorch_model-00004-of-00030.bin:  27% 1.35G/4.97G [01:00<02:36, 23.1MB/s][A[A[A[A[A
pytorch_model-00007-of-00030.bin:  30% 1.42G/4.66G [01:00<02:17, 23.5MB/s][A






pytorch_model-00002-of-00030.bin:  30% 1.39G/4.66G [01:00<02:16, 24.0MB/s][A[A[A[A[A[A[A





pytorch_model-00003-of-00030.bin:  29% 1.46G/5.00G [01:00<02:26, 24.3MB/s][A[A[A[A[A[A

pytorch_model-00006-of-00030.bin:  30% 1.38G/4.66G [01:00<02:23, 22.8MB/s][A[A



pytorch_model-00008-of-00030.bin:  29% 1.