<a href="https://colab.research.google.com/github/gyasifred/NLP-Techniques/blob/main/ai_malnutrition_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install xformers
!pip uninstall torchvision
!pip install torchvision


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-on1fwl0t/unsloth_8fec937d9abc402bb43a422509e3d91c
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-on1fwl0t/unsloth_8fec937d9abc402bb43a422509e3d91c
  Resolved https://github.com/unslothai/unsloth.git to commit 646ad2f141a3a0721d1ec9449cf9454b5612a84a
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Found existing installation: torchvision 0.21.0
Uninstalling torchvision-0.21.0:
  Would remove:
    /usr/local/lib/python3.11/dist-packages/torchvision-0.21.0.dist-info/*
    /usr/local/lib/python3.11/dist-packages/torchvision.libs/libcudart.41118559.so.12
    /usr/local/lib/python3.11/dist-packages/tor

In [None]:
import torch
import pandas as pd
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from transformers import BitsAndBytesConfig
from unsloth import FastLanguageModel
import xformers

#####################################
# Define quantization configuration
#####################################

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=True
)

####################################
# Load Base Model
####################################

print("Loading base model and tokenizer...")
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Phi-3-mini-4k-instruct",
    load_in_4bit=True,
    dtype=torch.float16,
    quantization_config=quantization_config,
    device_map="auto",
    use_flash_attention_2=False,
    use_cache=False
)
#####################################
# Create PEFT/LORA Model
#####################################

print("Creating PEFT/LoRA model...")
model = FastLanguageModel.get_peft_model(
    model=base_model,
    r=8,
    lora_alpha=32,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    use_gradient_checkpointing=True,
    random_state = 3407,
    use_rslora=True,
    loftq_config=None
)

# Enable gradient checkpointing and (if available) input gradients for efficient training
model.gradient_checkpointing_enable()
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()



#######################################
# Define the unified prompt constructor
#######################################

def construct_prompt(patient_notes: str) -> str:
    """
    Construct the full prompt using the provided patient notes.
    Both training and inference will use this prompt format.
    """
    return f"""Read the patient's notes and determine if the patient is likely to have malnutrition:

Criteria list.

Weight is primarily affected during periods of acute undernutrition, whereas chronic undernutrition typically manifests as stunting. Severe acute undernutrition, experienced by children ages 6–60 months of age, is defined as a very low weight-for-height (less than −3 standard deviations [SD] [z scores] of the median WHO growth standards), by visible severe wasting (mid–upper arm circumference [MUAC] ≤115 mm), or by the presence of nutritional edema.

Chronic undernutrition or stunting is defined by WHO as having a height-forage (or length-for-age) that is less than −2 SD (z score) of the median of the WHO international reference.

Growth is the primary outcome measure of nutritional status in children. Growth should be monitored at regular intervals throughout childhood and adolescence and should also be measured every time a child presents, in any healthcare setting, for preventive, acute, or chronic care. In children less than 36 months of age, measures of growth include length-for-age, weight-for-age, head circumference-for-age, and weight-for-length. In children ages 2–20 years, standing height-for-age, weight-for-age, and body mass index (BMI)-for-age are typically collected.

Mild malnutrition related to undernutrition is usually the result of an acute event, either due to economic circumstances or acute illness, and presents with unintentional weight loss or weight gain velocity less than expected. Moderate malnutrition related to undernutrition occurs due to undernutrition of a significant duration that results in weight-for-length/height values or BMI-for-age values that are below the normal range. Severe malnutrition related to undernutrition occurs as a result of prolonged undernutrition and is most frequently quantified by declines in rates of linear growth that result in stunting.

On initial presentation, a child may have only a single data point for use as a criterion for the identification and diagnosis of malnutrition related to undernutrition. When this is the case, the use of z scores for weight-for-height/length, BMI-for-age, length/height-for-age or MUAC criteria as stated in Table below:

Table.

Mild Malnutrition
Weight-for-height: −1 to −1.9 z score
BMI-for-age: −1 to −1.9 z score
Length/height-for-age: No Data
Mid–upper arm circumference: Greater than or equal to −1 to −1.9 z score

Moderate Malnutrition
Weight-for-height: −2 to −2.9 z score
BMI-for-age: −2 to −2.9 z score
Length/height-for-age: No Data
Mid–upper arm circumference: Greater than or equal to −2 to −2.9 z score

Severe Malnutrition
Weight-for-height: −3 or greater z score
BMI-for-age: −3 or greater z score
Length/height-for-age: −3 z score
Mid–upper arm circumference: Greater than or equal to −3 z score

Follow this format:

1) First provide some explanations about your decision.
2) Then format your output as follows, strictly follow this format: malnutrition=yes or malnutrition=no

{patient_notes}"""

#######################################
# Preparation for Training
#######################################

class MalnutritionDataset:
    def __init__(self, data_path: str):
        """Initialize dataset from CSV file."""
        self.df = pd.read_csv(data_path)

    def prepare_training_data(self) -> Dataset:
        """Prepare data in the format required for training."""
        formatted_data = []
        for _, row in self.df.iterrows():
            # Use the unified prompt format for each example
            prompt = construct_prompt(row["text"])
            formatted_data.append({
                "text": prompt,
                "labels": row["label"]
            })
        return Dataset.from_list(formatted_data)

#######################################
# Inference Functions
#######################################

def generate_response(model, tokenizer, patient_notes: str, device,
                      max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9) -> str:
    """
    Generate a response from the model using the unified prompt format.
    """
    prompt = construct_prompt(patient_notes)
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    output_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


def interactive_prompt(model, tokenizer, device):
    """
    Provide an interactive loop for users to input patient notes and view the model's decision.
    """
    print("Interactive mode: type the patient notes (or type 'exit' to quit).")
    while True:
        patient_notes = input("Enter patient notes: ")
        if patient_notes.strip().lower() == "exit":
            print("Exiting interactive mode.")
            break
        response = generate_response(model, tokenizer, patient_notes, device)
        print("\nModel Response:")
        print(response)
        print("-" * 80)


# Create an instance of your dataset and prepare the data
data_path = "/content/malnutrition_cases.csv"
malnutrition_dataset = MalnutritionDataset(data_path)
train_dataset = malnutrition_dataset.prepare_training_data()

# --- SFTTrainer Setup ---
# Define maximum sequence length; adjust as needed
max_seq_length = 1024

# Define training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=60,
    learning_rate=2e-4,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=3407,
    output_dir="outputs",
    report_to="none"
)

# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=training_args,
)

# --- Training and Saving ---
# Run training
trainer.train()

# Save the final model (and tokenizer) after training
trainer.save_model("/content/drive/MyDrive/malnutrition_model")
tokenizer.save_pretrained("/content/drive/MyDrive/malnutrition_model")

print("Training completed and model saved.")


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Loading base model and tokenizer...
==((====))==  Unsloth 2025.2.4: Fast Mistral patching. Transformers: 4.48.2.
   \\   /|    GPU: NVIDIA L4. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Device supports bfloat16 but you selected float16. Will change to bfloat16.


Creating PEFT/LoRA model...


Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2025.2.4 patched 32 layers with 32 QKV layers, 32 O layers and 0 MLP layers.


Map (num_proc=2):   0%|          | 0/93 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 93 | Num Epochs = 6
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 60
 "-____-"     Number of trainable parameters = 6,291,456


Step,Training Loss
1,1.3715
2,1.3697
3,1.349
4,1.2945
5,1.206
6,1.0994
7,0.9878
8,0.8727
9,0.7569
10,0.6491


Training completed and model saved.


In [None]:
#######################################
# Example: Inference Setup
#######################################

# Load the saved model and tokenizer for inference.
# Replace "outputs/final_model" with your actual saved model directory.
model_dir = "/content/drive/MyDrive/malnutrition_model"
print(f"Loading model and tokenizer from {model_dir} ...")
model, tokenizer = FastLanguageModel.from_pretrained(model_dir)
# tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Choose mode: interactive inference or evaluation on a CSV.
print("Select mode:\n  [I]nteractive prompt\n  [E]valuation on CSV dataset")
mode = input("Enter 'i' or 'e': ").strip().lower()

if mode == "i":
    interactive_prompt(model, tokenizer, device)
elif mode == "e":
    # Evaluation mode: expect a CSV with columns "text" (patient notes) and "labels" (reference answer)
    csv_path = input("Enter the path to your evaluation CSV file: ").strip()
    df = pd.read_csv(csv_path)
    if 'text' not in df.columns or 'labels' not in df.columns:
        raise ValueError("CSV file must contain 'text' and 'labels' columns.")

    prompts = df["text"].tolist()
    references = df["labels"].tolist()

    print(f"Evaluating {len(prompts)} examples from {csv_path} ...\n")
    for i, (patient_notes, reference) in enumerate(zip(prompts, references), start=1):
        print(f"Example {i}:")
        print("Patient Notes:")
        print(patient_notes)
        generated = generate_response(model, tokenizer, patient_notes, device)
        print("\nGenerated Response:")
        print(generated)
        print("\nReference:")
        print(reference)
        print("=" * 80)
else:
    print("Invalid selection. Exiting.")

Loading model and tokenizer from /content/drive/MyDrive/malnutrition_model ...
==((====))==  Unsloth 2025.2.4: Fast Mistral patching. Transformers: 4.48.2.
   \\   /|    GPU: NVIDIA L4. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Select mode:
  [I]nteractive prompt
  [E]valuation on CSV dataset
Enter 'i' or 'e': i
Interactive mode: type the patient notes (or type 'exit' to quit).
Enter patient notes: 4-year-old female at routine checkup. Weight-for-height z-score: -0.8. BMI-for-age within normal range. Good appetite reported, following normal growth curve.

Model Response:
Read the patient's notes and determine if the patient is likely to have malnutrition:

Criteria list.

Weight is pr