In [1]:
%%capture
import os

# Check if running in Colab
if "COLAB_" not in "".join(os.environ.keys()):
    # Local environment (e.g., your PC)
    !pip install unsloth
else:
    # Google Colab setup
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf==5.29.1 datasets huggingface_hub hf_transfer fsspec==2025.3.2
    !pip install --no-deps unsloth


In [2]:
import json

# Load dataset from uploaded file
with open("medical_dataset.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)

# Create prompt structure
def build_prompt(sample):
    return {
        "input": f"""### Instruction:
You are a medical assistant. Read the medical report and provide:
1. A short summary
2. Your interpretation
3. A possible solution

### Input:
{sample['report']}

### Response:
1. Summary: {sample['summary']}
2. Interpretation: {sample['interpretation']}
3. Solution: {sample['solution']}"""
    }

formatted_data = [build_prompt(example) for example in raw_data]


In [3]:
from datasets import Dataset

dataset = Dataset.from_list(formatted_data)
print(dataset[0])


{'input': '### Instruction:\nYou are a medical assistant. Read the medical report and provide:\n1. A short summary\n2. Your interpretation\n3. A possible solution\n\n### Input:\n45-year-old male presents with 3 days of productive cough, fever (38.5°C), and right-sided chest pain. On examination: crackles in right lower lobe. WBC 12.5k, CRP 45. Chest X-ray shows right lower lobe consolidation.\n\n### Response:\n1. Summary: Middle-aged male with community-acquired pneumonia\n2. Interpretation: Clinical presentation and imaging consistent with bacterial pneumonia, likely Streptococcus pneumoniae\n3. Solution: 1. Start amoxicillin-clavulanate 875/125mg PO q12h\n2. Chest physiotherapy\n3. Follow-up in 48 hours or if symptoms worsen'}


In [4]:
from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-bnb-4bit",
    max_seq_length = 2048 ,
    dtype = None,
    load_in_4bit = True,
)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.7.8: Fast Mistral patching. Transformers: 4.53.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [5]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    lora_alpha = 32,
    lora_dropout = 0.0,  # Set to 0 for full Unsloth optimization
    bias = "none",
)


Unsloth 2025.7.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [6]:
def tokenize(example):
    return tokenizer(example["input"], truncation=True, padding="max_length", max_length=2048)

tokenized_dataset = dataset.map(tokenize, batched=True)

from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)


Map:   0%|          | 0/8 [00:00<?, ? examples/s]

In [7]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    per_device_train_batch_size = 2,       # Set to 1 if you get out-of-memory errors
    gradient_accumulation_steps = 4,
    num_train_epochs = 3,
    learning_rate = 2e-4,
    fp16 = True,                           # Mixed precision for speed/memory
    logging_steps = 10,
    report_to = "none",
    output_dir = "outputs",
)


In [8]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)


In [9]:
trainer.train()


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 8 | Num Epochs = 3 | Total steps = 3
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040 of 7,283,675,136 (0.58% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss


TrainOutput(global_step=3, training_loss=6.806591033935547, metrics={'train_runtime': 144.0128, 'train_samples_per_second': 0.167, 'train_steps_per_second': 0.021, 'total_flos': 2109388496044032.0, 'train_loss': 6.806591033935547, 'epoch': 3.0})

In [10]:
model.save_pretrained("finetuned_med_mistral")
tokenizer.save_pretrained("finetuned_med_mistral")


('finetuned_med_mistral/tokenizer_config.json',
 'finetuned_med_mistral/special_tokens_map.json',
 'finetuned_med_mistral/tokenizer.model',
 'finetuned_med_mistral/added_tokens.json',
 'finetuned_med_mistral/tokenizer.json')

In [11]:
sample_report = "A 65-year-old male with fever and cough."
prompt = f"""### Instruction:
You are a medical assistant. Read the medical report and provide:
[SUMMARY]
[INTERPRETATION]
[SOLUTION]

### Input:
{sample_report}

### Response:"""
print(prompt)



inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=500)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


