In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np
from torchvision import transforms
import os
import torchvision.models as models
from torchvision.models import ResNet18_Weights
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight
from monai.networks.nets import DenseNet121, UNet
from monai.networks.nets import DenseNet

In [32]:
image_channels = 1
image_size = (224, 224)
# image_size = (1024, 1024)

In [33]:
class ImprovedCNN(nn.Module):
    def __init__(self):
        super(ImprovedCNN, self).__init__()
        self.conv1 = nn.Conv2d(image_channels, 32, kernel_size=3, padding=1)  
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

        self.flattened_size = 256 * (image_size[0] // 16) * (image_size[1] // 16)
        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = self.pool(torch.relu(self.conv4(x)))
        x = x.view(-1, self.flattened_size)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        # x = torch.sigmoid(self.fc2(x))
        return self.fc2(x)

class TransferLearningModel(nn.Module):
    def __init__(self):
        super(TransferLearningModel, self).__init__()
        # Load pre-trained ResNet18 with weights
        self.base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        
        # Modify the first convolutional layer to accept grayscale input
        self.base_model.conv1 = nn.Conv2d(image_channels, 64, kernel_size=3, padding=1) 
        
        # Replace the fully connected layer for binary classification
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(num_features, 1)  # Output 1 raw logit
    
    def forward(self, x):
        return self.base_model(x)


In [34]:
def get_monai_model():
    # Load MONAI's DenseNet121
    model = DenseNet121(
        spatial_dims=2,  # For 2D images
        in_channels=image_channels,   # RGB input
        out_channels=1   # Binary classification (logits)
    )
    return model

def get_densenet169():
    model = DenseNet(
        spatial_dims=2,
        in_channels=image_channels,
        out_channels=1,  # Binary classification
        block_config=(6, 12, 32, 32)  # DenseNet169 configuration
    )
    return model


In [35]:
from PIL import Image

class FullImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image as a NumPy array
        image = np.load(self.image_paths[idx])  # Shape: [224, 224]

        # # Load the image as a JPG
        # image = np.array(Image.open(self.image_paths[idx]))
        
        # Scale pixel values to [0, 255] if they are normalized
        if image.max() <= 1.0:  # Check if image is normalized
            image = (image * 255).astype(np.uint8)
        
        # Convert NumPy array to PIL Image
        image = Image.fromarray(image)
        
        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # # Convert to PyTorch tensor
        # image = torch.tensor(image, dtype=torch.float32)

        # Convert label to PyTorch tensor
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, label


In [36]:
# # Example file
# image_path = '../rr1_dataset/img_input/LSDS-2_Histology_M16-M17-M18_HE.jpg'

# img = Image.open(image_path)

# intermediate_size = (img.size[0] // 4, img.size[1] // 4)
# img_intermediate = img.resize(intermediate_size, Image.Resampling.NEAREST)
# img_resized = img_intermediate.resize(image_size, Image.Resampling.NEAREST)

# img_array = np.array(img_resized) / 255.0

# plt.imshow(img_array)
# plt.show()

In [37]:
# Define the LOOCV Function
def loocv_full_image_with_augmentation(image_paths, labels, num_epochs=100, learning_rate=0.0001, save_dir="../Models"):
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    true_labels = []
    image_pred_prob = []
    image_pred_label = []
    
    # Define augmentations for the training dataset - Grayscale
    train_transform = transforms.Compose([
        transforms.Resize(image_size),             # Resize all images
        # transforms.RandomHorizontalFlip(p=0.5),  # Flip horizontally with 50% probability
        # transforms.RandomVerticalFlip(p=0.5),    # Flip vertically with 50% probability
        transforms.RandomRotation(10),           # Rotate randomly within ±10 degrees
        transforms.RandomResizedCrop(size=(image_size[0], image_size[1]), scale=(0.9, 1.0)),
        transforms.ToTensor(),                   # Convert PIL Image to PyTorch tensor
        transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
    ])
    
    # Define transform for test dataset (no augmentation, only normalization) - Grayscale
    test_transform = transforms.Compose([
        transforms.Resize(image_size),             # Resize all images
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
    ])

    # # Define augmentations for the training dataset - RGB
    # train_transform = transforms.Compose([
    #     transforms.Resize(image_size),             # Resize all images
    #     transforms.RandomRotation(10),            # Rotate randomly within ±10 degrees
    #     transforms.RandomResizedCrop(size=(image_size), scale=(0.9, 1.0)),
    #     transforms.ToTensor(),                    # Convert PIL Image to PyTorch tensor
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize RGB values
    #                         std=[0.229, 0.224, 0.225])
    # ])

    # # Define transform for the test dataset (no augmentation, only resizing and normalization) - RGB
    # test_transform = transforms.Compose([
    #     transforms.Resize(image_size),            # Resize all images
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize RGB values
    #                         std=[0.229, 0.224, 0.225])
    # ])

    for test_image_id in range(len(image_paths)):
        print(f"Processing LOOCV for test image {test_image_id + 1}/{len(image_paths)}")

        # Split dataset into training and test sets
        train_images = [image_paths[i] for i in range(len(image_paths)) if i != test_image_id]
        train_labels = [labels[i] for i in range(len(labels)) if i != test_image_id]
        test_image = image_paths[test_image_id]
        test_label = labels[test_image_id]

        # Create datasets and dataloaders
        train_dataset = FullImageDataset(train_images, train_labels, transform=train_transform)
        test_dataset = FullImageDataset([test_image], [test_label], transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        # # Visualize augmented images
        # for i in range(1):
        #     augmented_image, _ = train_dataset[np.random.randint(10)]  # Get augmented image
        #     plt.imshow(augmented_image.squeeze().numpy().T)  # Visualize
        #     plt.title("Augmented Image")
        #     plt.show()

        # Initialize model, loss function, and optimizer
        model = ImprovedCNN().to(device)
        # model = TransferLearningModel().to(device)
        # model = get_monai_model().to(device)
        # model = get_densenet169().to(device)
        

        class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
        class_weights = torch.tensor(class_weights, dtype=torch.float)
        # print("class_weights", class_weights)

        # criterion = nn.BCELoss()
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        # Training loop
        model.train()
        for epoch in range(num_epochs):
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device).unsqueeze(1)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        # Testing loop
        model.eval()
        with torch.no_grad():
            for inputs, _ in test_loader:
                inputs = inputs.to(device)
                outputs = model(inputs).squeeze()
                probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid for probabilities
                image_pred_prob.append(probs)
                pred_label = 1 if probs > 0.5 else 0
                image_pred_label.append(pred_label)
                true_labels.append(test_label)
                print('Test pred prob:', probs, 'True label:', test_label, "test_image", test_image)

    # Save the final trained model
    final_model_path = os.path.join(save_dir, "rr1_cnn_"+str(image_size[0])+".pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")

    return true_labels, image_pred_prob

In [38]:
# image_dir = '../rr1_dataset/preprocessed_images/'

# label_dict = {
#     image_dir+'LSDS-2_Histology_M16-M17-M18_HE.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M16-M17-M18_ORO.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M19-M20_HE.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M19-M20_ORO.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M21-M22_HE.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M21-M22_ORO.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M25-M26-M27_HE.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M25-M26-M27_ORO.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M28-M30_HE.npy' : 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M28-M30_ORO.npy' : 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M31-M32_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M31-M32_ORO.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M36-M37-M38_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M36-M37-M38_ORO.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M39-M40_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M39-M40_ORO.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M4-M5-M7_HE.npy': 0, #'Basal Control',
#     image_dir+'LSDS-2_Histology_M4-M5-M7_ORO.npy': 0, #'Basal Control',
#     image_dir+'LSDS-2_Histology_M8-M10_HE.npy': 0, #'Basal Control',
#     image_dir+'LSDS-2_Histology_M8-M10_ORO.npy': 0, #'Basal Control',
# }

image_dir = '../rr1_dataset/preprocessed_images/'

# label_dict = {
#     image_dir+'LSDS-2_Histology_M16-M17-M18_HE.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M19-M20_HE.npy': 0, #'Vivarium Control',
#     image_dir+'LSDS-2_Histology_M21-M22_HE.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M25-M26-M27_HE.npy': 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M28-M30_HE.npy' : 1, #'Space Flight',
#     image_dir+'LSDS-2_Histology_M31-M32_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M36-M37-M38_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M39-M40_HE.npy': 0, #'Ground Control',
#     image_dir+'LSDS-2_Histology_M4-M5-M7_HE.npy': 0, #'Basal Control',
#     image_dir+'LSDS-2_Histology_M8-M10_HE.npy': 0, #'Basal Control',
# }

label_dict = {
    image_dir+'LSDS-2_Histology_M16-M17-M18-1_ORO.npy': 0, #'Vivarium Control',
    image_dir+'LSDS-2_Histology_M16-M17-M18-2_ORO.npy': 0, #'Vivarium Control',
    image_dir+'LSDS-2_Histology_M16-M17-M18-3_ORO.npy': 0, #'Vivarium Control',

    image_dir+'LSDS-2_Histology_M19-M20-1_ORO.npy': 0, #'Vivarium Control',
    image_dir+'LSDS-2_Histology_M19-M20-2_ORO.npy': 0, #'Vivarium Control',
    image_dir+'LSDS-2_Histology_M19-M20-3_ORO.npy': 0, #'Vivarium Control',

    image_dir+'LSDS-2_Histology_M21-M22-1_ORO.npy': 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M21-M22-2_ORO.npy': 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M21-M22-3_ORO.npy': 1, #'Space Flight',

    image_dir+'LSDS-2_Histology_M25-M26-M27-1_ORO.npy': 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M25-M26-M27-2_ORO.npy': 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M25-M26-M27-3_ORO.npy': 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M25-M26-M27-4_ORO.npy': 1, #'Space Flight',

    image_dir+'LSDS-2_Histology_M28-M30-1_ORO.npy' : 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M28-M30-2_ORO.npy' : 1, #'Space Flight',
    image_dir+'LSDS-2_Histology_M28-M30-3_ORO.npy' : 1, #'Space Flight',

    image_dir+'LSDS-2_Histology_M31-M32-1_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M31-M32-2_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M31-M32-3_ORO.npy': 0, #'Ground Control',

    image_dir+'LSDS-2_Histology_M36-M37-M38-1_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M36-M37-M38-2_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M36-M37-M38-3_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M36-M37-M38-4_ORO.npy': 0, #'Ground Control',

    image_dir+'LSDS-2_Histology_M39-M40-1_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M39-M40-2_ORO.npy': 0, #'Ground Control',
    image_dir+'LSDS-2_Histology_M39-M40-3_ORO.npy': 0, #'Ground Control',

    image_dir+'LSDS-2_Histology_M4-M5-M7-1_ORO.npy': 0, #'Basal Control',
    image_dir+'LSDS-2_Histology_M4-M5-M7-2_ORO.npy': 0, #'Basal Control',
    image_dir+'LSDS-2_Histology_M4-M5-M7-3_ORO.npy': 0, #'Basal Control',
    image_dir+'LSDS-2_Histology_M4-M5-M7-4_ORO.npy': 0, #'Basal Control',

    image_dir+'LSDS-2_Histology_M8-M10-1_ORO.npy': 0, #'Basal Control',
    image_dir+'LSDS-2_Histology_M8-M10-2_ORO.npy': 0, #'Basal Control',
    image_dir+'LSDS-2_Histology_M8-M10-3_ORO.npy': 0, #'Basal Control'
}

In [39]:
image_paths = list(label_dict.keys()) # Replace with actual .npy file paths
labels = list(label_dict.values())  # Replace with actual labels


In [40]:
true_labels, image_pred_prob = loocv_full_image_with_augmentation(image_paths, labels)

Processing LOOCV for test image 1/33
Test pred prob: 3.1904196e-07 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M16-M17-M18-1_ORO.npy
Processing LOOCV for test image 2/33
Test pred prob: 0.98675597 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M16-M17-M18-2_ORO.npy
Processing LOOCV for test image 3/33
Test pred prob: 6.9036435e-07 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M16-M17-M18-3_ORO.npy
Processing LOOCV for test image 4/33
Test pred prob: 0.99981683 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M19-M20-1_ORO.npy
Processing LOOCV for test image 5/33
Test pred prob: 3.0166817e-07 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M19-M20-2_ORO.npy
Processing LOOCV for test image 6/33
Test pred prob: 1.5898648e-09 True label: 0 test_image ../rr1_dataset/preprocessed_images/LSDS-2_Histology_M19-M20-3_ORO.npy
Processing LOOCV for t

In [41]:
# Define the cutoff value
cutoff = 0.5
# Convert to 0 and 1 based on the cutoff
image_pred_label = (np.array(image_pred_prob) > cutoff).astype(int)
image_pred_label

array([0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0])

In [43]:
# Calculate metrics
accuracy = accuracy_score(true_labels, image_pred_label)
precision = precision_score(true_labels, image_pred_label, average='weighted')
recall = recall_score(true_labels, image_pred_label, average='weighted')
f1 = f1_score(true_labels, image_pred_label, average='weighted')
auc = roc_auc_score(true_labels, image_pred_prob, average='weighted')

print(f"LOOCV Image-Level Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1-Score: {f1:.2f}")
print(f"AUC Score: {auc:.2f}")


LOOCV Image-Level Accuracy: 0.82
Precision: 0.83
Recall: 0.82
F1-Score: 0.82
AUC Score: 0.85


In [44]:
import pandas as pd
locv_results_df = pd.DataFrame({'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'AUC'],
    'Value': [accuracy, precision, recall, f1, auc]})
locv_results_df

Unnamed: 0,Metric,Value
0,Accuracy,0.818182
1,Precision,0.832612
2,Recall,0.818182
3,F1 Score,0.822314
4,AUC,0.852174
