In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm

# Data Loading and Preprocessing
def load_all_folds(base_path):
    folds = ["Fold 1", "Fold 2", "Fold 3"]
    all_images, all_masks, all_types = [], [], []
    
    for fold in folds:
        fold_path = os.path.join(base_path, fold)
        images = np.load(os.path.join(fold_path, "images", f"fold{fold[-1]}", "images.npy"))
        masks = np.load(os.path.join(fold_path, "masks", f"fold{fold[-1]}", "masks.npy"))
        types = np.load(os.path.join(fold_path, "images", f"fold{fold[-1]}", "types.npy"))
        
        all_images.append(images)
        all_masks.append(masks)
        all_types.append(types)
    
    return all_images, all_masks, all_types

base_path = "/rsrch5/home/plm/yshokrollahi/vitamin-p/vitamin-p/data/raw/H&E"
all_images, all_masks, all_types = load_all_folds(base_path)

def create_train_val_test_split(all_images, all_masks, all_types):
    splits = []
    
    for test_fold in range(3):
        train_val_folds = [i for i in range(3) if i != test_fold]
        
        train_val_images = np.concatenate([all_images[i] for i in train_val_folds])
        train_val_masks = np.concatenate([all_masks[i] for i in train_val_folds])
        train_val_types = np.concatenate([all_types[i] for i in train_val_folds])
        
        num_samples = len(train_val_images)
        num_val = num_samples // 10
        
        indices = np.arange(num_samples)
        np.random.shuffle(indices)
        
        val_indices = indices[:num_val]
        train_indices = indices[num_val:]
        
        split = {
            'train': {
                'images': train_val_images[train_indices],
                'masks': train_val_masks[train_indices],
                'types': train_val_types[train_indices]
            },
            'val': {
                'images': train_val_images[val_indices],
                'masks': train_val_masks[val_indices],
                'types': train_val_types[val_indices]
            },
            'test': {
                'images': all_images[test_fold],
                'masks': all_masks[test_fold],
                'types': all_types[test_fold]
            }
        }
        
        splits.append(split)
    
    return splits

data_splits = create_train_val_test_split(all_images, all_masks, all_types)
print("done")


In [None]:
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import functional as TF
from scipy.ndimage import gaussian_filter, map_coordinates, distance_transform_edt
from skimage.segmentation import slic
from skimage.measure import regionprops

class CellSegmentationDataset(Dataset):
    def __init__(self, images, masks, types, image_transform=None, mask_transform=None, augment=False):
        self.images = images
        self.masks = masks
        self.types = types
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.augment = augment
        
        # Get unique tissue types and create a mapping
        self.unique_tissue_types = np.unique(self.types)
        self.tissue_type_to_idx = {t: i for i, t in enumerate(self.unique_tissue_types)}

        # Get unique cell types (excluding background)
        self.unique_cell_types = np.arange(self.masks[0].shape[-1] - 1)  # Exclude last channel (background)
        self.cell_type_to_idx = {t: i for i, t in enumerate(self.unique_cell_types)}

        print(f"Number of unique cell types: {len(self.unique_cell_types)}")
        print(f"Shape of first mask: {self.masks[0].shape}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        tissue_type = self.types[idx]

        # Normalize image to 0-1 range and convert to float32
        image = ((image - image.min()) / (image.max() - image.min())).astype(np.float32)
        
        # Create binary mask from the last channel (background) and convert to float32
        binary_mask = (mask[..., -1] == 0).astype(np.float32)  # Invert the background mask

        # Create multi-class mask for cell types (all channels except the last one)
        multi_class_mask = mask[..., :-1].astype(np.float32)
        multi_class_mask = np.divide(multi_class_mask, np.maximum(np.max(multi_class_mask), 1e-8))  # Normalize to 0-1 range

        # Apply binary mask to multi-class mask to exclude background
        multi_class_mask = multi_class_mask * binary_mask[..., np.newaxis]

        # Generate HV maps
        hv_map = self.generate_hv_map(binary_mask)

        # Create global cell labels
        global_cell_labels = np.zeros(len(self.unique_cell_types), dtype=np.float32)
        for i in range(multi_class_mask.shape[-1]):
            if np.any(multi_class_mask[..., i] > 0):
                global_cell_labels[i] = 1

        if self.augment:
            image, binary_mask, multi_class_mask, hv_map = self.apply_augmentation(image, binary_mask, multi_class_mask, hv_map)

        if self.image_transform:
            image = self.image_transform(image)

        if self.mask_transform:
            binary_mask = self.mask_transform(binary_mask)
            multi_class_mask = self.mask_transform(multi_class_mask)
            hv_map = self.mask_transform(hv_map)
        else:
            # If no mask_transform, convert to tensor manually
            binary_mask = torch.from_numpy(binary_mask).unsqueeze(0)
            multi_class_mask = torch.from_numpy(multi_class_mask).permute(2, 0, 1)
            hv_map = torch.from_numpy(hv_map).permute(2, 0, 1)

        # Create one-hot encoded tensor for tissue type
        tissue_type_idx = self.tissue_type_to_idx[tissue_type]
        tissue_type_onehot = torch.zeros(len(self.unique_tissue_types))
        tissue_type_onehot[tissue_type_idx] = 1

        # Convert global cell labels to tensor
        global_cell_labels = torch.from_numpy(global_cell_labels)

        return image, binary_mask, multi_class_mask, hv_map, tissue_type_onehot, global_cell_labels

    def generate_hv_map(self, binary_mask):
        label_img = (binary_mask * 255).astype(np.uint8)
        label_img = cv2.connectedComponents(label_img)[1]
        
        h_map = np.zeros_like(binary_mask, dtype=np.float32)
        v_map = np.zeros_like(binary_mask, dtype=np.float32)

        for region in regionprops(label_img):
            coords = region.coords
            center = region.centroid
            
            h_map[coords[:, 0], coords[:, 1]] = (coords[:, 1] - center[1]) / (region.bbox[3] - region.bbox[1] + 1e-5)
            v_map[coords[:, 0], coords[:, 1]] = (coords[:, 0] - center[0]) / (region.bbox[2] - region.bbox[0] + 1e-5)

        hv_map = np.stack([h_map, v_map], axis=-1)
        return hv_map


    def apply_augmentation(self, image, binary_mask, multi_class_mask, hv_map):
        # Convert numpy arrays to tensors
        image = torch.from_numpy(image).permute(2, 0, 1)
        binary_mask = torch.from_numpy(binary_mask).unsqueeze(0)
        multi_class_mask = torch.from_numpy(multi_class_mask).permute(2, 0, 1)
        hv_map = torch.from_numpy(hv_map).permute(2, 0, 1)

        original_size = image.shape[1:]

        # Random 90-degree rotation
        if torch.rand(1) < 0.5:
            k = torch.randint(1, 4, (1,)).item()
            image = torch.rot90(image, k, [1, 2])
            binary_mask = torch.rot90(binary_mask, k, [1, 2])
            multi_class_mask = torch.rot90(multi_class_mask, k, [1, 2])
            hv_map = torch.rot90(hv_map, k, [1, 2])

        # Random horizontal flip
        if torch.rand(1) < 0.5:
            image = TF.hflip(image)
            binary_mask = TF.hflip(binary_mask)
            multi_class_mask = TF.hflip(multi_class_mask)
            hv_map = TF.hflip(hv_map)
            hv_map[0] = -hv_map[0]  # Flip horizontal map

        # Random vertical flip
        if torch.rand(1) < 0.5:
            image = TF.vflip(image)
            binary_mask = TF.vflip(binary_mask)
            multi_class_mask = TF.vflip(multi_class_mask)
            hv_map = TF.vflip(hv_map)
            hv_map[1] = -hv_map[1]  # Flip vertical map

        # Random scaling (downscaling)
        if torch.rand(1) < 0.5:
            scale_factor = torch.FloatTensor(1).uniform_(0.8, 1.0).item()
            new_size = [max(224, int(s * scale_factor)) for s in image.shape[1:]]
            image = TF.resize(image, new_size)
            binary_mask = TF.resize(binary_mask, new_size)
            multi_class_mask = TF.resize(multi_class_mask, new_size, interpolation=TF.InterpolationMode.NEAREST)

        # Elastic transformation
        if torch.rand(1) < 0.5:
            image = self.elastic_transform(image.permute(1, 2, 0).numpy())
            binary_mask = self.elastic_transform(binary_mask.squeeze().numpy(), is_mask=True)
            multi_class_mask = self.elastic_transform(multi_class_mask.permute(1, 2, 0).numpy(), is_mask=True)
            image = torch.from_numpy(image).permute(2, 0, 1)
            binary_mask = torch.from_numpy(binary_mask).unsqueeze(0)
            multi_class_mask = torch.from_numpy(multi_class_mask).permute(2, 0, 1)

        # Ensure image is large enough for subsequent operations
        if min(image.shape[1:]) < 224:
            scale_factor = 224 / min(image.shape[1:])
            new_size = [int(s * scale_factor) for s in image.shape[1:]]
            image = TF.resize(image, new_size)
            binary_mask = TF.resize(binary_mask, new_size)
            multi_class_mask = TF.resize(multi_class_mask, new_size, interpolation=TF.InterpolationMode.NEAREST)

        # Blurring
        if torch.rand(1) < 0.5:
            sigma = torch.FloatTensor(1).uniform_(0.1, 2.0).item()
            image = torch.from_numpy(gaussian_filter(image.numpy(), sigma=(0, sigma, sigma)))

        # Gaussian noise
        if torch.rand(1) < 0.5:
            noise = torch.randn_like(image) * 0.1
            image = image + noise
            image = torch.clamp(image, 0, 1)

        # Color jittering
        if torch.rand(1) < 0.5:
            brightness_factor = torch.tensor(1.0).uniform_(0.8, 1.2).item()
            contrast_factor = torch.tensor(1.0).uniform_(0.8, 1.2).item()
            saturation_factor = torch.tensor(1.0).uniform_(0.8, 1.2).item()
            hue_factor = torch.tensor(1.0).uniform_(-0.1, 0.1).item()
            image = TF.adjust_brightness(image, brightness_factor)
            image = TF.adjust_contrast(image, contrast_factor)
            image = TF.adjust_saturation(image, saturation_factor)
            image = TF.adjust_hue(image, hue_factor)

        # SLIC superpixels
        if torch.rand(1) < 0.5:
            image = self.apply_slic(image)

        # Zoom blur
        if torch.rand(1) < 0.5:
            image = self.zoom_blur(image)

        # Random cropping with resizing
        if torch.rand(1) < 0.5:
            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(224, 224))
            image = TF.crop(image, i, j, h, w)
            binary_mask = TF.crop(binary_mask, i, j, h, w)
            multi_class_mask = TF.crop(multi_class_mask, i, j, h, w)

        # Resize back to original size
        image = TF.resize(image, original_size)
        binary_mask = TF.resize(binary_mask, original_size)
        multi_class_mask = TF.resize(multi_class_mask, original_size, interpolation=TF.InterpolationMode.NEAREST)
        hv_map = TF.resize(hv_map, original_size, interpolation=TF.InterpolationMode.BILINEAR)

        return image, binary_mask, multi_class_mask, hv_map

    def elastic_transform(self, image, alpha=1, sigma=0.1, alpha_affine=0.1, is_mask=False):
        """Elastic deformation of images as described in [Simard2003]_."""
        random_state = np.random.RandomState(None)

        if image.ndim == 2:
            shape = image.shape
        elif image.ndim == 3:
            shape = image.shape[:2]
        else:
            raise ValueError("Image must be 2D or 3D")

        # Random affine
        center_square = np.float32(shape) // 2
        square_size = min(shape) // 3
        pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size])
        pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
        M = cv2.getAffineTransform(pts1, pts2)

        if is_mask:
            if image.ndim == 2:
                image = cv2.warpAffine(image, M, shape[::-1], borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_NEAREST)
            else:
                image = np.stack([cv2.warpAffine(image[:,:,i], M, shape[::-1], borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_NEAREST) for i in range(image.shape[2])], axis=2)
        else:
            if image.ndim == 2:
                image = cv2.warpAffine(image, M, shape[::-1], borderMode=cv2.BORDER_REFLECT_101)
            else:
                image = np.stack([cv2.warpAffine(image[:,:,i], M, shape[::-1], borderMode=cv2.BORDER_REFLECT_101) for i in range(image.shape[2])], axis=2)

        dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
        dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha

        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))

        if is_mask:
            if image.ndim == 2:
                return map_coordinates(image, indices, order=0, mode='constant').reshape(shape)
            else:
                return np.stack([map_coordinates(image[:,:,i], indices, order=0, mode='constant').reshape(shape) for i in range(image.shape[2])], axis=2)
        else:
            if image.ndim == 2:
                return map_coordinates(image, indices, order=1, mode='reflect').reshape(shape)
            else:
                return np.stack([map_coordinates(image[:,:,i], indices, order=1, mode='reflect').reshape(shape) for i in range(image.shape[2])], axis=2)

    def apply_slic(self, image):
        image_np = image.numpy().transpose(1, 2, 0)
        segments = slic(image_np, n_segments=100, compactness=10, sigma=1)
        out = np.zeros_like(image_np)
        for i in np.unique(segments):
            mask = segments == i
            out[mask] = np.mean(image_np[mask], axis=0)
        return torch.from_numpy(out.transpose(2, 0, 1))

    def zoom_blur(self, image, max_factor=1.2):
        c, h, w = image.shape
        zoom_factor = torch.FloatTensor(1).uniform_(1, max_factor).item()
        zh = int(np.round(h * zoom_factor))
        zw = int(np.round(w * zoom_factor))
        zoom_image = TF.resize(image, (zh, zw))
        zoom_image = TF.center_crop(zoom_image, (h, w))
        return (image + zoom_image) / 2

