In [1]:
!pip install torch torchvision timm tensorflow pillow opencv-python numpy

Collecting torchvision
  Downloading torchvision-0.22.1-cp312-cp312-win_amd64.whl.metadata (6.1 kB)
Collecting timm
  Downloading timm-1.0.16-py3-none-any.whl.metadata (57 kB)
Collecting torch
  Downloading torch-2.7.1-cp312-cp312-win_amd64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Downloading torchvision-0.22.1-cp312-cp312-win_amd64.whl (1.7 MB)
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
   ------------------ --------------------- 0.8/1.7 MB 6.7 MB/s eta 0:00:01
   ---------------------------------------- 1.7/1.7 MB 7.1 MB/s eta 0:00:00
Downloading torch-2.7.1-cp312-cp312-win_amd64.whl (216.1 MB)
   ---------------------------------------- 0.0/216.1 MB ? eta -:--:--
   ---------------------------------------- 1.6/216.1 MB 8.3 MB/s eta 0:00:26
    --------------------------------------- 3.9/216.1 MB 9.0 MB/s eta 0:00:24
    --------------------------------------- 5.0/216.1 MB 7.9 MB/


[notice] A new release of pip is available: 24.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
# CALL MODELS

import torch
from timm import create_model

# ✅ Load your ViT classifier from this path
vit_model_path = r"C:\Users\Devansh\Desktop\ProjectNew\Model_5_ViT\vit_model.pth"

vit_model = create_model('vit_base_patch16_224', pretrained=False, num_classes=2)
vit_model.load_state_dict(torch.load(vit_model_path, map_location="cpu"))
vit_model.eval()

import tensorflow as tf
from tensorflow.keras.models import load_model

# ✅ Load your Keras segmentation model from this path
segment_model_path = r"C:\Users\Devansh\Desktop\ProjectNew\Model_5_ViT\best_model (1).keras"

# Custom metric functions (used during training)
def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def jaccard_index(y_true, y_pred, smooth=100):
    y_true_f = tf.reshape(tf.cast(y_true, tf.float32), [-1])
    y_pred_f = tf.reshape(tf.cast(y_pred, tf.float32), [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    total = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    return (intersection + smooth) / (total + smooth)

segment_model = load_model(
    segment_model_path,
    custom_objects={'dice_coefficient': dice_coefficient, 'jaccard_index': jaccard_index}
)

In [2]:
# PREPROCESS DATA

import numpy as np
import cv2
from PIL import Image
from torchvision import transforms

# Define transforms for ViT
transform_vit = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

def preprocess_and_predict(image_path):
    # Load image
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)
    original_shape = image_np.shape[:2]  # (H, W)

    # --- Segment lungs ---
    resized_input = cv2.resize(image_np, (256, 256)).astype(np.float32) / 255.0
    resized_input = np.expand_dims(resized_input, axis=0)  # (1, 256, 256, 3)

    mask = segment_model.predict(resized_input, verbose=0)[0]
    if mask.ndim == 3:
        mask = mask[:, :, 0]

    # Resize and binarize mask
    mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
    mask_binary = (mask_resized > 0.5).astype(np.uint8)

    # Apply mask 
    masked_image = image_np * np.expand_dims(mask_binary, axis=-1)

    # Convert to PIL and apply ViT transforms
    masked_pil = Image.fromarray(masked_image.astype(np.uint8))
    input_tensor = transform_vit(masked_pil).unsqueeze(0)  # (1, 3, 224, 224)

    # Predict
    with torch.no_grad():
        output = vit_model(input_tensor)
        pred_class = torch.argmax(output, dim=1).item()

    return pred_class


In [5]:
# SINGLE IMAGE PREDICTION

img_path = r"C:\Users\Devansh\Desktop\ProjectNew\Test\tb0017.png"  # your X-ray image
prediction = preprocess_and_predict(img_path)
print(f"Prediction: {'TB' if prediction == 1 else 'Normal'}")

Prediction: TB
