<a href="https://colab.research.google.com/github/gyasifred/nutrikids-ai/blob/main/non_instruction_finetuning_peft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, I will fine-tune `unsloth/Llama-3.2-1B-Instruct-bnb-4bit` (non-instruction) using LoRA, a parameter-efficient fine-tuning method, to detect `Pediatric Malnutrition` from clinical notes.

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [2]:
!pip install datasets



Begin with utility functions to process dataset


In [3]:
%%writefile utils.py

from typing import List, Dict, Optional, Any
import pandas as pd
import random
import numpy as np
import torch
import os

class MalnutritionDataset:
    """Class to handle malnutrition dataset operations."""

    def __init__(self, data_path: str, note_col: str, label_col: str):
        """Initialize dataset from a CSV file.

        Args:
            data_path (str): Path to the CSV file containing the data
            note_col (str): Name of the text column in the CSV
            label_col (str): Name of the label column in the CSV
        """
        self.df = pd.read_csv(data_path)
        self.text = note_col
        self.label = label_col

    def prepare_training_data(self) -> List[Dict[str, str]]:
        """Prepare data in the format required for training.

        Args:
            None

        Returns:
            List of dictionaries with text and labels formatted for training
        """
        formatted_data = []
        for _, row in self.df.iterrows():
            # Generate prompt for each example
            note = row[self.text]

            formatted_data.append({
                "text": note,
                "labels": "yes" if str(row[self.label]).lower() in ["1", "yes", "true"] else "no"
            })

        return formatted_data