# Usage example
# Usage example
chosen_split = 2

# Define transforms
image_transform = transforms.Compose([
    transforms.Lambda(lambda x: x if isinstance(x, torch.Tensor) else torch.from_numpy(x).permute(2, 0, 1)),
    transforms.Lambda(lambda x: x.float())
])

mask_transform = transforms.Compose([
    transforms.Lambda(lambda x: x if isinstance(x, torch.Tensor) else torch.from_numpy(x).unsqueeze(0)),
    transforms.Lambda(lambda x: x.float())
])

# Create datasets
train_dataset = CellSegmentationDataset(
    data_splits[chosen_split]['train']['images'],
    data_splits[chosen_split]['train']['masks'],
    data_splits[chosen_split]['train']['types'],
    image_transform=image_transform,
    mask_transform=mask_transform,
    augment=True  # Enable augmentation for training set
)

val_dataset = CellSegmentationDataset(
    data_splits[chosen_split]['val']['images'],
    data_splits[chosen_split]['val']['masks'],
    data_splits[chosen_split]['val']['types'],
    image_transform=image_transform,
    mask_transform=mask_transform,
    augment=False  # No augmentation for validation set
)

test_dataset = CellSegmentationDataset(
    data_splits[chosen_split]['test']['images'],
    data_splits[chosen_split]['test']['masks'],
    data_splits[chosen_split]['test']['types'],
    image_transform=image_transform,
    mask_transform=mask_transform,
    augment=False  # No augmentation for test set
)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Check first item in train_dataset
print("\nChecking first item in train_dataset:")
image, binary_mask, multi_class_mask, hv_map, tissue_type_onehot, global_cell_labels = train_dataset[0]
print(f"Image shape: {image.shape}, dtype: {image.dtype}")
print(f"Binary mask shape: {binary_mask.shape}, dtype: {binary_mask.dtype}")
print(f"Multi-class mask shape: {multi_class_mask.shape}, dtype: {multi_class_mask.dtype}")
print(f"HV map shape: {hv_map.shape}, dtype: {hv_map.dtype}")
print(f"Tissue type one-hot encoding: {tissue_type_onehot}")
print(f"Global cell labels: {global_cell_labels}")
print(f"Image min: {image.min().item():.4f}, max: {image.max().item():.4f}")
print(f"Binary mask min: {binary_mask.min().item():.4f}, max: {binary_mask.max().item():.4f}")
print(f"Multi-class mask min: {multi_class_mask.min().item():.4f}, max: {multi_class_mask.max().item():.4f}")
print(f"HV map min: {hv_map.min().item():.4f}, max: {hv_map.max().item():.4f}")
print(f"Unique tissue types: {train_dataset.unique_tissue_types}")
print(f"Unique cell types: {train_dataset.unique_cell_types}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_samples(dataset, num_samples=2, set_name=""):
    fig, axes = plt.subplots(num_samples, 6, figsize=(30, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        idx = np.random.randint(len(dataset))
        
        # Get the sample
        try:
            sample = dataset[idx]
        except Exception as e:
            print(f"Error getting sample {idx} from {set_name} dataset: {str(e)}")
            continue
        
        # Unpack the sample
        if len(sample) == 6:
            image, binary_mask, multi_class_mask, hv_map, tissue_type_onehot, global_cell_labels = sample
        else:
            print(f"Unexpected number of items in sample: {len(sample)}")
            continue
        
        # Print shapes and dtypes
        print(f"\nSample {i + 1} from {set_name} dataset:")
        print(f"Image shape: {image.shape}, dtype: {image.dtype}")
        print(f"Binary mask shape: {binary_mask.shape}, dtype: {binary_mask.dtype}")
        print(f"Multi-class mask shape: {multi_class_mask.shape}, dtype: {multi_class_mask.dtype}")
        print(f"HV map shape: {hv_map.shape}, dtype: {hv_map.dtype}")
        print(f"Tissue type onehot shape: {tissue_type_onehot.shape}, dtype: {tissue_type_onehot.dtype}")
        print(f"Global cell labels shape: {global_cell_labels.shape}, dtype: {global_cell_labels.dtype}")
        
        # Convert to numpy and ensure correct shape
        image = image.permute(1, 2, 0).numpy()
        binary_mask = binary_mask.squeeze().numpy()
        multi_class_mask = multi_class_mask.squeeze().numpy()
        hv_map = hv_map.squeeze().numpy()
        
        # Adjust multi_class_mask shape if necessary
        if multi_class_mask.ndim == 3 and multi_class_mask.shape[-1] == 5:
            multi_class_mask = np.transpose(multi_class_mask, (2, 0, 1))
        
        # Adjust hv_map shape if necessary
        if hv_map.ndim == 3 and hv_map.shape[-1] == 2:
            hv_map = np.transpose(hv_map, (2, 0, 1))
        
        # Print min and max values
        print(f"Image min: {image.min():.4f}, max: {image.max():.4f}")
        print(f"Binary mask min: {binary_mask.min():.4f}, max: {binary_mask.max():.4f}")
        print(f"Multi-class mask min: {multi_class_mask.min():.4f}, max: {multi_class_mask.max():.4f}")
        print(f"HV map min: {hv_map.min():.4f}, max: {hv_map.max():.4f}")
        
        # Original Image
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"{set_name} Sample {i+1}\nOriginal Image")
        axes[i, 0].axis('off')
        
        # Binary Segmentation
        axes[i, 1].imshow(binary_mask, cmap='gray')
        axes[i, 1].set_title("Binary Segmentation")
        axes[i, 1].axis('off')
        
        # Cell Classification
        num_classes = multi_class_mask.shape[0]
        colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, num_classes))
        
        # Create the colored cell classification image
        cell_class_image = np.zeros((*multi_class_mask.shape[1:], 3))
        for class_idx in range(num_classes):
            class_mask = multi_class_mask[class_idx] > 0
            color = colors[class_idx][:3]
            cell_class_image[class_mask] = color
        
        axes[i, 2].imshow(cell_class_image)
        axes[i, 2].set_title("Cell Classification")
        axes[i, 2].axis('off')
        
        # HV Map - Horizontal component
        axes[i, 3].imshow(hv_map[0], cmap='coolwarm', vmin=-1, vmax=1)
        axes[i, 3].set_title("HV Map - Horizontal")
        axes[i, 3].axis('off')
        
        # HV Map - Vertical component
        axes[i, 4].imshow(hv_map[1], cmap='coolwarm', vmin=-1, vmax=1)
        axes[i, 4].set_title("HV Map - Vertical")
        axes[i, 4].axis('off')
        
        # Tissue Type and Global Cell Labels
        axes[i, 5].axis('off')
        tissue_type_idx = tissue_type_onehot.argmax().item()
        axes[i, 5].text(0.5, 0.9, f"Tissue Type: {tissue_type_idx}", 
                        horizontalalignment='center', verticalalignment='center',
                        fontsize=10, fontweight='bold', transform=axes[i, 5].transAxes)
        
        axes[i, 5].text(0.5, 0.7, "Global Cell Labels:", 
                        horizontalalignment='center', verticalalignment='center',
                        fontsize=10, fontweight='bold', transform=axes[i, 5].transAxes)
        
        for j, label in enumerate(global_cell_labels):
            axes[i, 5].text(0.5, 0.6 - j*0.1, f"Class {j}: {label.item():.0f}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, transform=axes[i, 5].transAxes)
    
    plt.tight_layout()
    plt.show()

# Visualize samples from each dataset
visualize_samples(train_dataset, num_samples=2, set_name="Train")
visualize_samples(val_dataset, num_samples=2, set_name="Validation")
visualize_samples(test_dataset, num_samples=2, set_name="Test")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import ndimage

def find_boundaries(labeled_mask):
    boundaries = np.zeros_like(labeled_mask, dtype=bool)
    for label in np.unique(labeled_mask):
        if label == 0:  # Skip background
            continue
        cell_mask = labeled_mask == label
        dilated = ndimage.binary_dilation(cell_mask)
        cell_boundary = dilated & ~cell_mask
        boundaries |= cell_boundary
    return boundaries

def visualize_samples(dataset, num_samples=2, set_name=""):
    fig, axes = plt.subplots(num_samples, 8, figsize=(40, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        idx = np.random.randint(len(dataset))
        
        # Get the sample
        try:
            sample = dataset[idx]
        except Exception as e:
            print(f"Error getting sample {idx} from {set_name} dataset: {str(e)}")
            continue
        
        # Unpack the sample
        if len(sample) == 6:
            image, binary_mask, multi_class_mask, hv_map, tissue_type_onehot, global_cell_labels = sample
        else:
            print(f"Unexpected number of items in sample: {len(sample)}")
            continue
        
        # Print shapes and dtypes
        print(f"\nSample {i + 1} from {set_name} dataset:")
        print(f"Image shape: {image.shape}, dtype: {image.dtype}")
        print(f"Binary mask shape: {binary_mask.shape}, dtype: {binary_mask.dtype}")
        print(f"Multi-class mask shape: {multi_class_mask.shape}, dtype: {multi_class_mask.dtype}")
        print(f"HV map shape: {hv_map.shape}, dtype: {hv_map.dtype}")
        print(f"Tissue type onehot shape: {tissue_type_onehot.shape}, dtype: {tissue_type_onehot.dtype}")
        print(f"Global cell labels shape: {global_cell_labels.shape}, dtype: {global_cell_labels.dtype}")
        
        # Convert to numpy and ensure correct shape
        image = image.permute(1, 2, 0).numpy()
        binary_mask = binary_mask.squeeze().numpy()
        multi_class_mask = multi_class_mask.squeeze().numpy()
        hv_map = hv_map.squeeze().numpy()
        
        # Adjust multi_class_mask shape if necessary
        if multi_class_mask.ndim == 3 and multi_class_mask.shape[-1] == 5:
            multi_class_mask = np.transpose(multi_class_mask, (2, 0, 1))
        
        # Adjust hv_map shape if necessary
        if hv_map.ndim == 3 and hv_map.shape[-1] == 2:
            hv_map = np.transpose(hv_map, (2, 0, 1))
        
        # Print min and max values
        print(f"Image min: {image.min():.4f}, max: {image.max():.4f}")
        print(f"Binary mask min: {binary_mask.min():.4f}, max: {binary_mask.max():.4f}")
        print(f"Multi-class mask min: {multi_class_mask.min():.4f}, max: {multi_class_mask.max():.4f}")
        print(f"HV map min: {hv_map.min():.4f}, max: {hv_map.max():.4f}")
        
        # Original Image
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"{set_name} Sample {i+1}\nOriginal Image")
        axes[i, 0].axis('off')
        
        # Binary Segmentation
        axes[i, 1].imshow(binary_mask, cmap='gray')
        axes[i, 1].set_title("Binary Segmentation")
        axes[i, 1].axis('off')
        
        # Cell Classification
        num_classes = multi_class_mask.shape[0]
        colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, num_classes))
        
        # Create the colored cell classification image
        cell_class_image = np.zeros((*multi_class_mask.shape[1:], 3))
        for class_idx in range(num_classes):
            class_mask = multi_class_mask[class_idx] > 0
            color = colors[class_idx][:3]
            cell_class_image[class_mask] = color
        
        axes[i, 2].imshow(cell_class_image)
        axes[i, 2].set_title("Cell Classification")
        axes[i, 2].axis('off')
        
        # HV Map - Horizontal component
        axes[i, 3].imshow(hv_map[0], cmap='coolwarm', vmin=-1, vmax=1)
        axes[i, 3].set_title("HV Map - Horizontal")
        axes[i, 3].axis('off')
        
        # HV Map - Vertical component
        axes[i, 4].imshow(hv_map[1], cmap='coolwarm', vmin=-1, vmax=1)
        axes[i, 4].set_title("HV Map - Vertical")
        axes[i, 4].axis('off')
        
        # Combined HV Map
        combined_hv = np.sqrt(np.square(hv_map[0]) + np.square(hv_map[1]))
        axes[i, 5].imshow(combined_hv, cmap='viridis')
        axes[i, 5].set_title("Combined HV Map")
        axes[i, 5].axis('off')
        
        # Cell Boundaries
        labeled_mask = np.argmax(multi_class_mask, axis=0) + 1  # +1 to reserve 0 for background
        labeled_mask[binary_mask == 0] = 0  # Set background to 0
        boundaries = find_boundaries(labeled_mask)
        
        axes[i, 6].imshow(image)
        axes[i, 6].imshow(boundaries, alpha=0.4, cmap='gray')
        axes[i, 6].set_title("Cell Boundaries")
        axes[i, 6].axis('off')
        
        # Tissue Type and Global Cell Labels
        axes[i, 7].axis('off')
        tissue_type_idx = tissue_type_onehot.argmax().item()
        axes[i, 7].text(0.5, 0.9, f"Tissue Type: {tissue_type_idx}", 
                        horizontalalignment='center', verticalalignment='center',
                        fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)
        
        axes[i, 7].text(0.5, 0.7, "Global Cell Labels:", 
                        horizontalalignment='center', verticalalignment='center',
                        fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)
        
        for j, label in enumerate(global_cell_labels):
            axes[i, 7].text(0.5, 0.6 - j*0.1, f"Class {j}: {label.item():.0f}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, transform=axes[i, 7].transAxes)
    
    plt.tight_layout()
    plt.show()

# Visualize samples from each dataset
visualize_samples(train_dataset, num_samples=2, set_name="Train")
visualize_samples(val_dataset, num_samples=2, set_name="Validation")
visualize_samples(test_dataset, num_samples=2, set_name="Test")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import ndimage

def visualize_cell_boundaries(dataset, num_samples=1, set_name="Test"):
    fig, axes = plt.subplots(num_samples, 2, figsize=(20, 10*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Get a random sample
        idx = np.random.randint(len(dataset))
        image, binary_mask, multi_class_mask, hv_map, tissue_type_onehot, global_cell_labels = dataset[idx]
        
        # Convert tensors to numpy arrays
        image = image.permute(1, 2, 0).numpy()
        binary_mask = binary_mask.squeeze().numpy()
        multi_class_mask = multi_class_mask.numpy()
        
        # Create a labeled mask where each cell has a unique ID
        labeled_mask = np.argmax(multi_class_mask, axis=0) + 1  # +1 to reserve 0 for background
        labeled_mask[binary_mask == 0] = 0  # Set background to 0
        
        # Find boundaries
        boundaries = find_boundaries(labeled_mask)
        
        # Ensure boundaries is 2D
        if boundaries.ndim > 2:
            boundaries = np.max(boundaries, axis=2)
        
        # Display original image
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"{set_name} Sample {i+1}: Original Image")
        axes[i, 0].axis('off')
        
        # Display image with cell boundaries
        axes[i, 1].imshow(image)
        axes[i, 1].imshow(boundaries, alpha=0.4, cmap='gray')
        axes[i, 1].set_title(f"{set_name} Sample {i+1}: Cell Boundaries")
        axes[i, 1].axis('off')
        
        # Print additional information
        tissue_type = dataset.unique_tissue_types[torch.argmax(tissue_type_onehot).item()]
        present_cell_types = [dataset.unique_cell_types[i] for i, present in enumerate(global_cell_labels) if present]
        
        print(f"\nSample {i + 1} from {set_name} dataset:")
        print(f"Tissue Type: {tissue_type}")
        print(f"Present Cell Types: {present_cell_types}")
        print(f"Image shape: {image.shape}, dtype: {image.dtype}")
        print(f"Boundaries shape: {boundaries.shape}, dtype: {boundaries.dtype}")
        print(f"Image min: {image.min():.4f}, max: {image.max():.4f}")
        print(f"Boundaries min: {boundaries.min():.4f}, max: {boundaries.max():.4f}")
    
    plt.tight_layout()
    plt.show()

def find_boundaries(labeled_mask):
    boundaries = np.zeros_like(labeled_mask, dtype=bool)
    for label in np.unique(labeled_mask):
        if label == 0:  # Skip background
            continue
        cell_mask = labeled_mask == label
        dilated = ndimage.binary_dilation(cell_mask)
        cell_boundary = dilated & ~cell_mask
        boundaries |= cell_boundary
    return boundaries

# Visualize samples from the test dataset
visualize_cell_boundaries(test_dataset, num_samples=2)

def find_boundaries(labeled_mask):
    boundaries = np.zeros_like(labeled_mask, dtype=bool)
    for label in np.unique(labeled_mask):
        if label == 0:  # Skip background
            continue
        cell_mask = labeled_mask == label
        dilated = ndimage.binary_dilation(cell_mask)
        cell_boundary = dilated & ~cell_mask
        boundaries |= cell_boundary
    return boundaries

# Visualize samples from the test dataset
visualize_cell_boundaries(test_dataset, num_samples=2)

## model training 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import swin_t, Swin_T_Weights

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.se = SEBlock(out_channels)

    def forward(self, x):
        x = self.relu(self.norm1(self.conv1(x)))
        x = self.relu(self.norm2(self.conv2(x)))
        x = self.se(x)
        return x

class SwinEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        if pretrained:
            weights = Swin_T_Weights.IMAGENET1K_V1
        else:
            weights = None
        self.swin = swin_t(weights=weights)
        self.swin.head = nn.Identity()  # Remove the classifier head

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.swin.features):
            x = layer(x)
            if i in [2, 4, 6]:  # Collect features from specific layers
                features.append(x)
        return features

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class MultiHeadDecoder(nn.Module):
    def __init__(self, num_cell_classes, num_tissue_classes):
        super().__init__()
        self.decoder1 = DecoderBlock(768, 384)
        self.decoder2 = DecoderBlock(384 + 384, 192)
        self.decoder3 = DecoderBlock(192 + 192, 96)
        self.decoder4 = DecoderBlock(96 + 192, 48)
        
        # Binary cell segmentation (all cells)
        self.cell_seg_conv = nn.Conv2d(48, 1, kernel_size=1)
        
        # Cell type classification
        self.cell_class_conv = nn.Conv2d(48, num_cell_classes, kernel_size=1)
        
        # Tissue classification branch
        self.tc_pool = nn.AdaptiveAvgPool2d(1)
        self.tc_fc = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_tissue_classes)
        )
        
        # Global cell classification
        self.global_cell_pool = nn.AdaptiveMaxPool2d(1)
        self.global_cell_fc = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_cell_classes)
        )
        
        # HV branch
        self.hv_conv = nn.Conv2d(48, 2, kernel_size=1)  # 2 channels for horizontal and vertical maps
        
        self.attention1 = AttentionBlock(F_g=384, F_l=384, F_int=192)
        self.attention2 = AttentionBlock(F_g=192, F_l=192, F_int=96)
        self.attention3 = AttentionBlock(F_g=96, F_l=192, F_int=48)

    def forward(self, features):
        x = features[-1]
        x = x.permute(0, 3, 1, 2)  # Change from [B, H, W, C] to [B, C, H, W]
        x = self.decoder1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        skip1 = F.interpolate(features[-2].permute(0, 3, 1, 2), size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, self.attention1(x, skip1)], dim=1)
        x = self.decoder2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        skip2 = F.interpolate(features[-3].permute(0, 3, 1, 2), size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, self.attention2(x, skip2)], dim=1)
        x = self.decoder3(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        skip3 = F.interpolate(features[0].permute(0, 3, 1, 2), size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, self.attention3(x, skip3)], dim=1)
        x = self.decoder4(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        # Binary cell segmentation (all cells)
        cell_seg_out = self.cell_seg_conv(x)
        cell_seg_out = F.interpolate(cell_seg_out, scale_factor=2, mode='bilinear', align_corners=False)
        
        # Cell type classification
        cell_class_out = self.cell_class_conv(x)
        cell_class_out = F.interpolate(cell_class_out, scale_factor=2, mode='bilinear', align_corners=False)
        
        # Tissue classification (tc)
        tc_out = self.tc_pool(features[-1].permute(0, 3, 1, 2))
        tc_out = tc_out.view(tc_out.size(0), -1)
        tc_out = self.tc_fc(tc_out)
        
        # Global cell classification
        global_cell_features = self.global_cell_pool(features[-1].permute(0, 3, 1, 2))
        global_cell_features = global_cell_features.view(global_cell_features.size(0), -1)
        global_cell_out = self.global_cell_fc(global_cell_features)
        
        # HV distance maps
        hv_out = self.hv_conv(x)
        hv_out = F.interpolate(hv_out, scale_factor=2, mode='bilinear', align_corners=False)
        
        return cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out

class ModifiedCellSwin(nn.Module):
    def __init__(self, num_cell_classes, num_tissue_classes, seg_threshold=0.5):
        super().__init__()
        self.encoder = SwinEncoder()
        self.decoder = MultiHeadDecoder(num_cell_classes, num_tissue_classes)
        self.seg_threshold = seg_threshold

    def forward(self, x):
        features = self.encoder(x)
        cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = self.decoder(features)
        
        # Apply binary segmentation mask to cell classification
        cell_seg_mask = (torch.sigmoid(cell_seg_out) > self.seg_threshold).float()
        cell_class_out_masked = cell_class_out * cell_seg_mask
        
        return cell_seg_out, cell_class_out_masked, tc_out, global_cell_out, hv_out

# Initialize the model
num_cell_classes = 5  # 5 cell types
num_tissue_classes = len(train_dataset.unique_tissue_types)  # Should be 19 based on your data
model = ModifiedCellSwin(num_cell_classes, num_tissue_classes).float()
print("Modified CellSwin model with binary segmentation, classification, HV maps, and global cell classification defined.")

# Print model structure
print(model)

# Verify output shapes
sample_input = torch.randn(1, 3, 256, 256)  # Batch size 1, 3 channels, 256x256 image
cell_seg_out, cell_class_out_masked, tc_out, global_cell_out, hv_out = model(sample_input)
print(f"cell_seg_out shape: {cell_seg_out.shape}")
print(f"cell_class_out_masked shape: {cell_class_out_masked.shape}")
print(f"tc_out shape: {tc_out.shape}")
print(f"global_cell_out shape: {global_cell_out.shape}")
print(f"hv_out shape: {hv_out.shape}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3, gamma=4/3):
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        TP = (inputs * targets).sum()    
        FP = ((1-targets) * inputs).sum()
        FN = (targets * (1-inputs)).sum()
        
        Tversky = (TP + 1e-5) / (TP + self.alpha*FP + self.beta*FN + 1e-5)  
        FocalTversky = (1 - Tversky)**self.gamma
        
        return FocalTversky

class MultiTaskLoss(nn.Module):
    def __init__(self, num_classes, num_tissue_types):
        super(MultiTaskLoss, self).__init__()
        self.focal_tversky_loss = FocalTverskyLoss()
        self.cell_class_loss = nn.CrossEntropyLoss(ignore_index=-100)
        self.tissue_class_loss = nn.CrossEntropyLoss()
        self.global_cell_loss = nn.BCEWithLogitsLoss()
        self.hv_loss = nn.MSELoss()

    def forward(self, cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out,
                cell_seg_target, cell_class_target, tissue_type_target, global_cell_target, hv_target):
        

        # Adjust shapes if necessary
        if cell_class_target.dim() == 5:
            cell_class_target = cell_class_target.squeeze(1).permute(0, 3, 1, 2)
        if hv_target.dim() == 5:
            hv_target = hv_target.squeeze(1).permute(0, 3, 1, 2)

        # Cell segmentation loss (Focal Tversky)
        cell_seg_loss = self.focal_tversky_loss(cell_seg_out, cell_seg_target)
        
        # Cell classification loss
        cell_class_out_reshaped = cell_class_out.permute(0, 2, 3, 1).contiguous().view(-1, cell_class_out.size(1))
        cell_class_target_reshaped = cell_class_target.permute(0, 2, 3, 1).contiguous().view(-1, cell_class_target.size(1))
        cell_class_loss = self.cell_class_loss(cell_class_out_reshaped, cell_class_target_reshaped.argmax(dim=1))
        
        # Tissue classification loss
        tissue_class_loss = self.tissue_class_loss(tc_out, tissue_type_target.argmax(dim=1))
        
        # Global cell classification loss
        global_cell_loss = self.global_cell_loss(global_cell_out, global_cell_target)
        
        # HV branch: Horizontal and vertical distance map loss
        hv_loss = self.hv_loss(hv_out, hv_target)
        
        # Compute gradients of HV maps
        hv_out_grad = self.compute_gradients(hv_out)
        hv_target_grad = self.compute_gradients(hv_target)
        hv_grad_loss = self.hv_loss(hv_out_grad, hv_target_grad)
        
        # Combine losses
        total_loss = cell_seg_loss + cell_class_loss + tissue_class_loss + global_cell_loss + hv_loss + hv_grad_loss
        
        return total_loss, {
            'cell_seg_loss': cell_seg_loss.item(),
            'cell_class_loss': cell_class_loss.item(),
            'tissue_class_loss': tissue_class_loss.item(),
            'global_cell_loss': global_cell_loss.item(),
            'hv_loss': hv_loss.item(),
            'hv_grad_loss': hv_grad_loss.item()
        }

    def compute_gradients(self, tensor):
        # Compute gradients
        dx = tensor[:, :, :, 1:] - tensor[:, :, :, :-1]
        dy = tensor[:, :, 1:, :] - tensor[:, :, :-1, :]
        
        # Pad the gradients to match the original size
        dx = F.pad(dx, (0, 1, 0, 0), mode='replicate')
        dy = F.pad(dy, (0, 0, 0, 1), mode='replicate')
        
        return torch.cat([dx, dy], dim=1)

# Usage example
criterion = MultiTaskLoss(num_classes=5, num_tissue_types=19)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

def train(model, train_loader, val_loader, num_epochs=10, batch_size=16, learning_rate=1e-4):
    device = next(model.parameters()).device
    
    criterion = MultiTaskLoss(
        num_classes=5, 
        num_tissue_types=19
    )
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        loss_dict_sum = {
            'cell_seg_loss': 0, 'cell_class_loss': 0, 'tissue_class_loss': 0,
            'global_cell_loss': 0, 'hv_loss': 0, 'hv_grad_loss': 0
        }
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels = batch
            images = images.to(device)
            binary_masks = binary_masks.to(device)
            multi_class_masks = multi_class_masks.to(device)
            hv_maps = hv_maps.to(device)
            tissue_types = tissue_types.to(device)
            global_cell_labels = global_cell_labels.to(device)

            optimizer.zero_grad()
            
            try:
                cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)

                loss, loss_dict = criterion(cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out,
                                            binary_masks, multi_class_masks, tissue_types, global_cell_labels, hv_maps)

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                total_loss += loss.item()
                for k, v in loss_dict.items():
                    loss_dict_sum[k] += v
            
            except RuntimeError as e:
                print(f"Error in training batch: {e}")
                raise e

        avg_train_loss = total_loss / len(train_loader)
        avg_loss_dict = {k: v / len(train_loader) for k, v in loss_dict_sum.items()}

        # Validation loop
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_loss_dict_sum = {k: 0 for k in loss_dict_sum.keys()}
        with torch.no_grad():
            for batch in val_loader:
                images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels = batch
                images = images.to(device)
                binary_masks = binary_masks.to(device)
                multi_class_masks = multi_class_masks.to(device)
                hv_maps = hv_maps.to(device)
                tissue_types = tissue_types.to(device)
                global_cell_labels = global_cell_labels.to(device)


                cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)


                loss, loss_dict = criterion(cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out,
                                            binary_masks, multi_class_masks, tissue_types, global_cell_labels, hv_maps)

                val_loss += loss.item()
                for k, v in loss_dict.items():
                    val_loss_dict_sum[k] += v

        avg_val_loss = val_loss / len(val_loader)
        avg_val_loss_dict = {k: v / len(val_loader) for k, v in val_loss_dict_sum.items()}

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Training Loss: {avg_train_loss:.4f}")
        for k, v in avg_loss_dict.items():
            print(f"  {k}: {v:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}")
        for k, v in avg_val_loss_dict.items():
            print(f"  {k}: {v:.4f}")

        scheduler.step(avg_val_loss)

    print("\nTraining completed!")
    torch.save(model.state_dict(), "improved_cellswin_model.pth")
    print("Model saved successfully!")

# Usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)

# train(model, train_loader, val_loader, num_epochs=60, batch_size=16, learning_rate=1e-4)

In [None]:
# Continue training
import warnings
warnings.filterwarnings("ignore")
epochs = 100 
train(model, train_loader, val_loader, num_epochs=epochs, batch_size=16, learning_rate=1e-4)

## Evaluations

In [None]:
import torch

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize the model
model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)

# Load the saved state dict
model.load_state_dict(torch.load("improved_cellswin_model.pth", map_location=device))

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score

# Assuming you have already defined and loaded your ModifiedCellSwin model
def postprocess_hv(hv_map, cell_prob, prob_thresh=0.5, hv_thresh=0.2):
    # Convert tensors to numpy arrays
    hv_map = hv_map.cpu().numpy()
    cell_prob = cell_prob.cpu().numpy()

    # Compute gradients of HV maps
    grad_h = np.abs(hv_map[0, 1:, :] - hv_map[0, :-1, :])
    grad_v = np.abs(hv_map[1, :, 1:] - hv_map[1, :, :-1])

    # Pad gradients to match original size
    grad_h = np.pad(grad_h, ((1, 0), (0, 0)), mode='constant')
    grad_v = np.pad(grad_v, ((0, 0), (1, 0)), mode='constant')

    # Combine gradients
    grad_combined = np.maximum(grad_h, grad_v)

    # Create binary mask
    cell_mask = cell_prob > prob_thresh
    grad_mask = grad_combined > hv_thresh

    # Combine masks
    mask = cell_mask & (~grad_mask)

    # Find peaks (cell centers)
    peaks = peak_local_max(cell_prob, min_distance=3, threshold_abs=prob_thresh, labels=mask)

    # If no peaks are found, return a blank segmentation
    if len(peaks) == 0:
        return np.zeros_like(cell_prob, dtype=np.int32).squeeze()

    # Create markers for watershed
    markers = np.zeros_like(cell_prob, dtype=np.int32)
    markers[tuple(peaks.T)] = np.arange(1, len(peaks) + 1)

    # Perform watershed
    segmentation = watershed(-cell_prob, markers, mask=mask)

    return segmentation.squeeze()  # Add squeeze() here

