<a href="https://colab.research.google.com/github/mzu-2410z/generic-doctor/blob/main/generic_doctor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Quick notes:
# - Make sure you've set Runtime > Change runtime type > GPU.
# - Have your Hugging Face token ready when the notebook asks.
# - If you hit OOM, switch to a smaller model (instructions later).


In [1]:
# Install required libs
!pip install -q transformers accelerate bitsandbytes peft trl datasets huggingface_hub
# make sure we have the latest versions (may take a minute)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m423.1/423.1 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
# Login to Hugging Face (you'll be prompted to paste your token)
from huggingface_hub import notebook_login
notebook_login()

# Check GPU
import torch, os
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Current device name:", torch.cuda.get_device_name(0))


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Torch version: 2.8.0+cu126
CUDA available: True
Device count: 1
Current device name: Tesla T4


In [4]:
# ----- CONFIG -----
# If GPU memory is limited, use a smaller model (gemma-2b).
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"  # recommended
# Fallback smaller model (if OOM): "google/gemma-2b" or "anon8231489123/gpt4-x-alpaca" etc.
FALLBACK_MODEL = "google/gemma-2b"

# Training options
NUM_SAMPLES = 2000        # how many total examples to use (adjustable)
MAX_LENGTH = 512
NUM_EPOCHS = 1
BATCH_SIZE = 1            # per device
GRAD_ACCUM = 4            # gradient accumulation to simulate larger batch
LEARNING_RATE = 2e-4

# Output
OUTPUT_DIR = "./doctor_model"

print("Config ok. Model:", MODEL_NAME)


Config ok. Model: mistralai/Mistral-7B-Instruct-v0.2


In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets
import random
random.seed(42)

def sample_dataset(name, split, n):
    print(f"Loading {name}...")
    try:
        ds = load_dataset(name, split=split)
    except Exception as e:
        print("Failed to load", name, ":", e)
        return None
    # try to extract sensible columns
    # We'll look for common names: question/answer, query/response, utterances set, context/response
    def extract_example(item):
        # try direct fields
        for qk in ("question","query","prompt","patient_question","symptom"):
            if qk in item and item[qk]:
                for ak in ("answer","response","reply","doctor_answer","medical_answer"):
                    if ak in item and item[ak]:
                        return {"text": f"Patient: {item[qk].strip()}\nDoctor: {item[ak].strip()}"}
        # try QA pairs as lists or dialogs
        if "dialog" in item or "utterances" in item or "conversations" in item:
            # naive extraction: find alternating list and join last two
            cand = item.get("dialog") or item.get("utterances") or item.get("conversations")
            if isinstance(cand, list) and len(cand) >= 2:
                # assume last user/assistant pair
                u = cand[-2].get("text") if isinstance(cand[-2], dict) else str(cand[-2])
                a = cand[-1].get("text") if isinstance(cand[-1], dict) else str(cand[-1])
                return {"text": f"Patient: {u.strip()}\nDoctor: {a.strip()}"}
        # fallback: take any two string fields
        strs = [v for v in item.values() if isinstance(v, str) and len(v) > 5]
        if len(strs) >= 2:
            return {"text": f"Patient: {strs[0].strip()}\nDoctor: {strs[1].strip()}"}
        return None

    examples = []
    for i, it in enumerate(ds):
        ex = extract_example(it)
        if ex:
            examples.append(ex)
        if len(examples) >= n:
            break
    print(f"Extracted {len(examples)} usable examples from {name}")
    return Dataset.from_list(examples)

# Choose datasets and sample sizes
# The original datasets failed to load, using a fallback dataset related to the fallback model
d1 = sample_dataset("HuggingFaceH4/ultrachat_200k", "train_sft", NUM_SAMPLES)


datasets = []
if d1 is not None and len(d1) > 0:
    datasets.append(d1)

if len(datasets) == 0:
    raise SystemExit("No dataset could be loaded/extracted. Check dataset names or change MODEL_NAME to fallback and try again.")

if len(datasets) == 1:
    full = datasets[0]
else:
    full = concatenate_datasets(datasets)

# Shuffle & split
full = full.shuffle(seed=42)
train_test = full.train_test_split(test_size=0.1)
train_ds = train_test["train"]
eval_ds = train_test["test"]
print("Train size:", len(train_ds), "Eval size:", len(eval_ds))

In [10]:
from transformers import AutoTokenizer

model_name = MODEL_NAME
print("Loading tokenizer for", model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token # Add this line to set the pad token

def tokenize_fn(batch):
    enc = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=MAX_LENGTH)
    enc["labels"] = enc["input_ids"].copy()
    return enc

train_ds_tok = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names)
eval_ds_tok = eval_ds.map(tokenize_fn, batched=True, remove_columns=eval_ds.column_names)

print("Tokenized. Example keys:", train_ds_tok.column_names)

Loading tokenizer for mistralai/Mistral-7B-Instruct-v0.2


Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Tokenized. Example keys: ['input_ids', 'attention_mask', 'labels']


In [11]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

print("Configuring 4-bit quantization and loading model...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
except Exception as e:
    print("4-bit load failed:", e)
    print("Trying fallback model:", FALLBACK_MODEL)
    model_name = FALLBACK_MODEL
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)

print("Model loaded:", model_name)

# Apply LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj","v_proj","k_proj","o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
print("LoRA wrapped model ready.")


Configuring 4-bit quantization and loading model...


config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Model loaded: mistralai/Mistral-7B-Instruct-v0.2
LoRA wrapped model ready.


In [14]:
from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=NUM_EPOCHS,
    logging_steps=20,
    fp16=True,
    save_strategy="epoch",
    save_total_limit=2,
    learning_rate=LEARNING_RATE,
    remove_unused_columns=False
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds_tok,
    eval_dataset=eval_ds_tok,
)

print("Starting training...")
trainer.train()
print("Training finished.")

Truncating train dataset:   0%|          | 0/1800 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/200 [00:00<?, ? examples/s]

Starting training...


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpahos82434[0m ([33mpahos82434-lahore-college-for-women-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
20,4.1817
40,0.9503
60,0.9314
80,0.9409
100,0.9941
120,0.8678
140,0.9521
160,0.8929
180,0.9014
200,0.8738


Training finished.


In [15]:
# Save the PEFT adapter (smaller than full model)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Saved to", OUTPUT_DIR)

# Quick interactive test
def gen(prompt, max_new_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(next(model.parameters()).device)
    out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, top_p=0.95, temperature=0.7)
    return tokenizer.decode(out[0], skip_special_tokens=True)

prompt = "Patient: I have had a sore throat and a mild fever for two days. What should I do?\nDoctor:"
print(gen(prompt))


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Saved to ./doctor_model
Patient: I have had a sore throat and a mild fever for two days. What should I do?
Doctor: Based on the information provided, the Patient has had a sore throat and a mild fever for two days.
If the fever is higher than 38 degrees Celsius, the Patient should seek medical advice. If the fever is lower than this, they can try to treat it at home. The Patient should drink plenty of water to stay hydrated and avoid caffeine and alcohol, as these can worsen the symptoms.
The Patient should also avoid speaking or coughing as much as possible, as this can worsen the pain. They can try sucking on a lollipop or eating ice cubes to help soothe the throat.
If the Patient is experiencing severe symptoms, they should seek medical advice immediately.
The doctor's response indicates that the Patient's fever is not high enough to warrant medical attention, and they can try to treat the symptoms at home by staying hydrated, avoiding caffeine