### Instruction:
You are a medical assistant. Read the medical report and provide:
[SUMMARY]
[INTERPRETATION]
[SOLUTION]

### Input:
A 65-year-old male with fever and cough.

### Response:
### Instruction:
You are a medical assistant. Read the medical report and provide:
[SUMMARY]
[INTERPRETATION]
[SOLUTION]

### Input:
A 65-year-old male with fever and cough.

### Response:
The patient is a 65-year-old male with a history of hypertension and diabetes. He presents with a 2-week history of fever, cough, and shortness of breath. On physical examination, he is afebrile, tachycardic, and tachypneic. Chest X-ray shows bilateral infiltrates. Laboratory tests show a white blood cell count of 12,000/mm3, hemoglobin of 10 g/dL, and platelets of 100,000/mm3. The patient is diagnosed with pneumonia and started on antibiotics.

### Summary:
The patient is a 65-year-old male with a history of hypertension and diabetes. He presents with a 2-week history of fever, cough, and shortness of breath. On p

In [12]:
!pip install fastapi uvicorn nest-asyncio
!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared
!chmod +x cloudflared
!pip install fastapi-cache2 sse-starlette python-multipart
!pip install gradio



--2025-07-26 16:02:48--  https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/cloudflare/cloudflared/releases/download/2025.7.0/cloudflared-linux-amd64 [following]
--2025-07-26 16:02:48--  https://github.com/cloudflare/cloudflared/releases/download/2025.7.0/cloudflared-linux-amd64
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://release-assets.githubusercontent.com/github-production-release-asset/106867604/37d2bad8-a2ed-4b93-8139-cbb15162d81d?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-07-26T16%3A51%3A25Z&rscd=attachment%3B+filename%3Dcloudflared-linux-amd64&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-07-

In [13]:
from unsloth import FastLanguageModel
from transformers import BitsAndBytesConfig
import torch

# Configure 4-bit quantization with CPU offload fallback
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True  # ✅ allows CPU fallback
)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="finetuned_med_mistral",
    max_seq_length=4096,
    dtype=None,  # Let Unsloth auto-select best dtype
    device_map="auto",  # ✅ Smart device placement
    quantization_config=bnb_config
)

model.eval()


==((====))==  Unsloth 2025.7.8: Fast Mistral patching. Transformers: 4.53.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): l

In [14]:
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from unsloth import FastLanguageModel
from transformers import BitsAndBytesConfig
import torch
import re
from typing import List
import logging

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="finetuned_med_mistral",
        max_seq_length=3072,
        quantization_config=bnb_config,
        device_map="auto",
    )
    model.eval()
    logger.info("✅ Model loaded successfully")
except Exception as e:
    logger.error(f"❌ Model loading failed: {str(e)}")
    raise RuntimeError("Model load error")

class MedicalReportRequest(BaseModel):
    report: str