def predict_and_postprocess(model, image):
    with torch.no_grad():
        cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(image.unsqueeze(0))
    
    # Apply sigmoid to cell segmentation output
    cell_prob = torch.sigmoid(cell_seg_out)
    
    # Postprocess
    segmentation = postprocess_hv(hv_out[0], cell_prob[0])
    
    return segmentation, cell_class_out.squeeze(0), tc_out.squeeze(0), global_cell_out.squeeze(0)

def calculate_metrics(pred_mask, true_mask):
    pred_flat = pred_mask.flatten()
    true_flat = true_mask.flatten()
    
    # Calculate F1 score, precision, and recall
    f1 = f1_score(true_flat, pred_flat, average='weighted', zero_division=1)
    precision = precision_score(true_flat, pred_flat, average='weighted', zero_division=1)
    recall = recall_score(true_flat, pred_flat, average='weighted', zero_division=1)
    
    return f1, precision, recall

def evaluate_model(model, test_loader, device):
    model.eval()
    total_f1, total_precision, total_recall = 0, 0, 0
    num_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, binary_masks, _, _, _, _ = batch
            images = images.to(device)
            
            for i in range(images.size(0)):
                segmentation, _, _, _ = predict_and_postprocess(model, images[i])
                
                # Assuming binary_masks contains the ground truth segmentation
                true_mask = binary_masks[i].squeeze().cpu().numpy()
                

                
                f1, precision, recall = calculate_metrics(segmentation, true_mask)
                
                total_f1 += f1
                total_precision += precision
                total_recall += recall
                num_samples += 1

    
    avg_f1 = total_f1 / num_samples
    avg_precision = total_precision / num_samples
    avg_recall = total_recall / num_samples
    
    return avg_f1, avg_precision, avg_recall

# Visualization function
def visualize_results(image, true_mask, pred_mask, save_path=None):
    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')
    
    ax2.imshow(true_mask, cmap='nipy_spectral')
    ax2.set_title('True Mask')
    ax2.axis('off')
    
    # Convert pred_mask to int32 if it's not already
    pred_mask = pred_mask.astype(np.int32) if pred_mask.dtype != np.int32 else pred_mask
    ax3.imshow(pred_mask, cmap='nipy_spectral')
    ax3.set_title('Predicted Mask')
    ax3.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()
    
    plt.close(fig)  # Close the figure to free up memory

# Main execution
if __name__ == "__main__":
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model (assuming you have this defined)
    model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)

    # Load the saved state dict
    model.load_state_dict(torch.load("improved_cellswin_model.pth", map_location=device))

    # Set the model to evaluation mode
    model.eval()
    print("Model loaded successfully!")

    # Run evaluation
    avg_f1, avg_precision, avg_recall = evaluate_model(model, test_loader, device)

    print(f"Average F1 Score: {avg_f1:.4f}")
    print(f"Average Precision: {avg_precision:.4f}")
    print(f"Average Recall: {avg_recall:.4f}")

    # Visualize some results (optional)
# Visualize some results (optional)
    for i, (images, binary_masks, _, _, _, _) in enumerate(test_loader):
        if i >= 5:  # Visualize first 5 images
            break
        image = images[0].to(device)
        true_mask = binary_masks[0].squeeze().cpu().numpy()
        pred_mask, _, _, _ = predict_and_postprocess(model, image)

        print(f"Pred mask shape: {pred_mask.shape}")
        print(f"Unique values in pred mask: {np.unique(pred_mask)}")

        visualize_results(image.cpu(), true_mask, pred_mask, f'result_{i}.png')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import ndimage as ndi
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2

from scipy import ndimage as ndi
from skimage.segmentation import watershed
from skimage.feature import peak_local_max

def postprocess_hv(hv_map, cell_prob, prob_thresh=0.5, hv_thresh=0.1):
    # Convert tensors to numpy arrays
    hv_map = hv_map.cpu().numpy()
    cell_prob = cell_prob.cpu().numpy().squeeze()

    # Compute gradients of HV maps
    grad_h = np.abs(hv_map[0, 1:, :] - hv_map[0, :-1, :])
    grad_v = np.abs(hv_map[1, :, 1:] - hv_map[1, :, :-1])

    # Pad gradients to match original size
    grad_h = np.pad(grad_h, ((1, 0), (0, 0)), mode='constant')
    grad_v = np.pad(grad_v, ((0, 0), (1, 0)), mode='constant')

    # Combine gradients
    grad_combined = np.maximum(grad_h, grad_v)

    # Create binary mask from cell probability
    cell_mask = cell_prob > prob_thresh

    # Find local maxima (cell centers)
    distance = ndi.distance_transform_edt(cell_mask)
    coordinates = peak_local_max(distance, footprint=np.ones((3, 3)), labels=cell_mask)
    local_maxi = np.zeros_like(cell_mask, dtype=bool)
    local_maxi[tuple(coordinates.T)] = True

    # Create markers for watershed
    markers, _ = ndi.label(local_maxi)

    # Invert cell probability for watershed
    cell_prob_inv = 1 - cell_prob

    # Perform watershed
    segmentation = watershed(cell_prob_inv, markers, mask=cell_mask)

    return segmentation

def predict_and_postprocess(model, image):
    with torch.no_grad():
        cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(image.unsqueeze(0))
    
    print(f"Cell seg out shape: {cell_seg_out.shape}")
    print(f"Cell seg out min/max: {cell_seg_out.min().item():.4f}/{cell_seg_out.max().item():.4f}")
    print(f"HV out shape: {hv_out.shape}")
    print(f"HV out min/max: {hv_out.min().item():.4f}/{hv_out.max().item():.4f}")

    # Apply sigmoid to cell segmentation output
    cell_prob = torch.sigmoid(cell_seg_out)
    
    # Postprocess
    segmentation = postprocess_hv(hv_out[0], cell_prob[0])
    
    return segmentation, cell_class_out.squeeze(0), tc_out.squeeze(0), global_cell_out.squeeze(0), cell_prob[0], hv_out[0]

