In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np
from torchvision import transforms
import os
import torchvision.models as models
from torchvision.models import ResNet18_Weights
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import matplotlib.pyplot as plt
from monai.networks.nets import DenseNet121
from sklearn.utils.class_weight import compute_class_weight

In [None]:
image_channels = 1
image_size = (224, 224)

In [44]:
class CNN_Scratch(nn.Module):
    def __init__(self):
        super(CNN_Scratch, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)  # Downsampling

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)  # Downsampling

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)  # Downsampling

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(2, 2)  # Downsampling

        self.global_avg_pool = nn.AdaptiveAvgPool2d((2, 2))  # Ensure same feature size as DenseNet
        self.dropout = nn.Dropout(0.25)

        self.fc1 = nn.Linear(256 * 2 * 2, 128)  # Adjusted for the new feature size
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        x = self.pool4(torch.relu(self.conv4(x)))

        x = self.global_avg_pool(x)  # Adaptive pooling ensures fixed feature size
        x = torch.flatten(x, start_dim=1)  # Flatten correctly
        # print(f"Flattened feature size: {x.shape}")  # Debug print

        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        return x
    

class TransferLearningImageClassifier(nn.Module):
    """ Image Classification Model using MONAI DenseNet121 """

    def __init__(self, num_classes=1):
        super(TransferLearningImageClassifier, self).__init__()
        self.model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_classes)
        
        # Modify the pooling layer to ensure fixed feature extraction
        self.model.features[-1] = nn.AdaptiveAvgPool2d((2, 2))  # Replace last pooling layer

    def forward(self, x):
        return self.model(x)


# class TransferLearningModel(nn.Module):
#     def __init__(self):
#         super(TransferLearningModel, self).__init__()
#         # Load pre-trained ResNet18 with weights
#         self.base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        
#         # Modify the first convolutional layer to accept grayscale input
#         self.base_model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1) 
        
#         # Replace the fully connected layer for binary classification
#         num_features = self.base_model.fc.in_features
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(num_features, 128),
#             nn.ReLU(),
#             nn.Dropout(0.5),
#             nn.Linear(128, 1),
#             nn.Sigmoid()  # Binary classification
#         )
    
#     def forward(self, x):
#         return self.base_model(x)


In [45]:
from PIL import Image

class FullImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image as a NumPy array
        image = np.load(self.image_paths[idx])  # Shape: [224, 224]
        
        # Scale pixel values to [0, 255] if they are normalized
        if image.max() <= 1.0:  # Check if image is normalized
            image = (image * 255).astype(np.uint8)
        
        # Convert NumPy array to PIL Image
        image = Image.fromarray(image)
        
        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # # Convert to PyTorch tensor
        # image = torch.tensor(image, dtype=torch.float32)

        # Convert label to PyTorch tensor
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, label


In [None]:
# Define the LOOCV Function
def loocv_full_image_with_augmentation(image_paths, labels, num_epochs=500, learning_rate=0.0001, save_dir="../Models"):
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    true_labels = []
    image_pred_prob = []
    image_pred_label = []

    # Define augmentations for the training dataset
    train_transform = transforms.Compose([
        transforms.Resize(image_size),             # Resize all images
        transforms.RandomRotation(10),           # Rotate randomly within ±10 degrees
        # transforms.RandomResizedCrop(size=image_size, scale=(0.8, 1.0)),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        transforms.ToTensor(),                   # Convert PIL Image to PyTorch tensor
        transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
    ])
    
    # Define transform for test dataset (no augmentation, only normalization)
    test_transform = transforms.Compose([
        transforms.Resize(image_size),             # Resize all images
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
    ])

    for test_image_id in range(len(image_paths)):
        print(f"Processing LOOCV for test image {test_image_id + 1}/{len(image_paths)}")

        # Split dataset into training and test sets
        train_images = [image_paths[i] for i in range(len(image_paths)) if i != test_image_id]
        train_labels = [labels[i] for i in range(len(labels)) if i != test_image_id]
        test_image = image_paths[test_image_id]
        test_label = labels[test_image_id]

        # Create datasets and dataloaders
        train_dataset = FullImageDataset(train_images, train_labels, transform=train_transform)
        test_dataset = FullImageDataset([test_image], [test_label], transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        # # Visualize augmented images
        # for i in range(2):
        #     augmented_image, _ = train_dataset[i]  # Get augmented image
        #     plt.imshow(augmented_image.squeeze().numpy(), cmap='gray')  # Visualize in grayscale
        #     plt.title("Augmented Image")
        #     plt.show()

        # Initialize model, loss function, and optimizer
        model = CNN_Scratch().to(device)
        # model = TransferLearningImageClassifier().to(device)
        
        class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
        class_weights = torch.tensor(class_weights, dtype=torch.float)
        # print("class_weights", class_weights)

        # criterion = nn.BCELoss()
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        # Training loop
        model.train()
        for epoch in range(num_epochs):
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device).unsqueeze(1)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        # Testing loop
        model.eval()
        with torch.no_grad():
            for inputs, _ in test_loader:
                inputs = inputs.to(device)
                outputs = model(inputs).squeeze().cpu().numpy()
                image_pred_prob.append(outputs)
                pred_label = 1 if outputs > 0.5 else 0
                image_pred_label.append(pred_label)
                true_labels.append(test_label)

    # # Save the final trained model
    # final_model_path = os.path.join(save_dir, "rr3_cnn_23022025.pth")
    # torch.save({
    #     'model_state_dict': model.state_dict(),
    #     'optimizer_state_dict': optimizer.state_dict()
    # }, final_model_path)
    # print(f"Final model saved to {final_model_path}")

    return true_labels, image_pred_prob

