In [None]:
import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import sys

main_dir = "/media/ist/Drive2/MANSOOR/Neuroimaging-Project/Breast_Cancer_Classification_Project"
pretrained_model_dir = f"{main_dir}/WSI_Breast_Cancer_Classification/breast_cancer_detection"
model_path_dir = f'{pretrained_model_dir}/saved/models/BCDensenet/0224_034642/'
model_config = f"{model_path_dir}/config.json"
model_path = f"{model_path_dir}/model_best.pth"

# Ensure the directory containing custom modules is in the path
sys.path.append(f'{main_dir}/WSI_Breast_Cancer_Classification/breast_cancer_detection')

# Assuming 'parse_config' and other necessary modules are in this directory
from breast_cancer_detection import parse_config  # Adjust the import according to actual usage

config = parse_config(model_config)  # Adjust if there's a configuration file


def load_pretrained_model(model_path):
    # Load a custom PyTorch model with necessary configurations
    model = CustomModel()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

# Define a function to preprocess images
def preprocess_image(image_path):
    # Assuming the model expects images to be 224x224 pixels
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize the image to fit the model input
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization
    ])
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Function to predict labels
def predict(model, image_path):
    image = preprocess_image(image_path)
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
    return predicted.item()


tiles_dir =  f"{main_dir}/test_tiles" 
save_model_path = f"{main_dir}/WSI_Breast_Cancer_Classification/Model_Weights"


# Load the model
model = load_pretrained_model(model_path)

# Directory containing images
image_directory =  f"{main_dir}/test_tiles/p_2_test" 
image_files = [os.path.join(image_directory, img) for img in os.listdir(image_directory) if img.endswith('.png')]

# Predict labels for each image
predictions = {img: predict(model, img) for img in image_files}

# Print the predictions
for img_path, label in predictions.items():
    print(f'Image: {img_path}, Predicted Label: {"Benign" if label == 0 else "Malignant"}')


In [None]:
##### does not have model weights available #########

import torch
from torchvision import transforms
from PIL import Image
from TransMIL import * 
from TransMIL.models import TransMIL
from TransMIL import MyOptimizer, MyLoss
from TransMIL.MyOptimizer import *


# Assume you have a function to load the model and it's been trained or a pre-trained model is loaded
model = TransMIL()
model.load_state_dict(torch.load('trans_mil_model.pth'))
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize if necessary
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def classify_patch(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
    return predicted.item()


# create file directory structure and train-test split (one time run when the file structure is not created)
main_dir = "/media/ist/Drive2/MANSOOR/Neuroimaging-Project/Breast_Cancer_Classification_Project"

tiles_dir =  f"{main_dir}/test_tiles" 
save_model_path = f"{main_dir}/WSI_Breast_Cancer_Classification/Model_Weights"

# Example usage
image_path = f'{tiles_dir}/n_14_test/SUB_n_14_tile_(6_38).png'
prediction = classify_patch(image_path)
print("Predicted class:", "Benign" if prediction == 0 else "Malignant")
