In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

# Function to retrieve images by indices
def get_images_by_indices(loader, indices, max_per_class=10):
    wt_images = []
    tpm_images = []
    wt_labels = []
    tpm_labels = []
    with torch.no_grad():
        for i, (inputs, lbls) in enumerate(loader):
            start_idx = i * loader.batch_size
            for j in range(inputs.size(0)):
                if start_idx + j in indices:
                    if lbls[j].item() == 0 and len(wt_images) < max_per_class:  # WT
                        wt_images.append(inputs[j])
                        wt_labels.append('WT')
                    elif lbls[j].item() == 1 and len(tpm_images) < max_per_class:  # TPM
                        tpm_images.append(inputs[j])
                        tpm_labels.append('TPM')
                    if len(wt_images) >= max_per_class and len(tpm_images) >= max_per_class:
                        break
            if len(wt_images) >= max_per_class and len(tpm_images) >= max_per_class:
                break
    return wt_images, tpm_images, wt_labels, tpm_labels

# Function to display and save images
def show_and_save_images(images, labels, cols=5, title="", filename=""):
    rows = len(images) // cols + 1
    fig, axs = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    for i, img in enumerate(images):
        ax = axs[i // cols, i % cols]
        img = img.permute(1, 2, 0).cpu().numpy()
        img = img - img.min()  # Normalize image to range [0, 1]
        img = img / img.max()
        ax.imshow(img)
        ax.set_title(labels[i])
        ax.axis('off')
    for j in range(i + 1, rows * cols):
        axs[j // cols, j % cols].axis('off')
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    if filename:
        plt.savefig(filename, bbox_inches='tight')
    plt.show()

# Function to display and save Grad-CAM images
def show_and_save_gradcam_images(images, gradcams, labels, cols=5, title="", filename=""):
    rows = len(images) // cols + 1
    fig, axs = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    for i, (img, cam) in enumerate(zip(images, gradcams)):
        ax = axs[i // cols, i % cols]
        img = img.permute(1, 2, 0).cpu().numpy()
        img = img - img.min()  # Normalize image to range [0, 1]
        img = img / img.max()
        ax.imshow(img)
        ax.imshow(cam, cmap='jet', alpha=0.5)  # Overlay Grad-CAM
        ax.set_title(labels[i])
        ax.axis('off')
    for j in range(i + 1, rows * cols):
        axs[j // cols, j % cols].axis('off')
    fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    if filename:
        plt.savefig(filename, bbox_inches='tight')
    plt.show()

# Load the dataset
data_dir = '\\path\\to\\you'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2)

# Load the PHATE transformed features and labels, you must convert the csv used before to npy
phate_transformed = np.load('phate_transformed_3d.npy')
labels = np.load('labels_pool4.npy')

# Define the transition zone coordinates (example values, adjust accordingly)
transition_zone = ((phate_transformed[:, 0] > -0.02) & (phate_transformed[:, 0] < 0.02) &
                   (phate_transformed[:, 1] > -0.02) & (phate_transformed[:, 1] < 0.02) &
                   (phate_transformed[:, 2] > 0.0) & (phate_transformed[:, 2] < 0.06))

# Define the most dense cluster coordinates (example values, adjust accordingly)
dense_cluster = ((phate_transformed[:, 0] > -0.02) & (phate_transformed[:, 0] < 0.02) &
                 (phate_transformed[:, 1] > -0.02) & (phate_transformed[:, 1] < 0.02) &
                 (phate_transformed[:, 2] > -0.02) & (phate_transformed[:, 2] < 0.02))

# Get indices of points in the transition zone and dense cluster
transition_indices = np.where(transition_zone)[0]
dense_indices = np.where(dense_cluster)[0]

# Get images from the dense cluster and transition zone
wt_dense_images, tpm_dense_images, wt_dense_labels, tpm_dense_labels = get_images_by_indices(loader, dense_indices)
wt_trans_images, tpm_trans_images, wt_trans_labels, tpm_trans_labels = get_images_by_indices(loader, transition_indices)

# Combine images and labels for dense cluster and transition zone
dense_images = wt_dense_images + tpm_dense_images
dense_labels = wt_dense_labels + tpm_dense_labels
trans_images = wt_trans_images + tpm_trans_images
trans_labels = wt_trans_labels + tpm_trans_labels

# Display and save images from the dense cluster and transition zone
show_and_save_images(dense_images, dense_labels, cols=5, title="Dense Cluster Images", filename="dense_cluster_images.png")
show_and_save_images(trans_images, trans_labels, cols=5, title="Transition Zone Images", filename="transition_zone_images.png")

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load('resnet50_tpm_wt.pth', map_location=device))
model.eval()
model = model.to(device)

# Define Grad-CAM
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.forward_hook = target_layer.register_forward_hook(self.save_forward)
        self.backward_hook = target_layer.register_full_backward_hook(self.save_backward)

    def save_forward(self, module, input, output):
        self.features = output

    def save_backward(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def __call__(self, x):
        self.model.zero_grad()
        output = self.model(x)
        score = output[:, output.max(1)[-1]].squeeze()
        score.backward(retain_graph=True)
        return self.features, self.gradients

def apply_gradcam(model, images, target_layer, device):
    gradcam = GradCAM(model, target_layer)
    cam_results = []
    for img in images:
        img = img.unsqueeze(0).to(device)
        features, gradients = gradcam(img)
        weights = F.adaptive_avg_pool2d(gradients, 1)
        cam = torch.sum(weights * features, dim=1).squeeze().cpu().data.numpy()
        cam = np.maximum(cam, 0)
        cam = cam / cam.max()  # Normalize
        cam_results.append(cam)
    return cam_results

# Apply Grad-CAM to dense cluster and transition zone images
dense_gradcams = apply_gradcam(model, dense_images, model.layer4[2].conv3, device)
trans_gradcams = apply_gradcam(model, trans_images, model.layer4[2].conv3, device)

# Display and save Grad-CAM images from the dense cluster and transition zone
show_and_save_gradcam_images(dense_images, dense_gradcams, dense_labels, cols=5, title="Dense Cluster Grad-CAM", filename="dense_cluster_gradcam.png")
show_and_save_gradcam_images(trans_images, trans_gradcams, trans_labels, cols=5, title="Transition Zone Grad-CAM", filename="transition_zone_gradcam.png")