def clean_text(text: str) -> str:
    # Remove repeated phrases, excessive whitespace, unwanted chars, unicode normalize
    text = re.sub(r'(?s)(\b\w+\b.*?)(?:\s*\1)+', r'\1', text)
    text = re.sub(r'\s+', ' ', text).strip()
    text = re.sub(r'[^\w\s.,:;!?()\-/]', ' ', text)  # allow standard punctuation
    text = re.sub(r'\\\(.*?\\\)', '', text)  # remove latex math if present
    text = re.sub(r'\[[^\]]*\]', '', text)  # remove bracketed citations
    text = re.sub(r'\[\^.*?\]', '', text)
    text = re.sub(r'\bdoi:\S+', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\bpmid:\S+', '', text, flags=re.IGNORECASE)
    text = re.sub(r'https?://\S+', '', text)
    text = re.sub(r'\n+', ' ', text)
    text = text.strip()
    return text

def chunk_text(text: str, max_words: int = 200) -> List[str]:
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = []
    word_count = 0

    for sentence in sentences:
        words = sentence.split()
        if word_count + len(words) <= max_words:
            current_chunk.append(sentence)
            word_count += len(words)
        else:
            if current_chunk:
                chunks.append(' '.join(current_chunk))
            current_chunk = [sentence]
            word_count = len(words)
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    return chunks

def post_process_output(output: str) -> str:
    # Remove everything between [INST]...[/INST]
    output = re.sub(r'\[INST\].*?\[/INST\]', '', output, flags=re.DOTALL)

    # Remove boilerplate like "Chapter 1 Introduction" or numeric headers
    output = re.sub(r'## Chapter \d+.*', '', output, flags=re.IGNORECASE)
    output = re.sub(r'## \d+(\.\d+)+\.*', '', output)

    # Remove repeated headers, bracketed phrases, URLs, and noise
    output = re.sub(r'(## [^\n]+)(\s+\1)+', r'\1', output)
    output = re.sub(r'\[[^\]]+?\]', '', output)
    output = re.sub(r'https?://\S+\s?', '', output)

    # Remove residual prompt header lines and role/task lines
    output = re.sub(r'(ROLE|TASK|RULES|CLINICAL REPORT|MEDICAL REPORT):.*', '', output, flags=re.IGNORECASE)

    # Deduplicate repeated lines
    lines = output.split('\n')
    filtered_lines = []
    seen = set()
    for line in lines:
        stripped = line.strip()
        if not stripped or stripped in seen:
            continue
        # Filter obvious boilerplate/noise lines
        if re.search(r'\bfractional state\b|\bmath\b|\bchapter\b|\breference\b', stripped, re.IGNORECASE):
            continue
        if re.search(r'[@#$%^&*{}<>~`\[\]]', stripped):
            continue
        filtered_lines.append(line)
        seen.add(stripped)

    cleaned = '\n'.join(filtered_lines).strip()
    cleaned = re.sub(r'\n{2,}', '\n\n', cleaned)  # collapse multiple empty lines
    return cleaned

def generate_with_fallback(
    prompt: str,
    max_tokens: int,
    temp: float = 0.1,
    do_sample: bool = False,
    repetition_penalty: float = 1.1,
    no_repeat_ngram_size: int = 3
) -> str:
    try:
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to("cuda")
        outputs = model.generate(
            **inputs,
            max_new_tokens=min(max_tokens, 1200),
            temperature=temp,
            do_sample=do_sample,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if result.startswith(prompt[:100]):
            result = result[len(prompt):].strip()
        return result.strip().split("###")[0]
    except Exception as e:
        raise RuntimeError(f"Generation failed: {e}")

# Refined stricter prompts forbidding narrative and educational content, only facts from report.


def generate_structured_summary(report: str) -> str:
    prompt = f"""
[INST] <<SYS>>
ROLE: Clinical Data Extractor

Extract ONLY data explicitly written in the report. NO assumptions.
No narrative or educational explanations. Use exact quotes for facts.

Extract ONLY:
1. Demographics (Age/Sex/Risks)
2. Chief Complaint + Duration
3. Key Abnormalities (ECG/Labs/Imaging/Vitals)
4. Interventions Done
5. Outcome / Patient status

Rules:
- Use exact quoted phrases.
- Missing info → [Not documented].
- Do NOT add information not in the report.
- Do NOT include any external facts or guidelines.
- Output ONLY the extracted data.

<</SYS>>

\"\"\"{report[:1800]}\"\"\"
[/INST]
SUMMARY:
"""
    return generate_with_fallback(prompt, max_tokens=600, temp=0.1)


def generate_focused_interpretation(report: str) -> str:
    prompt = f"""
[INST] <<SYS>>
ROLE: Lead Diagnostician

Goal: Provide the most likely diagnosis supported STRICTLY by facts in the report.

Include:
- DX: [Condition]
- 2 clear quotes supporting diagnosis
- 1 unclear/conflicting data if present

NO assumptions, NO narratives, NO external info.

<</SYS>>

\"\"\"{report[:1800]}\"\"\"
[/INST]
INTERPRETATION:
"""
    return generate_with_fallback(prompt, max_tokens=800, temp=0.1)


def generate_case_specific_solutions(report: str, diagnosis: str) -> str:
    prompt = f"""
[INST] <<SYS>>
ROLE: Tactical Medical Planner

Given DX: {diagnosis}

Provide ACTUAL recommendations STRICTLY based on the report facts, no assumptions:

1. IMMEDIATE: Life-threatening actions + Quote + Metrics + Backup plan
2. URGENT: Time-sensitive actions + Quote + Metrics + Backup plan
3. ROUTINE: Follow-up actions + Quote + Metrics + Backup plan

Do NOT add educational or hypothetical information.

<</SYS>>

\"\"\"{report[:1800]}\"\"\"
[/INST]
SOLUTIONS:
"""
    return generate_with_fallback(prompt, max_tokens=1200, temp=0.1)


@app.post("/analyze")
def analyze_report(request: MedicalReportRequest):
    MAX_REPORT_LENGTH = 8000

    if not request.report.strip():
        raise HTTPException(status_code=400, detail="Empty report")
    if len(request.report) > MAX_REPORT_LENGTH:
        raise HTTPException(status_code=400, detail="Report too long")

    def generate_output():
        try:
            clean_report = clean_text(request.report)
            chunks = chunk_text(clean_report)
            yield f"🔎 Processing {len(chunks)} chunks...\n\n"

            # Generate structured summary
            summary_raw = generate_structured_summary(clean_report)
            summary = post_process_output(summary_raw)
            yield f"📋 Structured Summary:\n{summary}\n\n"

            # Generate focused clinical interpretation
            interpretation_raw = generate_focused_interpretation(clean_report)
            interpretation = post_process_output(interpretation_raw)
            yield f"🩺 Clinical Interpretation:\n{interpretation}\n\n"

            # Parse primary diagnosis to pass to solutions
            primary_dx = "unspecified condition"
            if 'DX:' in interpretation:
                primary_dx_line = interpretation.split('DX:')[-1].split('\n')[0].strip()
                if primary_dx_line:
                    primary_dx = primary_dx_line

            # Generate detailed recommendations
            solutions_raw = generate_case_specific_solutions(clean_report, primary_dx)
            solutions = post_process_output(solutions_raw)
            yield f"🧭 Recommendations:\n{solutions}\n"

        except torch.cuda.OutOfMemoryError:
            yield "❌ GPU out of memory. Use a shorter report.\n"
        except Exception as e:
            logger.error(f"Analysis error: {str(e)}", exc_info=True)
            yield f"❌ Error: {str(e)}\n"

    return StreamingResponse(generate_output(), media_type="text/plain")


==((====))==  Unsloth 2025.7.8: Fast Mistral patching. Transformers: 4.53.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [15]:
import nest_asyncio
import threading
import uvicorn

nest_asyncio.apply()

def run():
    uvicorn.run(app, host="0.0.0.0", port=8000)

threading.Thread(target=run).start()


In [None]:
!./cloudflared tunnel --url http://localhost:8000 --metrics 127.0.0.1:45678 --no-autoupdate


INFO:     Started server process [9581]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


[90m2025-07-26T16:03:41Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
[90m2025-07-26T16:03:41Z[0m [32mINF[0m Requesting new quick Tunnel on trycloudflare.com...
[90m2025-07-26T16:03:46Z[0m [32mINF[0m +--------------------------------------------------------------------------------------------+
[90m2025-07-26T16:03:46Z[0m [32mINF[0m |  Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):  |
[90m2025

In [None]:
 import requests

url = "https://useful-candle-9d.trycloudflare.com/analyze"  # Replace with your actual URL

response = requests.post(url, json={
    "report": "A 70-year-old patient reports difficulty breathing and swelling in the ankles..."
})

print(response.json())