In [47]:
image_dir = '../rr3_dataset/preprocessed_images/'

label_dict = {
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_158_A1.npy': 1, #'Space Flight'
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_158_B1.npy': 1, #'Space Flight',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_158_C1.npy': 1, #'Space Flight',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_158_D1.npy': 1, #'Space Flight',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_159_A1.npy': 0, #'Ground Control',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_159_B1.npy': 0, #'Ground Control',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_159_C1.npy': 1, #'Space Flight',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_159_D1.npy': 1, #'Space Flight',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_304_A1.npy': 0, #'Ground Control',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_304_B1.npy': 0, #'Ground Control',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_304_C1.npy': 0, #'Ground Control',
    image_dir+'GLDS-352_SpatialTranscriptomics_NASA-RR3_Sample_304_D1.npy': 0, #'Ground Control',
}


In [48]:
image_paths = list(label_dict.keys()) # Replace with actual .npy file paths
labels = list(label_dict.values())  # Replace with actual labels


In [49]:
true_labels, image_pred_prob = loocv_full_image_with_augmentation(image_paths, labels)

Processing LOOCV for test image 1/12
Processing LOOCV for test image 2/12
Processing LOOCV for test image 3/12
Processing LOOCV for test image 4/12
Processing LOOCV for test image 5/12
Processing LOOCV for test image 6/12
Processing LOOCV for test image 7/12
Processing LOOCV for test image 8/12
Processing LOOCV for test image 9/12
Processing LOOCV for test image 10/12
Processing LOOCV for test image 11/12
Processing LOOCV for test image 12/12


In [50]:
true_labels, image_pred_prob

([1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
 [array(9.913676, dtype=float32),
  array(8.717853, dtype=float32),
  array(7.8208942, dtype=float32),
  array(7.0151954, dtype=float32),
  array(-4.285076, dtype=float32),
  array(1.2626834, dtype=float32),
  array(6.670873, dtype=float32),
  array(4.144833, dtype=float32),
  array(-10.979938, dtype=float32),
  array(-13.445078, dtype=float32),
  array(-7.666894, dtype=float32),
  array(-7.5089974, dtype=float32)])

In [51]:
# Define the cutoff value
cutoff = 0.5
# Convert to 0 and 1 based on the cutoff
image_pred_label = (np.array(image_pred_prob) > cutoff).astype(int)
image_pred_label

array([1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0])

In [52]:
# Calculate metrics
accuracy = accuracy_score(true_labels, image_pred_label)
precision = precision_score(true_labels, image_pred_label)
recall = recall_score(true_labels, image_pred_label)
f1 = f1_score(true_labels, image_pred_label)
auc = roc_auc_score(true_labels, image_pred_prob)

print(f"LOOCV Image-Level Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1-Score: {f1:.2f}")
print(f"AUC Score: {auc:.2f}")

# LOOCV Image-Level Accuracy: 0.92
# Precision: 0.86
# Recall: 1.00
# F1-Score: 0.92
# AUC Score: 1.00

LOOCV Image-Level Accuracy: 0.92
Precision: 0.86
Recall: 1.00
F1-Score: 0.92
AUC Score: 1.00
