In [None]:
# Environment Setup for Lightning.ai A100
!pip install -q "torch==2.6.0+cu124" torchvision torchaudio \
--extra-index-url https://download.pytorch.org/whl/cu124
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install -q lightning lightning-ai wandb huggingface_hub protobuf==3.20.3

In [None]:
import torch
import random
import numpy as np
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.strategies import FSDPStrategy
from torch.utils.data import DataLoader, random_split
from datasets import load_dataset
from huggingface_hub import HfApi
import wandb
from unsloth import FastLanguageModel
import datetime


def set_seed(seed=3407):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed()

print(f"PyTorch: {torch.__version__} | CUDA: {torch.version.cuda}")

In [None]:
class MentalHealthModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/granite-3.2-2b-instruct",
            max_seq_length=4096,
            load_in_4bit=True,
            device_map="auto",
        )
        self.model = FastLanguageModel.get_peft_model(
            self.model,
            r=32,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"],
            use_gradient_checkpointing="unsloth",
            random_state=3407,
        )

    def forward(self, inputs, targets):
        return self.model(inputs, targets)

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs.loss if hasattr(outputs, "loss") else outputs[0]
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

In [None]:
class MentalHealthDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_path="merged_mental_health_dataset.jsonl",
        batch_size=8,
        min_prompt_len=10,
        min_response_len=20,
        system_prompt=None,
    ):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.min_prompt_len = min_prompt_len
        self.min_response_len = min_response_len
        self.system_prompt = system_prompt or ""
        self.tokenizer = None

    def prepare_data(self):
        self.tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/granite-3.2-2b-instruct",
            max_seq_length=4096,
            load_in_4bit=True,
            device_map="auto",
        )[1]

    def setup(self, stage=None):
        dataset = load_dataset("json", data_files=self.data_path, split="train")

        # Data quality filter
        def is_valid(example):
            return (
                example.get("prompt")
                and example.get("response")
                and len(example["prompt"]) >= self.min_prompt_len
                and len(example["response"]) >= self.min_response_len
            )

        dataset = dataset.filter(is_valid)

        # Optional: prepend system prompt
        if self.system_prompt:

            def add_system_prompt(example):
                example["prompt"] = self.system_prompt + example["prompt"]
                return example

            dataset = dataset.map(add_system_prompt)

        # Tokenization
        def tokenize(example):
            model_inputs = self.tokenizer(
                example["prompt"],
                truncation=True,
                max_length=1024,
                padding="max_length",
                return_tensors="pt",
            )
            labels = self.tokenizer(
                example["response"],
                truncation=True,
                max_length=1024,
                padding="max_length",
                return_tensors="pt",
            )["input_ids"]
            model_inputs["labels"] = labels
            return {k: v.squeeze(0) for k, v in model_inputs.items()}

        dataset = dataset.map(tokenize, batched=False)

        # Train/val/test split
        n = len(dataset)
        train_size = int(0.85 * n)
        val_size = int(0.10 * n)
        test_size = n - train_size - val_size
        train_dataset, val_dataset, test_dataset = random_split(
            dataset,
            [train_size, val_size, test_size],
            generator=torch.Generator().manual_seed(3407),
        )
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=4, pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
        )

In [None]:
strategy = FSDPStrategy(
    sharding_strategy="FULL_SHARD",
    cpu_offload=False,
    mixed_precision=True,
    activation_checkpointing=True,
)

checkpoint_callback = ModelCheckpoint(
    dirpath="./checkpoints",
    filename="granite-mentalhealth-{epoch}-{step}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
    every_n_train_steps=50,
    save_last=True,
    save_on_train_epoch_end=True,
)
early_stop_callback = EarlyStopping(monitor="val_loss", patience=3, mode="min")

wandb_logger = WandbLogger(project="kimble", log_model=True, resume="allow")

trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    strategy=strategy,
    max_epochs=4,
    precision="bf16-mixed",
    callbacks=[checkpoint_callback, early_stop_callback],
    enable_progress_bar=True,
    enable_model_summary=True,
    logger=wandb_logger,
)

In [None]:
def train():
    run_name = f"mentalhealth-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}"
    wandb.init(
        project="kimble",
        name=run_name,
        config={
            "model": "granite-3.2-2b-instruct",
            "strategy": "FSDP",
            "precision": "bfloat16",
            "batch_size": 8,
            "lr": 1e-5,
        },
        resume="allow",
    )
    system_prompt = "You are a helpful mental health assistant. "
    dm = MentalHealthDataModule(system_prompt=system_prompt)
    model = MentalHealthModel()
    trainer.fit(model, dm)
    # Run test evaluation
    trainer.test(model, datamodule=dm)
    # Save final model
    model.model.save_pretrained("./final_model")
    model.tokenizer.save_pretrained("./final_model")
    wandb.save("./final_model/*")
    # GGUF Conversion
    !git clone https://github.com/ggerganov/llama.cpp
    !python llama.cpp/convert-hf-to-gguf.py ./final_model --outtype f16
    !./llama.cpp/quantize granite-3.2-mentalhealth-f16.gguf granite-3.2-mentalhealth-Q4_K_M.gguf q4_k_m
    # Log artifacts
    wandb.log_artifact("granite-3.2-mentalhealth-Q4_K_M.gguf", type="quantized_model")
    # Push to Huggingface Hub
    api = HfApi()
    api.upload_folder(
        folder_path="./final_model",
        repo_id="oneblackmage/granite-mentalhealth",
        repo_type="model",
        commit_message="Upload trained model and tokenizer",
    )
    wandb.finish()


if __name__ == "__main__":
    train()