def visualize_results(image, true_mask, segmentation, hv_map, cell_prob, cell_class_out, multi_class_mask):
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))
    
    # Original Image
    axs[0, 0].imshow(image.permute(1, 2, 0).cpu().numpy())
    axs[0, 0].set_title('Original Image')
    axs[0, 0].axis('off')
    
    # Ground Truth Binary Segmentation
    axs[0, 1].imshow(true_mask, cmap='gray')
    axs[0, 1].set_title('GT Binary Segmentation')
    axs[0, 1].axis('off')
    
    # Predicted Instance Segmentation
    axs[0, 2].imshow(segmentation, cmap='nipy_spectral')
    axs[0, 2].set_title('Predicted Instance Segmentation')
    axs[0, 2].axis('off')
    
    # Cell Probability
    axs[0, 3].imshow(cell_prob.squeeze().cpu().numpy(), cmap='viridis')
    axs[0, 3].set_title('Cell Probability')
    axs[0, 3].axis('off')
    
    # HV Map
    axs[1, 0].imshow(hv_map[0], cmap='coolwarm')
    axs[1, 0].set_title('Horizontal Map')
    axs[1, 0].axis('off')
    
    axs[1, 1].imshow(hv_map[1], cmap='coolwarm')
    axs[1, 1].set_title('Vertical Map')
    axs[1, 1].axis('off')
    
    # Ground Truth Cell Classification
    if multi_class_mask.ndim == 3:
        num_classes = multi_class_mask.shape[-1]
        colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, num_classes))
        
        gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
        for class_idx in range(num_classes):
            class_mask = multi_class_mask[..., class_idx] > 0
            color = colors[class_idx][:3]
            gt_cell_class_image[class_mask] = color
    else:
        gt_cell_class_image = multi_class_mask
    
    axs[1, 2].imshow(gt_cell_class_image)
    axs[1, 2].set_title('Ground Truth\nCell Classification')
    axs[1, 2].axis('off')
    
    # Predicted Cell Classification
    num_classes = cell_class_out.shape[0]
    pred_cell_class_image = np.zeros((*cell_class_out.shape[1:], 3))
    cell_class_out_argmax = np.argmax(cell_class_out.cpu().numpy(), axis=0)

    for class_idx in range(num_classes):
        class_mask = (cell_class_out_argmax == class_idx) & (segmentation > 0)
        color = colors[class_idx][:3]
        pred_cell_class_image[class_mask] = color

    pred_cell_class_image[segmentation == 0] = [0, 0, 0]
    
    axs[1, 3].imshow(pred_cell_class_image)
    axs[1, 3].set_title('Predicted\nCell Classification')
    axs[1, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

# Main execution
if __name__ == "__main__":
    # Set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model (assuming you have this defined)
    model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)

    # Load the saved state dict
    model.load_state_dict(torch.load("improved_cellswin_model.pth", map_location=device))

    # Set the model to evaluation mode
    model.eval()
    print("Model loaded successfully!")

    # Process and visualize 5 samples
    for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
        if i >= 5:
            break
        
        image = images[0].to(device)
        true_mask = binary_masks[0].squeeze().cpu().numpy()
        true_hv_map = hv_maps[0].squeeze().cpu().numpy()
        multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
        
        segmentation, cell_class_out, tc_out, global_cell_out, cell_prob, hv_out = predict_and_postprocess(model, image)
        
        print(f"\nSample {i+1}:")
        print(f"True mask shape: {true_mask.shape}")
        print(f"Segmentation shape: {segmentation.shape}")
        print(f"Unique values in true mask: {np.unique(true_mask)}")
        print(f"Unique values in segmentation: {np.unique(segmentation)}")
        
        visualize_results(image.cpu(), true_mask, segmentation, hv_out.cpu().numpy(), cell_prob, cell_class_out, multi_class_mask)

        # Print additional information
        print(f"Binary Mask - Unique values: GT {np.unique(true_mask)}, Pred {np.unique(segmentation > 0)}")
        print(f"Multi-class Mask - Unique values: GT {np.unique(np.argmax(multi_class_mask, axis=-1) if multi_class_mask.ndim == 3 else multi_class_mask)}, Pred {np.unique(cell_class_out.argmax(dim=0).cpu().numpy()[segmentation > 0])}")
        
        # Count cells of each type
        num_classes = cell_class_out.shape[0]
        for j in range(num_classes):
            gt_count = np.sum(np.argmax(multi_class_mask, axis=-1) == j) if multi_class_mask.ndim == 3 else np.sum(multi_class_mask == j)
            pred_count = np.sum((cell_class_out.argmax(dim=0).cpu().numpy() == j) & (segmentation > 0))
            print(f"Class {j} (GT | Pred): {gt_count} | {pred_count}")
        
        print(f"Tissue Types - GT {tissue_types[0].argmax().item()}, Pred {tc_out.argmax().item()}")
        print(f"Global Cell Labels - GT {global_cell_labels[0].argmax().item()}, Pred {global_cell_out.argmax().item()}")
        
        # Print prediction confidence
        for j in range(num_classes):
            class_confidence = cell_class_out[j].cpu().numpy()
            print(f"Class {j} Confidence:")
            print(f"  Min: {class_confidence.min():.4f}")
            print(f"  Max: {class_confidence.max():.4f}")
            print(f"  Mean: {class_confidence.mean():.4f}")
            print(f"  Std: {class_confidence.std():.4f}")

    print("\nVisualization completed.")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 7, figsize=(35, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Print shapes for debugging
            print(f"Sample {i} shapes:")
            print(f"image: {image.shape}")
            print(f"binary_mask: {binary_mask.shape}")
            print(f"multi_class_mask: {multi_class_mask.shape}")
            print(f"cell_seg_out_img: {cell_seg_out_img.shape}")
            print(f"cell_class_out_img: {cell_class_out_img.shape}")
            print(f"hv_out_img: {hv_out_img.shape}")
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification
            num_classes = multi_class_mask.shape[-1]
            colors = plt.cm.get_cmap('tab10')(np.linspace(0, 1, num_classes))
            
            gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
            for class_idx in range(num_classes):
                class_mask = multi_class_mask[..., class_idx] > 0
                color = colors[class_idx][:3]
                gt_cell_class_image[class_mask] = color
            
            axes[i, 3].imshow(gt_cell_class_image)
            axes[i, 3].set_title("Ground Truth\nCell Classification")
            axes[i, 3].axis('off')
            
            # Predicted Cell Classification
            pred_cell_class_image = np.zeros((*cell_class_out_img.shape[1:], 3))
            cell_class_out_argmax = np.argmax(cell_class_out_img, axis=0)
            binary_pred = cell_seg_out_img > 0.5

            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                pred_cell_class_image[class_mask] = color

            pred_cell_class_image[~binary_pred] = [0, 0, 0]
            
            axes[i, 4].imshow(pred_cell_class_image)
            axes[i, 4].set_title("Predicted\nCell Classification")
            axes[i, 4].axis('off')
            
            # HV Map Visualization
            hv_magnitude = np.sqrt(np.sum(hv_out_img**2, axis=0))
            axes[i, 5].imshow(hv_magnitude, cmap='viridis')
            axes[i, 5].set_title("HV Map Magnitude")
            axes[i, 5].axis('off')

            # Cell Boundaries using HV Map
            hv_grad_x = cv2.Sobel(hv_out_img[0], cv2.CV_64F, 1, 0, ksize=3)
            hv_grad_y = cv2.Sobel(hv_out_img[1], cv2.CV_64F, 0, 1, ksize=3)
            hv_edges = np.sqrt(hv_grad_x**2 + hv_grad_y**2)

            # Normalize edge intensities
            hv_edges = (hv_edges - hv_edges.min()) / (hv_edges.max() - hv_edges.min())

            # Threshold edges to create binary boundary map
            boundary_threshold = np.percentile(hv_edges, 98)  # Adjust this percentile as needed
            cell_boundaries = (hv_edges > boundary_threshold).astype(np.uint8)

            # Apply morphological operations to thin and clean up the boundaries
            kernel = np.ones((3,3), np.uint8)
            cell_boundaries = cv2.morphologyEx(cell_boundaries, cv2.MORPH_OPEN, kernel, iterations=1)
            cell_boundaries = cv2.dilate(cell_boundaries, kernel, iterations=1)

            # Create RGB boundary image with thin white lines
            boundary_rgb = np.zeros((*cell_boundaries.shape, 3), dtype=np.float32)
            boundary_rgb[cell_boundaries > 0] = [1, 1, 1]  # White boundaries

            # Combine predicted cell classification with boundary overlay
            combined_image = pred_cell_class_image.copy()
            combined_image[cell_boundaries > 0] = [1, 1, 1]  # Set boundary pixels to white

            axes[i, 6].imshow(combined_image)
            axes[i, 6].set_title("Predicted Classification\nwith Cell Boundaries")
            axes[i, 6].axis('off')

            # Tissue Type
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 6].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 6].transAxes)
        plt.tight_layout()
        plt.show()
    
        # Print additional information
        print(f"\nSample {i+1}:")
        print(f"Binary Mask - Unique values: GT {np.unique(binary_mask)}, Pred {np.unique(cell_seg_out_img > 0.5)}")
        print(f"Multi-class Mask - Unique values: GT {np.unique(np.argmax(multi_class_mask, axis=-1))}, Pred {np.unique(cell_class_out_argmax[binary_pred])}")
        
        # Count cells of each type
        for j in range(num_classes):
            gt_count = np.sum(np.argmax(multi_class_mask, axis=-1) == j)
            pred_count = np.sum((cell_class_out_argmax == j) & binary_pred)
            print(f"Class {j} (GT | Pred): {gt_count} | {pred_count}")
        
        print(f"Tissue Types - GT {tissue_types[0].argmax().item()}, Pred {predicted_tissue}")
        print(f"Global Cell Labels - GT {global_cell_labels[0].argmax().item()}, Pred {global_cell_out[0].argmax().item()}")
        
        # Print prediction confidence
        for j in range(num_classes):
            class_confidence = cell_class_out_img[j]
            print(f"Class {j} Confidence:")
            print(f"  Min: {class_confidence.min():.4f}")
            print(f"  Max: {class_confidence.max():.4f}")
            print(f"  Mean: {class_confidence.mean():.4f}")
            print(f"  Std: {class_confidence.std():.4f}")

# Call the function
visualize_results(model, test_loader, device, num_samples=5)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(20, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, _, _, _, _) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, _, _, _, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Predicted Segmentation with Cell Boundaries using HV maps
            cell_boundaries = get_cell_boundaries(cell_seg_out_img, hv_out_img)
            
            # Create a copy of the original image for drawing
            image_with_boundaries = image.copy()
            
            # Draw cell boundaries
            image_with_boundaries[cell_boundaries > 0] = [1, 0, 0]  # Red color for boundaries
            
            # Display the image with boundaries
            axes[i, 3].imshow(image_with_boundaries)
            axes[i, 3].set_title("Predicted Segmentation\nwith Cell Boundaries")
            axes[i, 3].axis('off')

        plt.tight_layout()
        plt.show()
    
        # Print additional information
        print(f"\nSample {i+1}:")
        print(f"Binary Mask - Unique values: GT {np.unique(binary_mask)}, Pred {np.unique(cell_seg_out_img > 0.5)}")

def get_cell_boundaries(cell_seg_out, hv_map):
    # Threshold the cell segmentation output
    binary_mask = (cell_seg_out > 0.5).astype(np.uint8)
    
    # Get initial boundaries from binary mask
    edges = cv2.Canny(binary_mask, 0.5, 1)
    
    # Calculate gradients of HV map
    h_map, v_map = hv_map[0], hv_map[1]
    grad_h = cv2.Sobel(h_map, cv2.CV_64F, 1, 0, ksize=3)
    grad_v = cv2.Sobel(v_map, cv2.CV_64F, 0, 1, ksize=3)
    
    # Compute edge map from HV gradients
    hv_edges = np.sqrt(grad_h**2 + grad_v**2)
    
    # Normalize HV edge map
    hv_edges = (hv_edges - hv_edges.min()) / (hv_edges.max() - hv_edges.min())
    
    # Threshold HV edge map to get potential cell separations
    hv_threshold = 0.5  # Increased threshold to focus on stronger edges
    hv_boundaries = (hv_edges > hv_threshold).astype(np.uint8)
    
    # Combine binary mask edges with HV boundaries
    cell_boundaries = np.maximum(edges, hv_boundaries)
    
    # Thin boundaries
    kernel = np.ones((3,3), np.uint8)
    cell_boundaries = cv2.morphologyEx(cell_boundaries, cv2.MORPH_OPEN, kernel, iterations=1)
    cell_boundaries = cv2.dilate(cell_boundaries, kernel, iterations=1)
    
    return cell_boundaries

