#  Utils for loading model and predicting

In [1]:
import os, math, numpy as np, torch
from torch import nn
from torchvision import models, transforms
from PIL import Image

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

_to_resnet = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def load_xray_model(weights_path: str, num_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_feats = model.fc.in_features
    model.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
    sd = torch.load(weights_path, map_location=DEVICE)
    if isinstance(sd, dict) and 'state_dict' in sd:
        model.load_state_dict(sd['state_dict'], strict=False)
        classes = sd.get('classes', ['NORMAL','PNEUMONIA'])
    else:
        model.load_state_dict(sd, strict=False)
        classes = ['NORMAL','PNEUMONIA']
    model.eval().to(DEVICE)
    return model, classes

class TempCal:
    def __init__(self, T=1.0):
        self.T = float(T)

    @classmethod
    def load(cls, path: str):
        if os.path.exists(path):
            return cls(torch.load(path, map_location='cpu').get('T', 1.0))
        return cls(1.0)

def predict_xray(model, img):
    if isinstance(img, (str, os.PathLike)):
        img = Image.open(img).convert('RGB')
    x = _to_resnet(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        idx = int(probs.argmax())
    return idx, float(probs[idx]), probs.tolist()

NORMAL_REPORT = (
    "Chest X-ray (PA):\n"
    "• Cardiomediastinal silhouette: Within normal limits.\n"
    "• Lungs: No focal consolidation identified. No pleural effusion.\n"
    "• Pneumothorax: Not seen.\n"
    "• Bones/soft tissues: Unremarkable.\n\n"
    "Impression: No acute cardiopulmonary abnormality."
)


# X-RAY CLASSIFICATION

In [2]:
def classify_xray(img_path):
    """
    Takes X-ray image path, preprocesses it,
    runs the model, and returns predicted label(s).
    """
    # Example stub — replace with your real model code
    # model = load_model(...)
    # img = preprocess(img_path)
    # pred = model.predict(img)
    # return decode_prediction(pred)
    return "Pneumonia (example)"


# ======================
# REPORT SUMMARIZATION
# ======================

def summarize_report(report_text: str) -> str:
    """
    Summarizes the input report into a concise version.
    """
    # Example stub — replace with your summarizer
    # summary = summarizer_pipeline(report_text)
    # return summary
    return "This is a summary of the report."


# ======================
# RAG QUESTION ANSWERING
# ======================

def answer_query(query: str, report_text: str, xray_result: str) -> str:
    """
    Answers user query using report text + xray classification + RAG.
    """
    # Example stub — replace with your RAG pipeline
    # context = build_context(report_text, xray_result)
    # answer = rag_pipeline(query, context)
    # return answer
    return f"Answer to '{query}' based on report and X-ray."


# ======================
# TEXT UTILS
# ======================

def preprocess_text(text: str) -> str:
    """
    Clean text (remove extra spaces, normalize, etc.)
    """
    return text.strip()


# ======================
# MASTER PIPELINE
# ======================

def run_pipeline(xray_img, report_text, query):
    """
    Complete pipeline:
    1. Classify the X-ray
    2. Summarize the report
    3. Answer the user query
    """
    try:
        # Step 1: classify X-ray
        xray_result = classify_xray(xray_img)

        # Step 2: summarize report
        cleaned_report = preprocess_text(report_text)
        summary_result = summarize_report(cleaned_report)

        # Step 3: answer query using RAG
        answer_result = answer_query(query, cleaned_report, xray_result)

        return xray_result, summary_result, answer_result

    except Exception as e:
        return f"Error: {str(e)}", f"Error: {str(e)}", f"Error: {str(e)}"


In [3]:
import os
import torch
from torch import nn
from typing import Tuple, List, Optional
from PIL import Image
from torchvision import models, transforms

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# image transforms 
XRAY_TFMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

#  small helper: find inner torch module from wrappers 
def _extract_torch_module(m):
    if isinstance(m, nn.Module):
        return m
    # common wrapper attribute names
    for name in ["model", "network", "net", "module", "backbone", "base_model"]:
        if hasattr(m, name) and isinstance(getattr(m, name), nn.Module):
            return getattr(m, name)
    raise AttributeError("Could not find inner nn.Module inside the provided model/wrapper.")

def export_xray_checkpoint(
    trained_model,
    class_names: List[str],
    export_path: str,
    arch_hint: Optional[str] = None,
):
    """
    Saves a universal checkpoint:
      { 'state_dict': ..., 'classes': [...], 'arch': 'resnet18', 'num_classes': N }
    Works whether trained_model is a wrapper or a plain nn.Module.
    """
    torch_model = _extract_torch_module(trained_model)
    # infer arch if not provided
    arch = arch_hint or torch_model.__class__.__name__.lower()
    # some wrappers keep the head as .fc (resnet) or .classifier (densenet/efficientnet)
    if hasattr(torch_model, "fc") and isinstance(torch_model.fc, nn.Module):
        out_features = getattr(torch_model.fc, "out_features", len(class_names))
    elif hasattr(torch_model, "classifier") and isinstance(torch_model.classifier, nn.Module):
        out_features = getattr(torch_model.classifier, "out_features", len(class_names))
    else:
        out_features = len(class_names)

    payload = {
        "state_dict": torch_model.state_dict(),
        "classes": list(class_names),
        "arch": arch,                
        "num_classes": int(out_features),
        "input_size": (224, 224),
    }
    os.makedirs(os.path.dirname(export_path), exist_ok=True)
    torch.save(payload, export_path)
    return export_path

# mapping from arch name -> constructor + head patch 
def _build_backbone(arch: str, num_classes: int) -> nn.Module:
    arch = (arch or "resnet18").lower()

    if arch in ["resnet18", "resnet-18", "resnet"]:
        m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        in_feats = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
        return m
    if arch in ["resnet34", "resnet-34"]:
        m = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        in_feats = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
        return m
    if arch in ["resnet50", "resnet-50"]:
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        in_feats = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
        return m
    if arch in ["densenet121", "densenet-121", "densenet"]:
        m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        in_feats = m.classifier.in_features
        m.classifier = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
        return m
    # fallback
    m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_feats = m.fc.in_features
    m.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
    return m

# loader used by main 
def load_xray_model(checkpoint_path: str) -> Tuple[nn.Module, List[str]]:
    """
    Loads from:
      - universal dict: {'state_dict', 'classes', 'arch', 'num_classes'}
      - OR plain state_dict (weights-only), defaulting to resnet18 + 2 classes
    """
    ckpt = torch.load(checkpoint_path, map_location=DEVICE)

    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        classes = ckpt.get("classes", ["NORMAL", "PNEUMONIA"])
        num_classes = ckpt.get("num_classes", len(classes))
        arch = ckpt.get("arch", "resnet18")
        model = _build_backbone(arch, num_classes)
        model.load_state_dict(ckpt["state_dict"], strict=True)
    else:
        # weights-only: assume resnet18 with 2 classes
        classes = ["NORMAL", "PNEUMONIA"]
        model = _build_backbone("resnet18", len(classes))
        model.load_state_dict(ckpt, strict=False)

    model.eval().to(DEVICE)
    return model, classes

# prediction
def predict_xray(model: nn.Module, img) -> Tuple[int, float, list]:
    """
    Returns (pred_idx, conf, probs_list) with test-time flip averaging.
    """
    if isinstance(img, (str, os.PathLike)):
        img = Image.open(img).convert("RGB")

    x1 = XRAY_TFMS(img).unsqueeze(0).to(DEVICE)
    x2 = XRAY_TFMS(img.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        p1 = torch.softmax(model(x1), dim=1)
        p2 = torch.softmax(model(x2), dim=1)
        probs = ((p1 + p2) / 2.0).squeeze(0).cpu().numpy()

    idx = int(probs.argmax())
    return idx, float(probs[idx]), probs.tolist()


In [4]:
import os, torch
from torch import nn
from torchvision import models, transforms
from PIL import Image

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Absolute paths to artifacts so app runs from anywhere
ROOT = r"D:\\MedicalAI-Assistant"
ART_DIR = os.path.join(ROOT, 'artifacts')

def load_xray_model(weights_path: str, class_names=('NORMAL','PNEUMONIA')):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    num_feats = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(num_feats, len(class_names))
    )
    sd = torch.load(weights_path, map_location=DEVICE)
    model.load_state_dict(sd['state_dict'])
    model.eval().to(DEVICE)
    return model, sd.get('classes', list(class_names))

_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def predict_xray_with_roboflow(img_path):
    pred, conf, _ = roboflow_classify(img_path)
    return pred, conf
