In [43]:
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Same transform as training
img_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load X-ray filter model
def load_xray_filter():
    model = models.resnet50(pretrained=False)
    model.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 500),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(500, 2)
    )

    model.load_state_dict(torch.load("xray_filter.pth", map_location=device))
    model.to(device)
    model.eval()
    return model

# Load pneumonia model
def load_pneumonia_model():
    model = models.resnet50(pretrained=False)
    model.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 500),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(500, 2)
    )
    model.load_state_dict(torch.load("pneumoina_model.pth", map_location=device))
    model.to(device)
    model.eval()
    return model

# Inference pipeline
def predict_image(image_path):
    img = Image.open(image_path).convert("L")  # Convert to grayscale
    img_tensor = img_transform(img).unsqueeze(0).to(device)

    # Stage 1: Run X-ray filter
    xray_filter = load_xray_filter()
    with torch.no_grad():
        pred = xray_filter(img_tensor)
        pred_label = torch.argmax(pred, dim=1).item()
        print(pred_label)
    if pred_label == 1:
        print("❌ This is not a valid chest X-ray image. Please upload a proper medical image.")
       
        return

    # Stage 2: Pneumonia prediction
    pneumonia_model = load_pneumonia_model()
    with torch.no_grad():
        output = pneumonia_model(img_tensor)
        label = torch.argmax(output, dim=1).item()

    classes = ["Normal", "Pneumonia"]
    print(f"✅ Prediction: {classes[label]}")


In [46]:
predict_image('C:\\Users\\Administrator\\Downloads\\zophie.png')



1
❌ This is not a valid chest X-ray image. Please upload a proper medical image.
