In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

In [43]:
import torch
import torch.nn as nn

# Correct architecture (must match training)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            # nn.AdaptiveAvgPool2d((4, 4)),
            nn.Flatten(),
            nn.Linear(32*32*32, 128), nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        return self.net(x)

# Load CT
checkpoint = torch.load("models/ct_model.pth", map_location="cpu")
ct_model = SimpleCNN(num_classes=checkpoint["num_classes"])
ct_model.load_state_dict(checkpoint["model_state_dict"])
ct_model.eval()

# Load X-ray
checkpoint = torch.load("models/cnn_chestxray.pth", map_location="cpu")
xray_model = SimpleCNN(num_classes=checkpoint["num_classes"])
xray_model.load_state_dict(checkpoint["model_state_dict"])
xray_model.eval()

# Load Ultrasound
checkpoint = torch.load("models/ultrasound_model.pth", map_location="cpu")
ultrasound_model = SimpleCNN(num_classes=checkpoint["num_classes"])
ultrasound_model.load_state_dict(checkpoint["model_state_dict"])
ultrasound_model.eval()

print("✅ All models loaded and ready for inference")


✅ All models loaded and ready for inference


In [36]:
# Step 2: Image transforms
# ------------------------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # RGB mean/std
])

In [37]:
def select_model(modality):
    if modality.lower() == "ct":
        return ct_model
    elif modality.lower() == "xray":
        return xray_model
    elif modality.lower() == "ultrasound":
        return ultrasound_model
    else:
        raise ValueError("Unknown modality. Please specify CT, X-ray, or Ultrasound.")

def predict(image_path, modality):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # add batch dimension

    if modality.lower() == "ct":
        model = ct_model
    elif modality.lower() == "xray":
        model = xray_model
    elif modality.lower() == "ultrasound":
        model = ultrasound_model
    else:
        raise ValueError("Invalid modality! Choose from: ct, xray, ultrasound")

    with torch.no_grad():
        outputs = model(image)
        probs = torch.softmax(outputs, dim=1)
        confidence, pred = torch.max(probs, 1)

    return pred.item(), confidence.item()

In [38]:
modality = input("Enter modality (ct / xray / ultrasound): ")
image_path = input("Enter image path: ")

In [39]:
pred, confidence = predict(image_path, modality)
print(f"Prediction: {'Anomaly' if pred == 1 else 'Normal'}")
print(f"Confidence: {confidence:.4f}")


Prediction: Normal
Confidence: 0.9928


In [45]:
# Quick smoke test: dummy forward pass to validate shapes
import torch
with torch.no_grad():
    dummy = torch.randn(1, 3, 128, 128)  # batch 1, RGB, 128x128
    out = ct_model(dummy)
    print('ct_model output shape:', out.shape)

ct_model output shape: torch.Size([1, 3])


In [46]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

def evaluate(model, dataloader, device="cpu"):
    model.to(device)
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            # handle one-hot labels -> convert to class indices
            if labels.dim() > 1 and labels.size(1) > 1:
                labels = torch.argmax(labels, dim=1)
            # handle binary (single-logit) vs multiclass outputs
            if outputs.dim() == 1 or (outputs.dim() == 2 and outputs.size(1) == 1):
                probs = torch.sigmoid(outputs.view(-1))
                preds = (probs > 0.5).long()
            else:
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(probs, dim=1)
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    print(f"Accuracy: {accuracy_score(all_labels, all_preds):.4f}")
    print("Classification report:")
    print(classification_report(all_labels, all_preds, digits=4))
    print("Confusion matrix:")
    print(confusion_matrix(all_labels, all_preds))

    return all_preds, all_labels