# Call the function
visualize_results(model, test_loader, device, num_samples=5)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import random

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    # Convert the test_loader to a list for random sampling
    all_samples = list(test_loader)
    total_samples = len(all_samples)
    
    # Randomly select indices
    selected_indices = random.sample(range(total_samples), min(num_samples, total_samples))
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            images, binary_masks, _, _, _, _ = all_samples[idx]
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, _, _, _, _ = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {idx}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')

        plt.tight_layout()
        plt.show()
    
        # Print additional information
        print(f"\nSample {idx}:")
        print(f"Binary Mask - Unique values: GT {np.unique(binary_mask)}, Pred {np.unique(cell_seg_out_img > 0.5)}")

# Call the function
visualize_results(model, test_loader, device, num_samples=5)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 7, figsize=(35, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification
            num_classes = multi_class_mask.shape[-1]
            colors = plt.colormaps['tab10'](np.linspace(0, 1, num_classes))
            
            gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
            for class_idx in range(num_classes):
                class_mask = multi_class_mask[..., class_idx] > 0
                color = colors[class_idx][:3]
                gt_cell_class_image[class_mask] = color
            
            axes[i, 3].imshow(gt_cell_class_image)
            axes[i, 3].set_title("Ground Truth\nCell Classification")
            axes[i, 3].axis('off')
            
            # Predicted Cell Classification
            pred_cell_class_image = np.zeros((*cell_class_out_img.shape[1:], 3))
            cell_class_out_argmax = np.argmax(cell_class_out_img, axis=0)
            binary_pred = cell_seg_out_img > 0.5

            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                pred_cell_class_image[class_mask] = color

            pred_cell_class_image[~binary_pred] = [0, 0, 0]
            
            axes[i, 4].imshow(pred_cell_class_image)
            axes[i, 4].set_title("Predicted\nCell Classification")
            axes[i, 4].axis('off')
            
            # HV Map Visualization (only for detected cells)
            hv_magnitude = np.sqrt(np.sum(hv_out_img**2, axis=0))
            hv_magnitude_masked = hv_magnitude * binary_pred
            axes[i, 5].imshow(hv_magnitude_masked, cmap='viridis')
            axes[i, 5].set_title("HV Map Magnitude\n(Detected Cells Only)")
            axes[i, 5].axis('off')

            # Cell Boundaries using HV Map (only for detected cells)
            gx = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 1, 0, ksize=3)
            gy = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 0, 1, ksize=3)
            grad_mag = np.sqrt(gx**2 + gy**2)
            grad_mag = (grad_mag - grad_mag.min()) / (grad_mag.max() - grad_mag.min() + 1e-8)

            # Threshold to get boundaries
            threshold = 0.1  # Adjust this value to control boundary detection sensitivity
            boundary_mask = (grad_mag > threshold).astype(np.uint8)

            # Combine predicted cell classification with boundary overlay
            combined_image = pred_cell_class_image.copy()
            combined_image[boundary_mask > 0] = [1, 1, 1]  # Set boundary pixels to white

            # Only show boundaries where cells are detected
            combined_image[~binary_pred] = [0, 0, 0]  # Set non-cell areas to black

            axes[i, 6].imshow(combined_image)
            axes[i, 6].set_title("Predicted Classification\nwith Cell Boundaries")
            axes[i, 6].axis('off')

            # Tissue Type prediction
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 6].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 6].transAxes)

    
        plt.tight_layout()
        plt.show()
    
        # Print additional information
        print(f"\nSample {i+1}:")
        print(f"Binary Mask - Unique values: GT {np.unique(binary_mask)}, Pred {np.unique(cell_seg_out_img > 0.5)}")
        print(f"Multi-class Mask - Unique values: GT {np.unique(np.argmax(multi_class_mask, axis=-1))}, Pred {np.unique(cell_class_out_argmax[binary_pred])}")
        
        # Count cells of each type
        for j in range(num_classes):
            gt_count = np.sum(np.argmax(multi_class_mask, axis=-1) == j)
            pred_count = np.sum((cell_class_out_argmax == j) & binary_pred)
            print(f"Class {j} (GT | Pred): {gt_count} | {pred_count}")
        
        print(f"Tissue Types - GT {tissue_types[0].argmax().item()}, Pred {predicted_tissue}")
        print(f"Global Cell Labels - GT {global_cell_labels[0].argmax().item()}, Pred {global_cell_out[0].argmax().item()}")
        
        # Print prediction confidence
        for j in range(num_classes):
            class_confidence = cell_class_out_img[j]
            print(f"Class {j} Confidence:")
            print(f"  Min: {class_confidence.min():.4f}")
            print(f"  Max: {class_confidence.max():.4f}")
            print(f"  Mean: {class_confidence.mean():.4f}")
            print(f"  Std: {class_confidence.std():.4f}")

# Usage
visualize_results(model, test_loader, device, num_samples=30)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 8, figsize=(40, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification
            num_classes = multi_class_mask.shape[-1]
            colors = plt.colormaps['tab10'](np.linspace(0, 1, num_classes))
            
            gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
            for class_idx in range(num_classes):
                class_mask = multi_class_mask[..., class_idx] > 0
                color = colors[class_idx][:3]
                gt_cell_class_image[class_mask] = color
            
            axes[i, 3].imshow(gt_cell_class_image)
            axes[i, 3].set_title("Ground Truth\nCell Classification")
            axes[i, 3].axis('off')
            
            # Predicted Cell Classification
            pred_cell_class_image = np.zeros((*cell_class_out_img.shape[1:], 3))
            cell_class_out_argmax = np.argmax(cell_class_out_img, axis=0)
            binary_pred = cell_seg_out_img > 0.5

            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                pred_cell_class_image[class_mask] = color

            pred_cell_class_image[~binary_pred] = [0, 0, 0]
            
            axes[i, 4].imshow(pred_cell_class_image)
            axes[i, 4].set_title("Predicted\nCell Classification")
            axes[i, 4].axis('off')
            
            # HV Map Visualization (only for detected cells)
            hv_magnitude = np.sqrt(np.sum(hv_out_img**2, axis=0))
            hv_magnitude_masked = hv_magnitude * binary_pred
            axes[i, 5].imshow(hv_magnitude_masked, cmap='viridis')
            axes[i, 5].set_title("HV Map Magnitude\n(Detected Cells Only)")
            axes[i, 5].axis('off')

            # Cell Boundaries using HV Map (only for detected cells)
            gx = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 1, 0, ksize=3)
            gy = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 0, 1, ksize=3)
            grad_mag = np.sqrt(gx**2 + gy**2)
            grad_mag = (grad_mag - grad_mag.min()) / (grad_mag.max() - grad_mag.min() + 1e-8)

            # Threshold to get boundaries
            threshold = 0.1  # Adjust this value to control boundary detection sensitivity
            boundary_mask = (grad_mag > threshold).astype(np.uint8)

            # Combine predicted cell classification with boundary overlay
            combined_image = pred_cell_class_image.copy()
            combined_image[boundary_mask > 0] = [1, 1, 1]  # Set boundary pixels to white

            # Only show boundaries where cells are detected
            combined_image[~binary_pred] = [0, 0, 0]  # Set non-cell areas to black

            axes[i, 6].imshow(combined_image)
            axes[i, 6].set_title("Predicted Classification\nwith Cell Boundaries")
            axes[i, 6].axis('off')

            # New image: Original image with colored cell boundaries overlay
            original_with_boundaries = image.copy()
            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                
                # Create a boundary mask for this class
                class_boundary = cv2.dilate(class_mask.astype(np.uint8), np.ones((5,5), np.uint8), iterations=1) - class_mask.astype(np.uint8)
                
                # Apply colored boundary to the original image
                original_with_boundaries[class_boundary > 0] = color

            axes[i, 7].imshow(original_with_boundaries)
            axes[i, 7].set_title("Original Image with\nColored Cell Boundaries")
            axes[i, 7].axis('off')

            # Tissue Type prediction
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 7].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)

    
        plt.tight_layout()
        plt.show()

# Usage
visualize_results(model, test_loader, device, num_samples=50)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 8, figsize=(40, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification
            num_classes = multi_class_mask.shape[-1]
            colors = plt.colormaps['tab10'](np.linspace(0, 1, num_classes))
            
            gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
            for class_idx in range(num_classes):
                class_mask = multi_class_mask[..., class_idx] > 0
                color = colors[class_idx][:3]
                gt_cell_class_image[class_mask] = color
            
            axes[i, 3].imshow(gt_cell_class_image)
            axes[i, 3].set_title("Ground Truth\nCell Classification")
            axes[i, 3].axis('off')
            
            # Predicted Cell Classification
            pred_cell_class_image = np.zeros((*cell_class_out_img.shape[1:], 3))
            cell_class_out_argmax = np.argmax(cell_class_out_img, axis=0)
            binary_pred = cell_seg_out_img > 0.5

            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                pred_cell_class_image[class_mask] = color

            pred_cell_class_image[~binary_pred] = [0, 0, 0]
            
            axes[i, 4].imshow(pred_cell_class_image)
            axes[i, 4].set_title("Predicted\nCell Classification")
            axes[i, 4].axis('off')
            
            # HV Map Visualization (only for detected cells)
            hv_magnitude = np.sqrt(np.sum(hv_out_img**2, axis=0))
            hv_magnitude_masked = hv_magnitude * binary_pred
            axes[i, 5].imshow(hv_magnitude_masked, cmap='viridis')
            axes[i, 5].set_title("HV Map Magnitude\n(Detected Cells Only)")
            axes[i, 5].axis('off')

            # Cell Boundaries using HV Map (only for detected cells)
            gx = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 1, 0, ksize=3)
            gy = cv2.Sobel(hv_magnitude_masked, cv2.CV_32F, 0, 1, ksize=3)
            grad_mag = np.sqrt(gx**2 + gy**2)
            grad_mag = (grad_mag - grad_mag.min()) / (grad_mag.max() - grad_mag.min() + 1e-8)

            # Threshold to get boundaries
            threshold = 0.1  # Adjust this value to control boundary detection sensitivity
            boundary_mask = (grad_mag > threshold).astype(np.uint8)

            # Combine predicted cell classification with boundary overlay
            combined_image = pred_cell_class_image.copy()
            combined_image[boundary_mask > 0] = [1, 1, 1]  # Set boundary pixels to white

            # Only show boundaries where cells are detected
            combined_image[~binary_pred] = [0, 0, 0]  # Set non-cell areas to black

            axes[i, 6].imshow(combined_image)
            axes[i, 6].set_title("Predicted Classification\nwith Cell Boundaries")
            axes[i, 6].axis('off')

            # New image: Original image with colored cell boundaries overlay
            original_with_boundaries = image.copy()
            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                
                # Create a boundary mask for this class
                class_boundary = cv2.dilate(class_mask.astype(np.uint8), np.ones((5,5), np.uint8), iterations=1) - class_mask.astype(np.uint8)
                
                # Apply colored boundary to the original image
                original_with_boundaries[class_boundary > 0] = color

            axes[i, 7].imshow(original_with_boundaries)
            axes[i, 7].set_title("Original Image with\nColored Cell Boundaries")
            axes[i, 7].axis('off')

            # Tissue Type prediction
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 7].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)

    
        plt.tight_layout()
        plt.show()

# Usage
visualize_results(model, test_loader, device, num_samples=50)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2

def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 8, figsize=(40, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification


            # New image: Original image with colored cell boundaries overlay
            original_with_boundaries = image.copy()
            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                
                # Create a boundary mask for this class
                class_boundary = cv2.dilate(class_mask.astype(np.uint8), np.ones((5,5), np.uint8), iterations=1) - class_mask.astype(np.uint8)
                
                # Apply colored boundary to the original image
                original_with_boundaries[class_boundary > 0] = color

            axes[i, 7].imshow(original_with_boundaries)
            axes[i, 7].set_title("Original Image with\nColored Cell Boundaries")
            axes[i, 7].axis('off')

            # Tissue Type prediction
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 7].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)

    
        plt.tight_layout()
        plt.show()