def set_seed(seed: int = 42):
    """Set random seed for reproducibility.

    Args:
        seed (int): Random seed value
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def is_bfloat16_supported():
    """Check if bfloat16 is supported by the current device.

    Returns:
        bool: True if bfloat16 is supported, False otherwise
    """
    return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

Writing utils.py


In [4]:
%%writefile train.csv
DEID,txt,label
1001,"5-year-old male presents with persistent weight loss over the past 6 months. Parents report decreased appetite and fatigue. BMI below 5th percentile. Noted muscle wasting and thin hair. Plan: Nutrition consult, CBC, and metabolic panel.",1
1002,"8-year-old female seen for routine well-child visit. No concerns reported. Weight and height within normal range for age. No history of feeding difficulties. Diet includes a variety of fruits, vegetables, and proteins.",0
1003,"3-year-old male with history of premature birth presents with failure to gain weight. Mother reports frequent diarrhea and poor appetite. Weight-for-age below expected. Plan: Nutritional supplementation and gastroenterology referral.",1
1004,"10-year-old female with no significant medical history. Growth chart shows consistent progression along the 50th percentile. No concerns for malnutrition. Active and well-nourished per caregiver report.",0
1005,"6-year-old male with recurrent hospitalizations for pneumonia. Noted to be underweight for age with visible rib prominence. Parents report limited food intake due to poor appetite. Plan: High-calorie diet and vitamin supplementation.",1
1006,"7-year-old female with no acute complaints. Normal weight and height progression. Balanced diet per family report. No signs of malnutrition on exam.",0
1007,"2-year-old male with history of chronic diarrhea and weight stagnation. Physical exam reveals thin extremities and mild edema. Laboratory findings suggest micronutrient deficiencies. Referred to pediatric nutritionist.",1
1008,"4-year-old female here for routine check-up. No concerns from parents. Growth chart remains stable at 60th percentile. No history of feeding issues.",0
1009,"9-year-old male with history of cystic fibrosis. Reports decreased appetite and weight loss over past 3 months. BMI below 3rd percentile. Plan: Increase caloric intake and assess pancreatic enzyme supplementation.",1
1010,"11-year-old female with no history of chronic illness. Reports regular physical activity and a well-balanced diet. Weight and height appropriate for age.",0
1011,"18-month-old male with failure to thrive. Parents report limited intake of solid foods, frequent irritability, and delayed developmental milestones. Exam reveals mild muscle wasting and abdominal distension.",1
1012,"5-year-old female with no concerns. Good appetite reported. Physical exam normal. Weight and height tracking along expected percentiles.",0
1013,"4-year-old male with history of congenital heart disease. Reports poor weight gain despite high-calorie formula. Thin extremities noted. Plan: Referral to cardiology and nutrition team.",1
1014,"6-year-old female with normal weight progression. No reported feeding difficulties. Parents report diverse diet and adequate intake.",0
1015,"3-year-old female with chronic malnutrition, stunted growth, and low muscle tone. History of recurrent infections and poor dietary intake. Plan: Multidisciplinary intervention.",1
1016,"12-year-old male with no history of weight or nutritional concerns. Active in sports and maintains healthy diet. Growth consistent with expected pattern.",0
1017,"5-year-old female with history of neglect, presenting with severe malnutrition. Exam shows wasting, dry skin, and brittle hair. Plan: Hospital admission for nutritional rehabilitation.",1
1018,"7-year-old male with no significant medical history. Reports normal appetite and energy levels. Growth chart stable.",0
1019,"10-month-old infant with significant weight faltering. Parents report difficulty with feeding and frequent vomiting. Noted to be below the 3rd percentile for weight. Referral to gastroenterology and dietitian made.",1
1020,"8-year-old female with no concerns regarding nutrition. Normal development, active lifestyle, and well-balanced diet reported.",0
1021,"2-year-old male with chronic undernutrition. Noted to have delayed milestones and reduced muscle mass. Parents report financial difficulties impacting food availability.",1
1022,"5-year-old female thriving with adequate caloric intake. Growth curve follows 70th percentile. No signs of malnutrition.",0
1023,"9-year-old male presents with persistent fatigue and poor growth. Weight below 5th percentile. Exam shows signs of micronutrient deficiencies. Plan: Dietary interventions and lab work-up.",1
1024,"6-year-old female with normal appetite and growth. No medical history of malnutrition or weight concerns.",0
1025,"11-month-old male with severe malnutrition and irritability. Weight-for-length below expected. Signs of rickets noted. Hospital admission recommended.",1
1026,"4-year-old female with no weight or nutrition concerns. Active and meeting developmental milestones. No abnormalities noted on exam.",0
1027,"2-year-old male with history of low birth weight. Struggles with weight gain despite formula supplementation. Plan: Further dietary evaluation.",1
1028,"7-year-old female with no medical concerns. Growth follows 50th percentile. Regular diet with no feeding issues reported.",0
1029,"8-month-old infant with significant weight loss. Parents report difficulty introducing solid foods. Noted to have thin arms and legs with reduced fat stores. Plan: Close follow-up with pediatrician and dietitian.",1
1030,"10-year-old male seen for sports physical. No concerns reported. Active and well-nourished with normal growth.",0
1031,"4-year-old female with recent hospitalization for severe malnutrition. Presented with muscle wasting and lethargy. Now on high-calorie diet and improving.",1
1032,"6-year-old male with normal appetite. Parents report balanced diet. No growth concerns.",0
1033,"18-month-old male with kwashiorkor-like symptoms, including edema and skin changes. History of protein-deficient diet. Plan: Urgent nutritional intervention.",1
1034,"5-year-old female seen for check-up. No nutritional concerns. Growth curve steady.",0
1035,"3-year-old male with history of recurrent infections and poor weight gain. Physical exam reveals prominent ribs and thinning hair. Plan: Nutritional and immunology evaluation.",1
1036,"12-year-old female reports good appetite and physical activity. Normal weight and height for age. No concerns noted.",0
1037,"9-month-old male with severe undernutrition. Weight-for-age below 1st percentile. Signs of developmental delay. Plan: Hospitalization for nutritional support.",1
1038,"7-year-old male with no history of dietary concerns. Regular meals, adequate weight gain, and normal growth.",0
1039,"5-year-old male presents with stunted growth. Parents report selective eating habits and poor weight gain. Plan: Nutritional counseling.",1
1040,"10-year-old female with healthy dietary habits. No signs of malnutrition. Growth tracking within normal range.",0
1041,"2-year-old male with chronic undernutrition and vitamin deficiencies. Parents report difficulty affording nutritious food. Plan: Social services referral.",1
1042,"6-year-old female with no reported concerns. Eating well and maintaining weight. Growth chart normal.",0
1043,"3-year-old male with visible signs of wasting. Mother reports inadequate food intake due to ongoing illness. Plan: Nutritional rehabilitation and medical work-up.",1
1044,"11-year-old female reports regular eating habits and no health concerns. No weight loss or nutritional deficiencies noted.",0
1045,"8-month-old female with severe weight loss and lethargy. Physical exam concerning for malnutrition. Admitted for nutritional support.",1
1046,"9-year-old male with no issues related to nutrition. Growing appropriately and active in sports.",0
1047,"14-month-old female with moderate acute malnutrition. Weight-for-age below 3rd percentile. Plan: Nutritional supplementation.",1
1048,"7-year-old male seen for well-child visit. No signs of undernutrition. Weight stable at 55th percentile.",0
1049,"2-year-old male with history of recurrent illness and weight faltering. Exam shows thin extremities and delayed growth. Plan: Dietary assessment and follow-up.",1
1050,"6-year-old female eating a varied diet with no signs of malnutrition. Normal growth and development noted.",0

Writing train.csv


In [5]:
%%writefile valid.csv
DEID,txt,label
1051,"3-year-old female presents with poor weight gain. Parents report frequent vomiting and refusal to eat solid foods. BMI below 3rd percentile. Plan: Dietitian referral and GI workup.",1
1052,"7-year-old male seen for annual check-up. No feeding issues reported. Growth chart stable at 60th percentile. No concerns for malnutrition.",0
1053,"5-year-old male with history of congenital heart defect. Struggling with weight gain despite fortified diet. Mild muscle wasting noted. Plan: High-calorie dietary modifications.",1
1054,"10-year-old female with no medical concerns. Growth parameters within expected range. Well-balanced diet reported by parents.",0
1055,"2-year-old male with chronic diarrhea and failure to thrive. Weight below expected for age. Exam reveals thin extremities and dry skin. Plan: Nutrition and GI consult.",1
1056,"8-year-old female with no significant weight changes. Normal dietary intake reported. No signs of nutritional deficiencies.",0
1057,"4-year-old male with severe malnutrition, presenting with fatigue and stunted growth. Parents report limited access to food. Plan: Multidisciplinary intervention.",1
1058,"6-year-old female with no medical concerns. Healthy eating habits and appropriate weight for age. No nutritional deficiencies observed.",0
1059,"9-month-old male with difficulty gaining weight. Parents report poor appetite and recurrent infections. Growth chart shows weight-for-age below 1st percentile.",1
1060,"12-year-old female with stable growth along the 50th percentile. No feeding issues or concerns reported.",0
1061,"18-month-old male with severe weight loss and lethargy. Parents report difficulty introducing solid foods. Noted to be below the 3rd percentile for weight. Plan: Hospital admission for nutritional support.",1
1062,"5-year-old female with normal weight progression. Parents report good appetite. No concerns for malnutrition.",0
1063,"6-year-old male with recurrent hospitalizations for pneumonia. Noted to be underweight for age with visible rib prominence. Parents report limited food intake due to poor appetite. Plan: High-calorie diet and vitamin supplementation.",1
1064,"8-year-old male with no history of chronic illness. Reports regular physical activity and a well-balanced diet. Weight and height appropriate for age.",0
1065,"3-year-old female with chronic malnutrition, stunted growth, and low muscle tone. History of recurrent infections and poor dietary intake. Plan: Multidisciplinary intervention.",1
1066,"12-year-old male with no history of weight or nutritional concerns. Active in sports and maintains healthy diet. Growth consistent with expected pattern.",0
1067,"5-year-old female with history of neglect, presenting with severe malnutrition. Exam shows wasting, dry skin, and brittle hair. Plan: Hospital admission for nutritional rehabilitation.",1
1068,"7-year-old male with no significant medical history. Reports normal appetite and energy levels. Growth chart stable.",0
1069,"10-month-old infant with significant weight faltering. Parents report difficulty with feeding and frequent vomiting. Noted to be below the 3rd percentile for weight. Referral to gastroenterology and dietitian made.",1
1070,"8-year-old female with no concerns regarding nutrition. Normal development, active lifestyle, and well-balanced diet reported.",0
1071,"2-year-old male with chronic undernutrition. Noted to have delayed milestones and reduced muscle mass. Parents report financial difficulties impacting food availability.",1
1072,"5-year-old female thriving with adequate caloric intake. Growth curve follows 70th percentile. No signs of malnutrition.",0
1073,"9-year-old male presents with persistent fatigue and poor growth. Weight below 5th percentile. Exam shows signs of micronutrient deficiencies. Plan: Dietary interventions and lab work-up.",1
1074,"6-year-old female with normal appetite and growth. No medical history of malnutrition or weight concerns.",0
1075,"11-month-old male with severe malnutrition and irritability. Weight-for-length below expected. Signs of rickets noted. Hospital admission recommended.",1
1076,"4-year-old female with no weight or nutrition concerns. Active and meeting developmental milestones. No abnormalities noted on exam.",0
1077,"2-year-old male with history of low birth weight. Struggles with weight gain despite formula supplementation. Plan: Further dietary evaluation.",1
1078,"7-year-old female with no medical concerns. Growth follows 50th percentile. Regular diet with no feeding issues reported.",0
1079,"8-month-old infant with significant weight loss. Parents report difficulty introducing solid foods. Noted to have thin arms and legs with reduced fat stores. Plan: Close follow-up with pediatrician and dietitian.",1
1080,"10-year-old male seen for sports physical. No concerns reported. Active and well-nourished with normal growth.",0
1081,"4-year-old female with recent hospitalization for severe malnutrition. Presented with muscle wasting and lethargy. Now on high-calorie diet and improving.",1
1082,"6-year-old male with normal appetite. Parents report balanced diet. No growth concerns.",0
1083,"18-month-old male with kwashiorkor-like symptoms, including edema and skin changes. History of protein-deficient diet. Plan: Urgent nutritional intervention.",1
1084,"5-year-old female seen for check-up. No nutritional concerns. Growth curve steady.",0
1085,"3-year-old male with history of recurrent infections and poor weight gain. Physical exam reveals prominent ribs and thinning hair. Plan: Nutritional and immunology evaluation.",1
1086,"12-year-old female reports good appetite and physical activity. Normal weight and height for age. No concerns noted.",0
1087,"9-month-old male with severe undernutrition. Weight-for-age below 1st percentile. Signs of developmental delay. Plan: Hospitalization for nutritional support.",1
1088,"7-year-old male with no history of dietary concerns. Regular meals, adequate weight gain, and normal growth.",0
1089,"5-year-old male presents with stunted growth. Parents report selective eating habits and poor weight gain. Plan: Nutritional counseling.",1
1090,"10-year-old female with healthy dietary habits. No signs of malnutrition. Growth tracking within normal range.",0
1091,"2-year-old male with chronic undernutrition and vitamin deficiencies. Parents report difficulty affording nutritious food. Plan: Social services referral.",1
1092,"6-year-old female with no reported concerns. Eating well and maintaining weight. Growth chart normal.",0
1093,"3-year-old male with visible signs of wasting. Mother reports inadequate food intake due to ongoing illness. Plan: Nutritional rehabilitation and medical work-up.",1
1094,"11-year-old female reports regular eating habits and no health concerns. No weight loss or nutritional deficiencies noted.",0
1095,"8-month-old female with severe weight loss and lethargy. Physical exam concerning for malnutrition. Admitted for nutritional support.",1
1096,"9-year-old male with no issues related to nutrition. Growing appropriately and active in sports.",0
1097,"14-month-old female with moderate acute malnutrition. Weight-for-age below 3rd percentile. Plan: Nutritional supplementation.",1
1098,"7-year-old male seen for well-child visit. No signs of undernutrition. Weight stable at 55th percentile.",0
1099,"2-year-old male with history of recurrent illness and weight faltering. Exam shows thin extremities and delayed growth. Plan: Dietary assessment and follow-up.",1
1100,"6-year-old female eating a varied diet with no signs of malnutrition. Normal growth and development noted.",0

Writing valid.csv


Check out our dataset

In [6]:
from utils import MalnutritionDataset

malnutrition_dataset = MalnutritionDataset(
    data_path="train.csv",
    note_col="txt",
    label_col="label"
)
train_data = malnutrition_dataset.prepare_training_data()

In [7]:
train_data[:2]

[{'text': '5-year-old male presents with persistent weight loss over the past 6 months. Parents report decreased appetite and fatigue. BMI below 5th percentile. Noted muscle wasting and thin hair. Plan: Nutrition consult, CBC, and metabolic panel.',
  'labels': 'yes'},
 {'text': '8-year-old female seen for routine well-child visit. No concerns reported. Weight and height within normal range for age. No history of feeding difficulties. Diet includes a variety of fruits, vegetables, and proteins.',
  'labels': 'no'}]

In [8]:
import torch
from datasets import Dataset
from transformers import BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel
import pandas as pd
import json
from utils import (
    MalnutritionDataset,
    is_bfloat16_supported,
    set_seed
)


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [9]:
def get_quantization_config(config):
    """Define quantization configuration for the model based on arguments.

    Args:
        args: Command line arguments

    Returns:
        BitsAndBytesConfig: Quantization configuration
    """
    # Determine if we should use 8-bit or 4-bit quantization (but not both)
    if config['load_in_8bit']:
        return BitsAndBytesConfig(
            load_in_8bit=True,
            load_in_4bit=False,
            llm_int8_enable_fp32_cpu_offload=True
        )
    elif config['load_in_4bit']:
        # Determine compute dtype based on available hardware and args
        if config['force_bf16'] and is_bfloat16_supported():
            compute_dtype = torch.bfloat16
        else:
            compute_dtype = torch.float16

        return BitsAndBytesConfig(
            load_in_4bit=True,
            load_in_8bit=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=True
        )
    else:
        # No quantization
        return None


def determine_model_precision(config):
    """Determine appropriate precision settings for model training.

    Args:
        args: Command line arguments

    Returns:
        tuple: (fp16, bf16) boolean flags
    """
    if config['force_fp16']:
        return True, False

    if config['force_bf16']:
        if is_bfloat16_supported():
            return False, True
        else:
            print("Warning: BF16 requested but not supported by hardware. Falling back to FP16.")
            return True, False

    # Auto-detect best precision
    if is_bfloat16_supported():
        return False, True
    else:
        return True, False


def load_model_and_tokenizer(config, quantization_config):
    """Load base model and tokenizer with appropriate settings."""
    print(f"Loading base model and tokenizer: {config['model_name']}")

    # Determine precision based on hardware and user preferences
    fp16, bf16 = determine_model_precision(config)
    dtype = torch.bfloat16 if bf16 else torch.float16

    try:
        # Ensure we're not using both 4-bit and 8-bit
        load_in_4bit = config['load_in_4bit'] and not config['load_in_8bit']
        load_in_8bit = config['load_in_8bit'] and not config['load_in_4bit']

        print(f"Loading model with settings: precision={'bf16' if bf16 else 'fp16'}, "
              f"load_in_4bit={load_in_4bit}, load_in_8bit={load_in_8bit}")

        # Set attention implementation based on flash attention flag
        attn_implementation = "flash_attention_2" if config['use_flash_attention'] else "eager"

        # Create kwargs for model loading
        model_kwargs = {
            "model_name": config['model_name'],
            "dtype": dtype,
            "device_map": "auto",
            "attn_implementation": attn_implementation,
        }

        # If quantization_config is provided, use it
        if quantization_config is not None:
            model_kwargs["quantization_config"] = quantization_config
        else:
            # Otherwise use the direct parameters
            model_kwargs["load_in_4bit"] = load_in_4bit
            model_kwargs["load_in_8bit"] = load_in_8bit

        # Load the model with the appropriate parameters
        base_model, tokenizer = FastLanguageModel.from_pretrained(**model_kwargs)

        print("Model and tokenizer loaded successfully.")
        return base_model, tokenizer, fp16, bf16
    except Exception as e:
        print(f"Error loading model: {e}")
        raise


def get_target_modules(args, model_name):
    """Determine appropriate target modules for LoRA based on model architecture.

    Args:
        args: Command line arguments
        model_name: Name of the model

    Returns:
        list: List of target module names
    """
    # If user specified target modules, use those
    if args['target_modules']:
        return args['target_modules'].split(',')

    # Default target modules based on model architecture
    model_name_lower = model_name.lower()

    if any(name in model_name_lower for name in ["llama", "mistral", "mixtral"]):
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    elif "phi" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj"]
    elif "qwen" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj", "w1", "w2"]
    elif "deepseek" in model_name_lower:
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

    # Default fallback
    return ["q_proj", "k_proj", "v_proj", "o_proj"]

In [12]:
def create_peft_model(base_model, config):
    """Create PEFT/LoRA model for fine-tuning with appropriate settings.

    Args:
        base_model: The base language model
        args: Command line arguments

    Returns:
        model: The PEFT model ready for training
    """
    print("Creating PEFT/LoRA model...")

    # Get appropriate target modules for this model architecture
    target_modules = get_target_modules(config, config['model_name'])
    print(f"Using target modules: {target_modules}")

    model = FastLanguageModel.get_peft_model(
        model=base_model,
        r=config['lora_r'],
        lora_alpha=config['lora_alpha'],
        lora_dropout=0,
        target_modules=target_modules,
        use_gradient_checkpointing=True,
        random_state=config['seed'],
        use_rslora=True,
        loftq_config=None
    )

    # Enable gradient checkpointing for efficient training
    model.gradient_checkpointing_enable()
    if hasattr(model, 'enable_input_require_grads'):
        model.enable_input_require_grads()

    return model


def get_sft_config(config, fp16, bf16):
    """Configure SFT training arguments.

    Args:
        args: Command line arguments
        fp16: Whether to use FP16 precision
        bf16: Whether to use BF16 precision

    Returns:
        SFTConfig: Configuration for SFT training
    """
    config_kwargs = {
        "per_device_train_batch_size": config['batch_size'],
        "gradient_accumulation_steps": config['gradient_accumulation'],
        "warmup_steps": 5,
        # "max_steps": args.max_steps,
        "learning_rate": config['learning_rate'],
        "fp16": fp16,
        "bf16": bf16,
        "logging_steps": 1,
        "optim": "adamw_8bit",
        "weight_decay": 0.01,
        "lr_scheduler_type": "linear",
        "seed": config['seed'],
        "output_dir": config['output_dir'],
        "report_to": config['report_to'],
        "save_strategy": "steps",
        "save_steps": 10,
        "max_seq_length": config['max_seq_length'],
        "dataset_num_proc": 4,
        "packing": False,
        "num_train_epochs": config['epochs']
    }

    # Add evaluation parameters only if validation data is provided
    if config['val_data'] is not None:
        config_kwargs.update({
            "eval_strategy": "steps",
            "eval_steps": 10,
            "load_best_model_at_end": True,
            "metric_for_best_model": "eval_loss",
        })

    print(f"Training with precision: fp16={fp16}, bf16={bf16}")
    return SFTConfig(**config_kwargs)


def prepare_datasets(train_data_path, val_data_path, tokenizer, note_col, label_col, max_seq_length):
    """Prepare training and validation datasets.

    Args:
        train_data_path (str): Path to training data CSV
        val_data_path (str): Path to validation data CSV
        prompt_builder: MalnutritionPromptBuilder instance
        tokenizer: Tokenizer for the model
        note_col (str): Name of the text column
        label_col (str): Name of the label column
        max_seq_length (int): Maximum sequence length for tokenization

    Returns:
        Tuple: (train_dataset, eval_dataset)
    """
    print("Preparing datasets...")

    # Load and prepare training data
    train_data = MalnutritionDataset(train_data_path, note_col, label_col)
    train_formatted = train_data.prepare_training_data()

    # Pre-tokenize the data to ensure consistent format
    def tokenize_function(examples):
        # Make sure 'text' field exists and is not empty
        if not examples.get('text'):
            return {"input_ids": [], "attention_mask": []}

        # Tokenize the examples
        tokenized = tokenizer(
            examples['text'],
            truncation=True,
            padding="max_length",
            max_length=max_seq_length,
            return_tensors=None,
        )
        # Add labels for supervised fine-tuning
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized

    # Convert to Dataset and tokenize
    train_dataset = Dataset.from_pandas(pd.DataFrame(train_formatted))
    train_tokenized = train_dataset.map(
        tokenize_function,
        batched=False,
        remove_columns=["text"] if "text" in train_dataset.column_names else [],
    )

    # Handle validation data if provided
    eval_tokenized = None
    if val_data_path is not None:
        val_data = MalnutritionDataset(val_data_path, note_col, label_col)
        val_formatted = val_data.prepare_training_data()
        eval_dataset = Dataset.from_pandas(pd.DataFrame(val_formatted))
        eval_tokenized = eval_dataset.map(
            tokenize_function,
            batched=False,
            remove_columns=["text"] if "text" in eval_dataset.column_names else [],
        )
        print(f"Prepared {len(train_tokenized)} training examples and {len(eval_tokenized)} validation examples")
    else:
        print(f"Prepared {len(train_tokenized)} training examples")

    return train_tokenized, eval_tokenized




config = {
  "model_name": "unsloth/Llama-3.2-1B-bnb-4bit",
  "train_data": "/content/train.csv",
  "val_data":"/content/valid.csv",
  "text_column": "txt",
  "label_column": "label",
  # Output arguments
  "output_dir": "./llm",
  "model_output": "./llm_models",

  # Training arguments
  "batch_size": 2,
  "gradient_accumulation": 4,
  "learning_rate": 2e-4,
  "max_seq_length": 2048,
  "epochs": 10,

  # LoRA parameters
  "lora_r": 8,
  "lora_alpha": 32,
  "target_modules": None,

  # Precision arguments
  "force_fp16": False,
  "force_bf16": False,

  # Miscellaneous
  "seed": 42,
  "use_flash_attention": False,
  "report_to": "none",
  "load_in_8bit": False,
  "load_in_4bit": True
}

# Ensure mutual exclusivity of precision settings
if config["load_in_8bit"]:
    config["load_in_4bit"] = False

# Set seed for reproducibility
set_seed(config['seed'])

# Create output directories
os.makedirs(config['output_dir'], exist_ok=True)
os.makedirs(config['model_output'], exist_ok=True)

# Save configuration
with open(os.path.join(config['output_dir'], "config.json"), "w") as f:
    json.dump(config, f, indent=4)

# Get quantization config
quantization_config = get_quantization_config(config)

# Load model and tokenizer with precision detection
base_model, tokenizer, fp16, bf16 = load_model_and_tokenizer(
    config, quantization_config
)

# Load and prepare datasets with tokenization
train_dataset, eval_dataset = prepare_datasets(
    config['train_data'],
    config['val_data'],
    tokenizer,
    config['text_column'],
    config['label_column'],
    config['max_seq_length']
)

# Create PEFT/LoRA model
model = create_peft_model(base_model, config)

# Get SFT config with correct precision settings
sft_config = get_sft_config(config, fp16, bf16)

# Initialize SFT trainer
trainer_kwargs = {
    "model": model,
    "processing_class": tokenizer,
    "train_dataset": train_dataset,
    "args": sft_config,
}

if eval_dataset is not None:
    trainer_kwargs["eval_dataset"] = eval_dataset

trainer = SFTTrainer(**trainer_kwargs)

import os
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'

# Train the model
print(f"Starting training with {len(train_dataset)} examples for {config['epochs']} epoch(s)...")
trainer.train()

# Save the final model
print(f"Training completed. Saving model to {config['model_output']}")
trainer.save_model(config['model_output'])

print("Fine-tuning complete!")

Loading base model and tokenizer: unsloth/Llama-3.2-1B-bnb-4bit
Loading model with settings: precision=fp16, load_in_4bit=True, load_in_8bit=False
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Model and tokenizer loaded successfully.
Preparing datasets...


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Prepared 50 training examples and 50 validation examples
Creating PEFT/LoRA model...
Using target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
Training with precision: fp16=True, bf16=False


Truncating train dataset (num_proc=4):   0%|          | 0/50 [00:00<?, ? examples/s]

Truncating eval dataset (num_proc=4):   0%|          | 0/50 [00:00<?, ? examples/s]

Starting training with 50 examples for 10 epoch(s)...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 50 | Num Epochs = 10 | Total steps = 60
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 5,636,096/1,000,000,000 (0.56% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss


Unsloth: Not an error, but LlamaForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


OutOfMemoryError: CUDA out of memory. Tried to allocate 7.88 GiB. GPU 0 has a total capacity of 14.74 GiB of which 7.74 GiB is free. Process 11128 has 7.00 GiB memory in use. Of the allocated memory 6.82 GiB is allocated by PyTorch, and 26.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)