# Usage
visualize_results(model, test_loader, device, num_samples=50)

## Metrics

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
from skimage import measure
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

def pair_coordinates(true_centroids, pred_centroids, pairing_radius):
    if len(true_centroids) == 0 or len(pred_centroids) == 0:
        return np.array([]), np.arange(len(true_centroids)), np.arange(len(pred_centroids))
    
    true_centroids = true_centroids[:, :2]
    pred_centroids = pred_centroids[:, :2]
    
    distances = np.linalg.norm(true_centroids[:, None] - pred_centroids[None, :], axis=-1)
    true_indices, pred_indices = linear_sum_assignment(distances)
    
    paired = []
    unpaired_true = []
    unpaired_pred = []
    
    for true_idx, pred_idx in zip(true_indices, pred_indices):
        if distances[true_idx, pred_idx] <= pairing_radius:
            paired.append((true_idx, pred_idx))
        else:
            unpaired_true.append(true_idx)
            unpaired_pred.append(pred_idx)
    
    return np.array(paired), np.array(unpaired_true), np.array(unpaired_pred)

def calculate_instance_map(binary_mask):
    return measure.label(binary_mask)

def get_centroids(instance_map):
    props = measure.regionprops(instance_map)
    centroids = np.array([prop.centroid for prop in props])
    if centroids.size == 0:
        return np.empty((0, 2))
    elif centroids.ndim == 1:
        return centroids.reshape(1, -1)
    else:
        return centroids[:, :2]

def cell_detection_scores(paired_true, paired_pred, unpaired_true, unpaired_pred):
    tp = len(paired_true)
    fp = len(unpaired_pred)
    fn = len(unpaired_true)
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return f1, precision, recall

def calculate_metrics_across_all_test_sets(model, test_loader, device, pairing_radius=12):
    model.eval()
    
    paired_all_global = []
    unpaired_true_all_global = []
    unpaired_pred_all_global = []
    
    all_true_cell_classes = []
    all_pred_cell_classes = []
    all_true_tissue_types = []
    all_pred_tissue_types = []
    all_true_global_cell_labels = []
    all_pred_global_cell_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels = batch
            images = images.to(device)
            binary_masks = binary_masks.to(device)
            multi_class_masks = multi_class_masks.to(device)
            tissue_types = tissue_types.to(device)
            global_cell_labels = global_cell_labels.to(device)
            
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Process segmentation
            pred_masks = (torch.sigmoid(cell_seg_out) > 0.5).float()
            
            for true_mask, pred_mask in zip(binary_masks.cpu().numpy(), pred_masks.cpu().numpy()):
                true_instance_map = calculate_instance_map(true_mask)
                pred_instance_map = calculate_instance_map(pred_mask)
                
                true_centroids = get_centroids(true_instance_map)
                pred_centroids = get_centroids(pred_instance_map)
                
                paired, unpaired_true, unpaired_pred = pair_coordinates(true_centroids, pred_centroids, pairing_radius)
                
                paired_all_global.extend(paired.flatten())
                unpaired_true_all_global.extend(unpaired_true.flatten())
                unpaired_pred_all_global.extend(unpaired_pred.flatten())
            
            # Process cell classification
            true_cell_classes = multi_class_masks.squeeze(1).argmax(dim=-1).cpu().numpy()
            pred_cell_classes = cell_class_out.argmax(dim=1).cpu().numpy()
            all_true_cell_classes.append(true_cell_classes)
            all_pred_cell_classes.append(pred_cell_classes)
            
            # Process tissue classification
            true_tissue_types = tissue_types.argmax(dim=1).cpu().numpy()
            pred_tissue_types = tc_out.argmax(dim=1).cpu().numpy()
            all_true_tissue_types.extend(true_tissue_types)
            all_pred_tissue_types.extend(pred_tissue_types)
            
            # Process global cell classification
            true_global_cell_labels = global_cell_labels.cpu().numpy()
            pred_global_cell_labels = (torch.sigmoid(global_cell_out) > 0.5).float().cpu().numpy()
            all_true_global_cell_labels.extend(true_global_cell_labels)
            all_pred_global_cell_labels.extend(pred_global_cell_labels)

    # Concatenate all cell classification results
    all_true_cell_classes = np.concatenate(all_true_cell_classes)
    all_pred_cell_classes = np.concatenate(all_pred_cell_classes)

    # Calculate segmentation metrics
    paired_all = np.array(paired_all_global).reshape(-1, 2)
    unpaired_true_all = np.array(unpaired_true_all_global)
    unpaired_pred_all = np.array(unpaired_pred_all_global)
    
    seg_f1, seg_precision, seg_recall = cell_detection_scores(
        paired_true=paired_all[:, 0],
        paired_pred=paired_all[:, 1],
        unpaired_true=unpaired_true_all,
        unpaired_pred=unpaired_pred_all
    )
    
    # Calculate cell classification metrics
    cell_f1, cell_precision, cell_recall, _ = precision_recall_fscore_support(
        all_true_cell_classes.flatten(), all_pred_cell_classes.flatten(), average='weighted'
    )
    
    # Calculate per-class cell classification metrics
    cell_class_metrics = precision_recall_fscore_support(
        all_true_cell_classes.flatten(), all_pred_cell_classes.flatten(), average=None
    )
    
    # Calculate tissue classification metrics
    tissue_f1, tissue_precision, tissue_recall, _ = precision_recall_fscore_support(
        all_true_tissue_types, all_pred_tissue_types, average='weighted'
    )
    
    # Calculate global cell classification metrics
    global_cell_f1, global_cell_precision, global_cell_recall, _ = precision_recall_fscore_support(
        all_true_global_cell_labels, all_pred_global_cell_labels, average='weighted'
    )
    
    print("Metrics across all test sets:")
    print(f"Segmentation - F1: {seg_f1:.4f}, Precision: {seg_precision:.4f}, Recall: {seg_recall:.4f}")
    print(f"Cell Classification - F1: {cell_f1:.4f}, Precision: {cell_precision:.4f}, Recall: {cell_recall:.4f}")
    print(f"Tissue Classification - F1: {tissue_f1:.4f}, Precision: {tissue_precision:.4f}, Recall: {tissue_recall:.4f}")
    print(f"Global Cell Classification - F1: {global_cell_f1:.4f}, Precision: {global_cell_precision:.4f}, Recall: {global_cell_recall:.4f}")
    
    print("\nPer-class Cell Classification Metrics:")
    for i, (f1, precision, recall, _) in enumerate(zip(*cell_class_metrics)):
        print(f"Class {i} - F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
    
    return {
        'segmentation': {'f1': seg_f1, 'precision': seg_precision, 'recall': seg_recall},
        'cell_classification': {'f1': cell_f1, 'precision': cell_precision, 'recall': cell_recall},
        'tissue_classification': {'f1': tissue_f1, 'precision': tissue_precision, 'recall': tissue_recall},
        'global_cell_classification': {'f1': global_cell_f1, 'precision': global_cell_precision, 'recall': global_cell_recall},
        'per_class_cell_metrics': cell_class_metrics
    }

# Usage
model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)
model.load_state_dict(torch.load("improved_cellswin_model.pth", map_location=device))
model.eval()

metrics = calculate_metrics_across_all_test_sets(model, test_loader, device)

In [None]:
def visualize_results(model, test_loader, device, num_samples=5):
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 8, figsize=(40, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for i, (images, binary_masks, multi_class_masks, hv_maps, tissue_types, global_cell_labels) in enumerate(test_loader):
            if i >= num_samples:
                break
            
            images = images.to(device)
            
            # Forward pass
            cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(images)
            
            # Move tensors to CPU and convert to numpy for visualization
            image = images[0].cpu().permute(1, 2, 0).numpy()
            binary_mask = binary_masks[0, 0].cpu().numpy()
            multi_class_mask = multi_class_masks[0].squeeze().cpu().numpy()
            cell_seg_out_img = cell_seg_out[0, 0].cpu().numpy()
            cell_class_out_img = cell_class_out[0].cpu().numpy()
            hv_out_img = hv_out[0].cpu().numpy()
            
            # Normalize image if necessary
            if image.max() > 1:
                image = image / 255.0
            
            # Original Image
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"Sample {i+1}\nOriginal Image")
            axes[i, 0].axis('off')
            
            # Ground Truth Binary Segmentation
            axes[i, 1].imshow(binary_mask, cmap='gray')
            axes[i, 1].set_title("GT Binary Segmentation")
            axes[i, 1].axis('off')
            
            # Predicted Binary Segmentation
            axes[i, 2].imshow(cell_seg_out_img > 0.5, cmap='gray')
            axes[i, 2].set_title("Predicted Binary Segmentation")
            axes[i, 2].axis('off')
            
            # Ground Truth Cell Classification
            num_classes = multi_class_mask.shape[-1]
            colors = plt.colormaps['tab10'](np.linspace(0, 1, num_classes))
            
            gt_cell_class_image = np.zeros((*multi_class_mask.shape[:2], 3))
            for class_idx in range(num_classes):
                class_mask = multi_class_mask[..., class_idx] > 0
                color = colors[class_idx][:3]
                gt_cell_class_image[class_mask] = color
            
            axes[i, 3].imshow(gt_cell_class_image)
            axes[i, 3].set_title("Ground Truth\nCell Classification")
            axes[i, 3].axis('off')
            
            # Predicted Cell Classification
            pred_cell_class_image = np.zeros((*cell_class_out_img.shape[1:], 3))
            cell_class_out_argmax = np.argmax(cell_class_out_img, axis=0)
            binary_pred = cell_seg_out_img > 0.5

            for class_idx in range(num_classes):
                class_mask = (cell_class_out_argmax == class_idx) & binary_pred
                color = colors[class_idx][:3]
                pred_cell_class_image[class_mask] = color

            pred_cell_class_image[~binary_pred] = [0, 0, 0]
            
            axes[i, 4].imshow(pred_cell_class_image)
            axes[i, 4].set_title("Predicted\nCell Classification")
            axes[i, 4].axis('off')
            
            # HV Map Visualization (only for detected cells)
            hv_magnitude = np.sqrt(np.sum(hv_out_img**2, axis=0))
            hv_magnitude_masked = hv_magnitude * binary_pred
            axes[i, 5].imshow(hv_magnitude_masked, cmap='viridis')
            axes[i, 5].set_title("HV Map Magnitude\n(Detected Cells Only)")
            axes[i, 5].axis('off')

            # Predicted Instance Segmentation
            cell_prob = torch.sigmoid(cell_seg_out[0])
            segmentation = postprocess_hv(hv_out[0], cell_prob)
            
            # Create a colormap for instance segmentation
            n_instances = len(np.unique(segmentation)) - 1  # Subtract 1 to exclude background
            instance_cmap = plt.cm.get_cmap('tab20')  # You can change this to any colormap you prefer
            instance_colors = instance_cmap(np.linspace(0, 1, n_instances))
            
            # Create instance segmentation image
            instance_seg_image = np.zeros((*segmentation.shape, 3))
            for idx, label in enumerate(np.unique(segmentation)):
                if label == 0:  # background
                    continue
                instance_seg_image[segmentation == label] = instance_colors[idx-1, :3]
            
            axes[i, 6].imshow(instance_seg_image)
            axes[i, 6].set_title("Predicted Instance\nSegmentation")
            axes[i, 6].axis('off')

            # Original image with instance segmentation overlay
            overlay_image = image.copy()
            for label in np.unique(segmentation):
                if label == 0:  # background
                    continue
                mask = segmentation == label
                boundary = cv2.dilate(mask.astype(np.uint8), np.ones((3,3), np.uint8), iterations=1) - mask.astype(np.uint8)
                overlay_image[boundary > 0] = instance_colors[label-1, :3]

            axes[i, 7].imshow(overlay_image)
            axes[i, 7].set_title("Original Image with\nInstance Segmentation")
            axes[i, 7].axis('off')

            # Tissue Type prediction
            predicted_tissue = tc_out[0].argmax().item()
            axes[i, 7].text(0.5, -0.1, f"Predicted Tissue Type: {predicted_tissue}", 
                            horizontalalignment='center', verticalalignment='center',
                            fontsize=10, fontweight='bold', transform=axes[i, 7].transAxes)

    plt.tight_layout()
    plt.show()

# Usage
visualize_results(model, test_loader, device, num_samples=10)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score
import matplotlib.pyplot as plt


def postprocess_hv(hv_map, cell_prob, prob_thresh=0.5, hv_thresh=0.2, min_distance=3):

    hv_map = hv_map.cpu().numpy()
    cell_prob = cell_prob.cpu().numpy()

    grad_h = np.abs(hv_map[0, 1:, :] - hv_map[0, :-1, :])
    grad_v = np.abs(hv_map[1, :, 1:] - hv_map[1, :, :-1])

    grad_h = np.pad(grad_h, ((1, 0), (0, 0)), mode='constant')
    grad_v = np.pad(grad_v, ((0, 0), (1, 0)), mode='constant')

    grad_combined = np.maximum(grad_h, grad_v)

    cell_mask = cell_prob > prob_thresh
    grad_mask = grad_combined > hv_thresh

    mask = cell_mask & (~grad_mask)


    # Adjust peak detection parameters
    peaks = peak_local_max(cell_prob, min_distance=min_distance, 
                           threshold_abs=prob_thresh * 0.5,  # Lower the threshold
                           exclude_border=False, labels=mask)


    if len(peaks) == 0:
        return np.zeros_like(cell_prob, dtype=np.int32).squeeze()

    markers = np.zeros_like(cell_prob, dtype=np.int32)
    markers[tuple(peaks.T)] = np.arange(1, len(peaks) + 1)

    segmentation = watershed(-cell_prob, markers, mask=mask)


    return segmentation.squeeze()

def predict_and_postprocess(model, image):
    with torch.no_grad():
        cell_seg_out, cell_class_out, tc_out, global_cell_out, hv_out = model(image.unsqueeze(0))
    
    cell_prob = torch.sigmoid(cell_seg_out)
    

    
    segmentation = postprocess_hv(hv_out[0], cell_prob[0])
    
    return segmentation, cell_class_out.squeeze(0), tc_out.squeeze(0), global_cell_out.squeeze(0)


def calculate_metrics(pred_mask, true_mask):
    pred_flat = pred_mask.flatten()
    true_flat = true_mask.flatten()
    
    # Calculate F1 score, precision, and recall
    f1 = f1_score(true_flat, pred_flat, average='weighted', zero_division=1)
    precision = precision_score(true_flat, pred_flat, average='weighted', zero_division=1)
    recall = recall_score(true_flat, pred_flat, average='weighted', zero_division=1)
    
    return f1, precision, recall

def evaluate_model(model, test_loader, device):
    model.eval()
    total_f1, total_precision, total_recall = 0, 0, 0
    num_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, binary_masks, _, _, _, _ = batch
            images = images.to(device)
            
            for i in range(images.size(0)):
                segmentation, _, _, _ = predict_and_postprocess(model, images[i])
                
                # Assuming binary_masks contains the ground truth segmentation
                true_mask = binary_masks[i].squeeze().cpu().numpy()
                
                f1, precision, recall = calculate_metrics(segmentation, true_mask)
                
                total_f1 += f1
                total_precision += precision
                total_recall += recall
                num_samples += 1

    avg_f1 = total_f1 / num_samples
    avg_precision = total_precision / num_samples
    avg_recall = total_recall / num_samples
    
    return avg_f1, avg_precision, avg_recall

def visualize_results(image, true_mask, pred_mask, save_path=None):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')
    
    ax2.imshow(true_mask, cmap='nipy_spectral')
    ax2.set_title('True Mask')
    ax2.axis('off')
    
    # Convert pred_mask to int32 if it's not already
    pred_mask = pred_mask.astype(np.int32) if pred_mask.dtype != np.int32 else pred_mask
    ax3.imshow(pred_mask, cmap='nipy_spectral')
    ax3.set_title('Predicted Mask')
    ax3.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()
    
    plt.close(fig)  # Close the figure to free up memory

# Setup and main execution
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize the model
model = ModifiedCellSwin(num_cell_classes=5, num_tissue_classes=19).to(device)

# Load the saved state dict
model.load_state_dict(torch.load("improved_cellswin_model.pth", map_location=device))

# Set the model to evaluation mode
model.eval()
print("Model loaded successfully!")

# Assuming you have a test_loader defined
# test_loader = ...

# Run evaluation
avg_f1, avg_precision, avg_recall = evaluate_model(model, test_loader, device)

print(f"Average F1 Score: {avg_f1:.4f}")
print(f"Average Precision: {avg_precision:.4f}")
print(f"Average Recall: {avg_recall:.4f}")

# Visualize some results (optional)



In [None]:
# Cell 1: Imports and Setup
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Cell 2: Define helper functions
def calculate_class_metrics(predictions, ground_truth, num_classes):
    pred_reshaped = predictions.permute(0, 2, 3, 1).reshape(-1, num_classes)
    pred_labels = torch.argmax(pred_reshaped, dim=1).cpu().numpy()
    
    true_reshaped = ground_truth.squeeze(1).reshape(-1, num_classes)
    true_labels = torch.argmax(true_reshaped, dim=1).cpu().numpy()
    
    metrics = {}
    for class_idx in range(num_classes):
        metrics[f'class_{class_idx}'] = {
            'accuracy': accuracy_score(true_labels == class_idx, pred_labels == class_idx),
            'precision': precision_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0),
            'recall': recall_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0),
            'f1_score': f1_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0)
        }
    
    return metrics

def analyze_class_distribution(predictions, ground_truth, num_classes):
    pred_reshaped = predictions.permute(0, 2, 3, 1).reshape(-1, num_classes)
    pred_labels = torch.argmax(pred_reshaped, dim=1).cpu().numpy()
    
    true_reshaped = ground_truth.squeeze(1).reshape(-1, num_classes)
    true_labels = torch.argmax(true_reshaped, dim=1).cpu().numpy()
    
    gt_distribution = {}
    pred_distribution = {}
    
    for class_idx in range(num_classes):
        gt_count = np.sum(true_labels == class_idx)
        pred_count = np.sum(pred_labels == class_idx)
        
        gt_distribution[f'class_{class_idx}'] = gt_count
        pred_distribution[f'class_{class_idx}'] = pred_count
    
    return gt_distribution, pred_distribution

def plot_class_metrics(metrics):
    class_names = list(metrics.keys())
    accuracies = [m['accuracy'] for m in metrics.values()]
    precisions = [m['precision'] for m in metrics.values()]
    recalls = [m['recall'] for m in metrics.values()]
    f1_scores = [m['f1_score'] for m in metrics.values()]

    x = np.arange(len(class_names))
    width = 0.2

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - 1.5*width, accuracies, width, label='Accuracy')
    ax.bar(x - 0.5*width, precisions, width, label='Precision')
    ax.bar(x + 0.5*width, recalls, width, label='Recall')
    ax.bar(x + 1.5*width, f1_scores, width, label='F1-score')

    ax.set_ylabel('Scores')
    ax.set_title('Class-wise Metrics')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names)
    ax.legend()

    plt.tight_layout()
    plt.show()

def plot_class_distribution(gt_distribution, pred_distribution):
    class_names = list(gt_distribution.keys())
    gt_counts = list(gt_distribution.values())
    pred_counts = list(pred_distribution.values())

    x = np.arange(len(class_names))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width/2, gt_counts, width, label='Ground Truth')
    ax.bar(x + width/2, pred_counts, width, label='Predicted')

    ax.set_ylabel('Count')
    ax.set_title('Class Distribution: Ground Truth vs Predicted')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names)
    ax.legend()

    plt.tight_layout()
    plt.show()

# Cell 3: Model and Data Setup
# Assuming your model is already defined and loaded
model = model.to(device)
model.eval()

# Assuming test_loader is already defined
num_classes = 5  # Adjust if necessary

# Cell 4: Collect predictions
all_predictions = []
all_ground_truths = []

with torch.no_grad():
    for i, (images, binary_masks, cell_classes, tissue_classes, global_cell_count, _) in enumerate(test_loader):
        if i >= 3:  # Process only 3 samples
            break
        
        images = images.to(device)
        cell_classes = cell_classes.to(device)
        
        model_output = model(images)
        class_predictions = model_output[1]  # Assuming class predictions are the second output
        
        all_predictions.append(class_predictions)
        all_ground_truths.append(cell_classes)
        
        print(f"Processed sample {i+1}")
        print(f"Class predictions shape: {class_predictions.shape}")
        print(f"Ground truth shape: {cell_classes.shape}")

# Cell 5: Calculate metrics and analyze class distribution
predictions = torch.cat(all_predictions, dim=0)
ground_truth = torch.cat(all_ground_truths, dim=0)

metrics = calculate_class_metrics(predictions, ground_truth, num_classes)
gt_distribution, pred_distribution = analyze_class_distribution(predictions, ground_truth, num_classes)

print("\nGround Truth Class Distribution:")
for class_idx, count in gt_distribution.items():
    print(f"{class_idx}: {count}")

print("\nPredicted Class Distribution:")
for class_idx, count in pred_distribution.items():
    print(f"{class_idx}: {count}")

print("\nClass-wise Metrics:")
for class_idx, class_metrics in metrics.items():
    print(f"\nMetrics for {class_idx}:")
    for metric_name, metric_value in class_metrics.items():
        print(f"  {metric_name}: {metric_value:.4f}")

# Cell 6: Visualize results
plot_class_metrics(metrics)
plot_class_distribution(gt_distribution, pred_distribution)

In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def calculate_class_metrics(predictions, ground_truth, num_classes):
    pred_reshaped = predictions.permute(0, 2, 3, 1).reshape(-1, num_classes)
    pred_labels = torch.argmax(pred_reshaped, dim=1).cpu().numpy()
    
    true_reshaped = ground_truth.squeeze(1).reshape(-1, num_classes)
    true_labels = torch.argmax(true_reshaped, dim=1).cpu().numpy()
    
    metrics = {}
    for class_idx in range(num_classes):
        metrics[f'class_{class_idx}'] = {
            'accuracy': accuracy_score(true_labels == class_idx, pred_labels == class_idx),
            'precision': precision_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0),
            'recall': recall_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0),
            'f1_score': f1_score(true_labels == class_idx, pred_labels == class_idx, zero_division=0)
        }
    
    # Calculate overall metrics
    metrics['overall'] = {
        'accuracy': accuracy_score(true_labels, pred_labels),
        'precision': precision_score(true_labels, pred_labels, average='weighted', zero_division=0),
        'recall': recall_score(true_labels, pred_labels, average='weighted', zero_division=0),
        'f1_score': f1_score(true_labels, pred_labels, average='weighted', zero_division=0)
    }
    
    return metrics

# Assuming your model is already defined and loaded
model = model.to(device)
model.eval()

# Assuming test_loader is already defined
num_classes = 5  # Adjust if necessary

all_predictions = []
all_ground_truths = []

total_batches = len(test_loader)
print(f"Processing test set ({total_batches} batches):")

with torch.no_grad():
    for i, (images, binary_masks, cell_classes, tissue_classes, global_cell_count, _) in enumerate(test_loader):
        images = images.to(device)
        cell_classes = cell_classes.to(device)
        
        model_output = model(images)
        class_predictions = model_output[1]  # Assuming class predictions are the second output
        
        all_predictions.append(class_predictions)
        all_ground_truths.append(cell_classes)
        
        if (i + 1) % 10 == 0 or (i + 1) == total_batches:
            print(f"Processed {i + 1}/{total_batches} batches", end='\r')

print("\nCalculating metrics...")

predictions = torch.cat(all_predictions, dim=0)
ground_truth = torch.cat(all_ground_truths, dim=0)

metrics = calculate_class_metrics(predictions, ground_truth, num_classes)

# Print results
print("\nClass-wise Metrics:")
for class_idx, class_metrics in metrics.items():
    if class_idx == 'overall':
        print("\nOverall Metrics:")
    else:
        print(f"\nMetrics for {class_idx}:")
    for metric_name, metric_value in class_metrics.items():
        print(f"  {metric_name}: {metric_value:.4f}")

# Print summary in the requested format
print("\nSummary:")
print("Class\tAccuracy\tPrecision\tRecall\t\tF1-score")
for class_idx, class_metrics in metrics.items():
    print(f"{class_idx}\t{class_metrics['accuracy']:.4f}\t\t{class_metrics['precision']:.4f}\t\t{class_metrics['recall']:.4f}\t\t{class_metrics['f1_score']:.4f}")