In [7]:
!pip install --upgrade pip
!pip install natsort
!pip install albumentations
!pip install wandb
!pip install torchinfo
!pip install schema
!pip install torchmetrics
!pip install einops
!pip install timm
!pip install natsort
!pip install torchsummary
!pip install natsort

Defaulting to user installation because normal site-packages is not writeable
Collecting pip
  Downloading pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-24.3.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hInstalling collected packages: pip
Successfully installed pip-24.3.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Collecting albumentations
  Downloading albumentations-1.4.21-py3-none-any.whl.metadata (31 kB)
Collecting pydantic>=2.7.0 (from albumentations)
  Downloading pydantic-2.9.2-py3-n

In [5]:
import os
import sys
from pathlib import Path

# Add CellViT to python path first
cellvit_path = "/rsrch5/home/plm/yshokrollahi/CellViT"
if cellvit_path not in sys.path:
    sys.path.append(cellvit_path)
    print(f"Added {cellvit_path} to Python path")

import yaml
import torch
from torch.utils.data import DataLoader
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np

try:
    from cell_segmentation.datasets.pannuke import PanNukeDataset
    from cell_segmentation.datasets.tissuenet import TissueNetDataset
    print("Successfully imported CellViT modules")
except ImportError as e:
    print(f"Import error: {e}")
    raise

class MultiModalDatasetManager:
    def __init__(self):
        # Dataset paths
        self.tissuenet_path = Path("/rsrch5/home/plm/yshokrollahi/CellViT/configs/datasets/tissuenet")
        self.pannuke_path = Path("/rsrch5/home/plm/yshokrollahi/CellViT/configs/datasets/reassemble")
        
        # Config paths
        self.tissuenet_config_path = Path("/rsrch5/home/plm/yshokrollahi/CellViT/configs/examples/cell_segmentation/vitaminp-tissuenet.yaml")
        self.pannuke_config_path = Path("/rsrch5/home/plm/yshokrollahi/CellViT/configs/examples/cell_segmentation/pannuke-vitaminp.yaml")
        
        self._load_configs()

    def _load_configs(self):
        try:
            with open(self.tissuenet_config_path, 'r') as file:
                self.tissuenet_config = yaml.safe_load(file)
            with open(self.pannuke_config_path, 'r') as file:
                self.pannuke_config = yaml.safe_load(file)
            print("Successfully loaded configuration files")
        except Exception as e:
            print(f"Error loading configs: {e}")
            raise

    def get_transforms(self, transform_settings, input_shape=256):
        transforms = []
        
        if input_shape != 256:
            transforms.append(A.Resize(input_shape, input_shape))
        
        transforms.extend([
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Downscale(scale_min=0.5, scale_max=1.0, p=0.15),
            A.Blur(blur_limit=9, p=0.2),
            A.GaussNoise(var_limit=50, p=0.25),
            A.ElasticTransform(p=0.2),
        ])
        
        if 'normalize' in transform_settings:
            transforms.append(A.Normalize(**transform_settings['normalize']))
        
        return A.Compose(transforms)

    def create_dataloader(self, dataset, config, is_train=False):
        if len(dataset) == 0:
            raise ValueError(f"Dataset is empty!")
        
        batch_size = config['training']['batch_size']
        return DataLoader(dataset, 
                         batch_size=batch_size,
                         shuffle=is_train,
                         num_workers=16,
                         pin_memory=True)

    def setup_tissuenet(self):
        print("\nSetting up TissueNet datasets...")
        
        # Create transforms
        train_transforms = self.get_transforms(self.tissuenet_config['transformations'])
        val_transforms = A.Compose([A.Normalize(**self.tissuenet_config['transformations']['normalize'])])
        
        # Initialize datasets directly using TissueNetDataset
        datasets = {
            'train': TissueNetDataset(
                dataset_path=self.tissuenet_path,
                split='train',
                transforms=train_transforms,
                stardist=False,
                regression=False,
                cache_dataset=False
            ),
            'val': TissueNetDataset(
                dataset_path=self.tissuenet_path,
                split='val',
                transforms=val_transforms,
                stardist=False,
                regression=False,
                cache_dataset=False
            ),
            'test': TissueNetDataset(
                dataset_path=self.tissuenet_path,
                split='test',
                transforms=val_transforms,
                stardist=False,
                regression=False,
                cache_dataset=False
            )
        }
        
        # Create dataloaders
        dataloaders = {}
        for split in datasets:
            dataloaders[split] = self.create_dataloader(
                datasets[split],
                self.tissuenet_config,
                is_train=(split == 'train')
            )
            print(f"Created {split} dataset with {len(datasets[split])} samples")
            
        return datasets, dataloaders

    def setup_pannuke(self):
        print("\nSetting up PanNuke datasets...")
        
        datasets = {
            'train': PanNukeDataset(
                dataset_path=self.pannuke_path,
                folds=[0, 1],  # Use first two folds for training
                transforms=None,
                stardist=False,
                regression=False,
                cache_dataset=False
            ),
            'val': PanNukeDataset(
                dataset_path=self.pannuke_path,
                folds=[2],  # Use last fold for validation
                transforms=None,
                stardist=False,
                regression=False,
                cache_dataset=False
            )
        }
        
        dataloaders = {}
        for split in datasets:
            dataloaders[split] = self.create_dataloader(
                datasets[split],
                self.pannuke_config,
                is_train=(split == 'train')
            )
            print(f"Created {split} dataset with {len(datasets[split])} samples")
            
        return datasets, dataloaders

    def verify_datasets(self):
        print("\nVerifying dataset paths and configurations...")
        for path, name in [(self.pannuke_path, "PanNuke"), 
                          (self.tissuenet_path, "TissueNet")]:
            print(f"\nChecking {name} dataset:")
            print(f"Path exists: {path.exists()}")
            print(f"dataset_config.yaml exists: {(path / 'dataset_config.yaml').exists()}")
            print(f"weight_config.yaml exists: {(path / 'weight_config.yaml').exists()}")

# Test the setup
if __name__ == "__main__":
    print("\nInitializing dataset manager...")
    try:
        data_manager = MultiModalDatasetManager()
        data_manager.verify_datasets()
        
        print("\nSetting up TissueNet...")
        tissuenet_datasets, tissuenet_loaders = data_manager.setup_tissuenet()
        
        print("\nSetting up PanNuke...")
        pannuke_datasets, pannuke_loaders = data_manager.setup_pannuke()
        
        print("\nSetup successful!")
    except Exception as e:
        print(f"\nError during setup: {e}")
        raise

Successfully imported CellViT modules

Initializing dataset manager...
Successfully loaded configuration files

Verifying dataset paths and configurations...

Checking PanNuke dataset:
Path exists: True
dataset_config.yaml exists: True
weight_config.yaml exists: True

Checking TissueNet dataset:
Path exists: True
dataset_config.yaml exists: False
weight_config.yaml exists: False

Setting up TissueNet...

Setting up TissueNet datasets...
Created train dataset with 10320 samples
Created val dataset with 3118 samples
Created test dataset with 1324 samples

Setting up PanNuke...

Setting up PanNuke datasets...
Created train dataset with 5179 samples
Created val dataset with 2722 samples

Setup successful!


In [12]:
# def visualize_tissuenet_sample(dataset, title="TissueNet Sample"):
#     # Get random sample
#     random_idx = np.random.randint(0, len(dataset))
#     img, masks, img_name = dataset[random_idx]
    
#     plt.figure(figsize=(20, 8))
    
#     # First row: Cell-related visualizations
#     plt.subplot(251)
#     plt.imshow(img[0].numpy(), cmap='gray')
#     plt.title('Channel 1')
#     plt.axis('off')
    
#     plt.subplot(252)
#     plt.imshow(img[1].numpy(), cmap='gray')
#     plt.title('Channel 2')
#     plt.axis('off')
    
#     plt.subplot(253)
#     plt.imshow(masks['cell_mask'], cmap='nipy_spectral')
#     plt.title('Cell Instances')
#     plt.axis('off')
    
#     plt.subplot(254)
#     plt.imshow(masks['cell_hv_map'][0], cmap='coolwarm')
#     plt.title('Cell HV (H)')
#     plt.axis('off')
    
#     plt.subplot(255)
#     plt.imshow(masks['cell_hv_map'][1], cmap='coolwarm')
#     plt.title('Cell HV (V)')
#     plt.axis('off')
    
#     # Second row: Combined view and nuclei-related visualizations
#     plt.subplot(256)
#     combined_img = np.stack([
#         np.zeros_like(img[0].numpy()),
#         img[1].numpy(),
#         img[0].numpy()
#     ], axis=2)
#     if combined_img.max() > 1:
#         combined_img = combined_img / combined_img.max()
#     plt.imshow(combined_img)
#     plt.title('Combined')
#     plt.axis('off')
    
#     plt.subplot(257)
#     plt.imshow(masks['nuclei_mask'], cmap='nipy_spectral')
#     plt.title('Nuclei Instances')
#     plt.axis('off')
    
#     plt.subplot(258)
#     plt.imshow(masks['nuclei_hv_map'][0], cmap='coolwarm')
#     plt.title('Nuclei HV (H)')
#     plt.axis('off')
    
#     plt.subplot(259)
#     plt.imshow(masks['nuclei_hv_map'][1], cmap='coolwarm')
#     plt.title('Nuclei HV (V)')
#     plt.axis('off')
    
#     plt.subplot(2,5,10)
#     nuclei_hv_magnitude = np.sqrt(masks['nuclei_hv_map'][0]**2 + masks['nuclei_hv_map'][1]**2)
#     plt.imshow(nuclei_hv_magnitude, cmap='viridis')
#     plt.title('Nuclei HV Magnitude')
#     plt.axis('off')
    
#     plt.suptitle(f"{title} - {img_name}")
#     plt.tight_layout()
#     plt.show()

# def visualize_pannuke_sample(dataset, title="PanNuke Sample"):
#     # Get random sample
#     random_idx = np.random.randint(0, len(dataset))
#     img, masks, tissue_type, img_name = dataset[random_idx]
    
#     plt.figure(figsize=(15, 5))
    
#     # Original image
#     plt.subplot(131)
#     plt.imshow(img.permute(1, 2, 0))
#     plt.title(f'Original Image\nTissue: {tissue_type}')
#     plt.axis('off')
    
#     # Instance map
#     plt.subplot(132)
#     plt.imshow(masks['instance_map'], cmap='nipy_spectral')
#     plt.title('Instance Map')
#     plt.axis('off')
    
#     # Nuclei type map
#     plt.subplot(133)
#     plt.imshow(masks['nuclei_type_map'], cmap='tab20')
#     plt.title('Nuclei Type Map')
#     plt.axis('off')
    
#     plt.suptitle(f"{title} - {img_name}")
#     plt.tight_layout()
#     plt.show()

# # Test visualization with the dataset manager
# if __name__ == "__main__":
#     print("\nInitializing dataset manager...")
#     data_manager = MultiModalDatasetManager()
    
#     # Setup datasets
#     tissuenet_datasets, tissuenet_loaders = data_manager.setup_tissuenet()
#     pannuke_datasets, pannuke_loaders = data_manager.setup_pannuke()
    
#     print("\nVisualizing TissueNet samples:")
#     for split in ['train', 'val', 'test']:
#         print(f"\nVisualizing {split} split:")
#         visualize_tissuenet_sample(tissuenet_datasets[split], f"TissueNet {split.capitalize()}")
    
#     print("\nVisualizing PanNuke samples:")
#     for split in ['train', 'val']:
#         print(f"\nVisualizing {split} split:")
#         visualize_pannuke_sample(pannuke_datasets[split], f"PanNuke {split.capitalize()}")

## model

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
from typing import List, Tuple, Literal, OrderedDict
import numpy as np

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class Conv2DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2DBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x))) + 1e-5 * torch.sum(torch.pow(self.conv.weight, 2))

class Deconv2DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super(Deconv2DBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.deconv(x)))


class SwinEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        if pretrained:
            weights = Swin_V2_B_Weights.IMAGENET1K_V1
        else:
            weights = None
        self.swin = swin_v2_b(weights=weights)
        self.swin.head = nn.Identity()  

        # Corrected channel dimensions for extra processing layers
        self.extra_processing = nn.ModuleList([
            nn.Sequential(
                Conv2DBlock(256, 256),  # Changed from 1024 to match input channels
                SEBlock(256),
            ),
            nn.Sequential(
                Conv2DBlock(512, 512),
                SEBlock(512),
            ),
            nn.Sequential(
                Conv2DBlock(1024, 1024),
                SEBlock(1024),
            )
        ])

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.swin.features):
            x = layer(x)
            if i in [2, 4, 6, 7]:  
                curr_x = x.permute(0, 3, 1, 2)  # Change from [B, H, W, C] to [B, C, H, W]
                if len(features) < len(self.extra_processing):
                    curr_x = self.extra_processing[len(features)](curr_x)
                features.append(curr_x)
        return features

class FeaturePyramidNetwork(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(FeaturePyramidNetwork, self).__init__()
        self.inner_blocks = nn.ModuleList()
        self.layer_blocks = nn.ModuleList()
        self.extra_blocks = nn.ModuleList()  # New extra processing blocks

        for in_channels in in_channels_list:
            inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
            layer_block_module = nn.Sequential(
                Conv2DBlock(out_channels, out_channels),
                Conv2DBlock(out_channels, out_channels),
                SEBlock(out_channels)
            )
            extra_block = nn.Sequential(
                Conv2DBlock(out_channels, out_channels),
                SEBlock(out_channels)
            )
            
            self.inner_blocks.append(inner_block_module)
            self.layer_blocks.append(layer_block_module)
            self.extra_blocks.append(extra_block)

    def forward(self, x):
        results = []
        
        last_inner = self.inner_blocks[-1](x[-1])
        last_inner = self.extra_blocks[-1](last_inner)  # Extra processing
        results.append(self.layer_blocks[-1](last_inner))

        for feature, inner_block, layer_block, extra_block in zip(
            x[:-1][::-1], 
            self.inner_blocks[:-1][::-1], 
            self.layer_blocks[:-1][::-1],
            self.extra_blocks[:-1][::-1]
        ):
            if last_inner.shape[-2:] != feature.shape[-2:]:
                inner_top_down = F.interpolate(last_inner, size=feature.shape[-2:], mode="nearest")
            else:
                inner_top_down = last_inner
                
            inner_lateral = inner_block(feature)
            last_inner = inner_lateral + inner_top_down
            last_inner = extra_block(last_inner)  # Extra processing
            results.insert(0, layer_block(last_inner))

        return results

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = Conv2DBlock(F_g, F_int, kernel_size=1, padding=0)
        self.W_x = Conv2DBlock(F_l, F_int, kernel_size=1, padding=0)
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
        self.residual = Conv2DBlock(F_g, F_l, kernel_size=1, padding=0)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        
        if g1.shape[2:] != x1.shape[2:]:
            g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False)
        
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = x * psi
        out = self.residual(out)
        return out + x


class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False)
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False)
        self.conv4 = nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv5 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.conv_out = nn.Conv2d(5 * out_channels, out_channels, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        feat5 = self.conv5(self.pool(x))
        feat5 = nn.functional.interpolate(feat5, size=x.shape[2:], mode='bilinear', align_corners=False)
        out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
        out = self.conv_out(out)
        out = self.bn(out)
        return self.relu(out)
class GlobalContextBlock(nn.Module):
    def __init__(self, inplanes, planes, pooling_type='att'):
        super(GlobalContextBlock, self).__init__()
        self.inplanes = inplanes
        self.planes = planes
        self.pooling_type = pooling_type

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.conv_in = nn.Conv2d(inplanes, planes, kernel_size=1)
        self.conv_out = nn.Conv2d(planes, inplanes, kernel_size=1)
        
        # Replace BatchNorm with LayerNorm
        self.ln_in = nn.LayerNorm([planes, 1, 1])
        self.ln_out = nn.LayerNorm([inplanes, 1, 1])

        self.relu = nn.ReLU(inplace=True)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            input_x = input_x.view(batch, channel, height * width)
            input_x = input_x.unsqueeze(1)
            context_mask = self.conv_mask(x)
            context_mask = context_mask.view(batch, 1, height * width)
            context_mask = self.softmax(context_mask)
            context_mask = context_mask.unsqueeze(-1)
            context = torch.matmul(input_x, context_mask)
            context = context.view(batch, channel, 1, 1)
        else:
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        context = self.spatial_pool(x)
        out = self.conv_in(context)
        out = self.ln_in(out)  # Use LayerNorm instead of BatchNorm
        out = self.relu(out)
        out = self.conv_out(out)
        out = self.ln_out(out)  # Use LayerNorm instead of BatchNorm
        
        return x * out.expand_as(x)

class ImprovedDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super(ImprovedDecoder, self).__init__()
        self.aspp = ASPP(in_channels, 256)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Enhanced conv blocks with residual connections
        self.conv_blocks = nn.ModuleList([
            nn.Sequential(
                self._make_dense_block(256 + in_channels, 128),
                SEBlock(128),
                self.dropout
            ),
            nn.Sequential(
                self._make_dense_block(128 + in_channels, 64),
                SEBlock(64),
                self.dropout
            ),
            nn.Sequential(
                self._make_dense_block(64 + in_channels, 32),
                SEBlock(32),
                self.dropout
            ),
        ])
        
        # Additional processing path
        self.extra_processing = nn.ModuleList([
            Conv2DBlock(128, 128),
            Conv2DBlock(64, 64),
            Conv2DBlock(32, 32),
        ])
        
        self.final_conv = nn.Sequential(
            Conv2DBlock(32, 32),
            nn.Conv2d(32, out_channels, kernel_size=1)
        )
        self.final_upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False)

    def _make_dense_block(self, in_ch, out_ch):
        return nn.Sequential(
            Conv2DBlock(in_ch, out_ch),
            Conv2DBlock(out_ch, out_ch),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, features):
        x = self.aspp(features[-1])
        x = self.dropout(x)
        
        for i, (feature, conv, extra) in enumerate(zip(
            features[-2::-1], 
            self.conv_blocks, 
            self.extra_processing
        )):
            x = F.interpolate(x, size=feature.shape[2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, feature], dim=1)
            x = conv(x)
            x = x + extra(x)  # Residual connection with extra processing
        
        x = self.final_conv(x)
        x = self.final_upsample(x)
        return x




In [8]:
class UnifiedImprovedCellSwin(nn.Module):
    def __init__(self, num_nuclei_classes=6, input_channels=3):
        super(UnifiedImprovedCellSwin, self).__init__()
        self.encoder = SwinEncoder(pretrained=True)
        
        # Modify first conv layer for flexible input channels
        original_weight = self.encoder.swin.features[0][0].weight
        self.encoder.swin.features[0][0] = nn.Conv2d(input_channels, 128, kernel_size=(4, 4), stride=(4, 4))
        if input_channels <= 3:  # For 2 or 3 channel input
            with torch.no_grad():
                self.encoder.swin.features[0][0].weight = nn.Parameter(
                    original_weight[:, :input_channels, :, :].clone()
                )
        
        # Feature Pyramid Network
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[256, 512, 1024, 1024],
            out_channels=256
        )
        
        # Multiple decoders for different tasks
        self.cell_binary_decoder = ImprovedDecoder(256, 2)  # [background, cell]
        self.nuclei_binary_decoder = ImprovedDecoder(256, 2)  # [background, nucleus]
        self.cell_hv_decoder = ImprovedDecoder(256, 2)  # Horizontal/Vertical maps for cells
        self.nuclei_hv_decoder = ImprovedDecoder(256, 2)  # Horizontal/Vertical maps for nuclei
        self.nuclei_type_decoder = ImprovedDecoder(256, num_nuclei_classes)  # Nuclei classification
        
        # Global context for feature enhancement
        self.global_context = nn.Sequential(
            GlobalContextBlock(1024, 256),
            SEBlock(1024)
        )
        
        # Tissue classifier
        self.tissue_classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_nuclei_classes)
        )
        
        self.num_nuclei_classes = num_nuclei_classes

    def forward(self, x):
        # Encoder features
        features = self.encoder(x)
        
        # FPN features
        fpn_features = self.fpn(features)
        
        # Global context
        global_feature = self.global_context(features[-1])
        
        # Task-specific outputs
        cell_binary = self.cell_binary_decoder(fpn_features)
        nuclei_binary = self.nuclei_binary_decoder(fpn_features)
        cell_hv = self.cell_hv_decoder(fpn_features)
        nuclei_hv = self.nuclei_hv_decoder(fpn_features)
        nuclei_types = self.nuclei_type_decoder(fpn_features)
        tissue_types = self.tissue_classifier(global_feature)

        out_dict = {
            # Cell-level outputs
            "cell_binary_map": cell_binary,  # [B, 2, H, W]
            "cell_hv_map": cell_hv,  # [B, 2, H, W]
            
            # Nuclei-level outputs
            "nuclei_binary_map": nuclei_binary,  # [B, 2, H, W]
            "nuclei_hv_map": nuclei_hv,  # [B, 2, H, W]
            "nuclei_type_map": nuclei_types,  # [B, num_nuclei_classes, H, W]
            
            # Tissue-level output
            "tissue_types": tissue_types  # [B, num_nuclei_classes]
        }

        return out_dict
    
    def calculate_instance_map(self, predictions, magnification=40):
        """Post-processing to get instance segmentation maps"""
        # Implementation as in your original code
        predictions_ = predictions.copy()
        predictions_["nuclei_type_map"] = predictions_["nuclei_type_map"].permute(0, 2, 3, 1)
        predictions_["nuclei_binary_map"] = predictions_["nuclei_binary_map"].permute(0, 2, 3, 1)
        predictions_["nuclei_hv_map"] = predictions_["nuclei_hv_map"].permute(0, 2, 3, 1)

        instance_preds = []
        type_preds = []

        # Process each image in batch
        for i in range(predictions_["nuclei_binary_map"].shape[0]):
            pred_map = np.concatenate([
                torch.argmax(predictions_["nuclei_type_map"], dim=-1)[i].detach().cpu()[..., None],
                torch.argmax(predictions_["nuclei_binary_map"], dim=-1)[i].detach().cpu()[..., None],
                predictions_["nuclei_hv_map"][i].detach().cpu(),
            ], axis=-1)
            
            # You'll need to implement DetectionCellPostProcessor
            instance_pred = self.post_process_cell_segmentation(pred_map)
            instance_preds.append(instance_pred[0])
            type_preds.append(instance_pred[1])

        return torch.Tensor(np.stack(instance_preds)), type_preds

    def freeze_encoder(self):
        """Freeze encoder parameters"""
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        """Unfreeze encoder parameters"""
        for param in self.encoder.parameters():
            param.requires_grad = True

In [9]:
# For MIF (2-channel input)
mif_model = UnifiedImprovedCellSwin(input_channels=2)
mif_input = torch.randn(1, 2, 256, 256)
mif_output = mif_model(mif_input)

# For H&E (3-channel input)
he_model = UnifiedImprovedCellSwin(input_channels=3)
he_input = torch.randn(1, 3, 256, 256)
he_output = he_model(he_input)

# Check outputs
for key, value in mif_output.items():
    print(f"{key}: {value.shape}")

Downloading: "https://download.pytorch.org/models/swin_v2_b-781e5279.pth" to /home/yshokrollahi/.cache/torch/hub/checkpoints/swin_v2_b-781e5279.pth
100%|██████████| 336M/336M [00:06<00:00, 51.9MB/s] 


cell_binary_map: torch.Size([1, 2, 256, 256])
cell_hv_map: torch.Size([1, 2, 256, 256])
nuclei_binary_map: torch.Size([1, 2, 256, 256])
nuclei_hv_map: torch.Size([1, 2, 256, 256])
nuclei_type_map: torch.Size([1, 6, 256, 256])
tissue_types: torch.Size([1, 6])


In [10]:
class MultiModalGatingNetwork(nn.Module):
    def __init__(self, feature_dim=128):
        super().__init__()
        
        # Feature extractors for each modality
        self.mif_feature_extractor = nn.Sequential(
            Conv2DBlock(2, 64),  # MIF has 2 channels
            SEBlock(64),
            Conv2DBlock(64, feature_dim),
            SEBlock(feature_dim)
        )
        
        self.he_feature_extractor = nn.Sequential(
            Conv2DBlock(3, 64),  # H&E has 3 channels
            SEBlock(64),
            Conv2DBlock(64, feature_dim),
            SEBlock(feature_dim)
        )
        
        # Global context for better feature representation
        self.global_context = GlobalContextBlock(feature_dim, feature_dim//2)
        
        # Dynamic gating mechanism
        self.gate_generator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(feature_dim, feature_dim//2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(feature_dim//2, 2),  # 2 gates for 2 modalities
            nn.Softmax(dim=1)
        )
        
    def forward(self, x, modality_type):
        if modality_type == 'mif':
            features = self.mif_feature_extractor(x)
        else:  # he
            features = self.he_feature_extractor(x)
            
        # Apply global context
        features = self.global_context(features)
        
        # Generate gating weights
        gates = self.gate_generator(features)
        
        return features, gates

class MultiModalExpertNetwork(nn.Module):
    def __init__(self, num_nuclei_classes=6):
        super().__init__()
        
        # Gating network
        self.gating_network = MultiModalGatingNetwork()
        
        # Expert networks for each modality
        self.mif_expert = UnifiedImprovedCellSwin(num_nuclei_classes=num_nuclei_classes, 
                                                 input_channels=2)  # MIF expert
        self.he_expert = UnifiedImprovedCellSwin(num_nuclei_classes=num_nuclei_classes, 
                                                input_channels=3)   # H&E expert
        
        # Output fusion layers
        self.fusion_layers = nn.ModuleDict({
            'binary_fusion': nn.Sequential(
                Conv2DBlock(4, 2),  # Combine binary maps
                SEBlock(2)
            ),
            'hv_fusion': nn.Sequential(
                Conv2DBlock(4, 2),  # Combine HV maps
                SEBlock(2)
            ),
            'type_fusion': nn.Sequential(
                Conv2DBlock(num_nuclei_classes * 2, num_nuclei_classes),  # Combine type predictions
                SEBlock(num_nuclei_classes)
            )
        })

    def forward(self, x, modality_type):
        # Get gating weights
        features, gates = self.gating_network(x, modality_type)
        
        if modality_type == 'mif':
            # MIF expert prediction
            mif_output = self.mif_expert(x)
            expert_output = mif_output
        else:
            # H&E expert prediction
            he_output = self.he_expert(x)
            expert_output = he_output

        # Apply gating weights
        gated_output = {}
        for key, value in expert_output.items():
            if isinstance(value, torch.Tensor):
                if key == 'tissue_types':
                    gated_output[key] = value  # Don't apply gating to tissue types
                else:
                    # Apply respective gate weight
                    gate_weight = gates[:, 0] if modality_type == 'mif' else gates[:, 1]
                    gated_output[key] = value * gate_weight.view(-1, 1, 1, 1)

        # Add gates to output for monitoring
        gated_output['gates'] = gates
        
        return gated_output

    def freeze_experts(self):
        """Freeze both expert networks"""
        for param in self.mif_expert.parameters():
            param.requires_grad = False
        for param in self.he_expert.parameters():
            param.requires_grad = False
    
    def unfreeze_experts(self):
        """Unfreeze both expert networks"""
        for param in self.mif_expert.parameters():
            param.requires_grad = True
        for param in self.he_expert.parameters():
            param.requires_grad = True

# Test the multi-modal network
if __name__ == "__main__":
    # Initialize model
    model = MultiModalExpertNetwork().cuda()
    
    # Test MIF input
    mif_input = torch.randn(1, 2, 256, 256).cuda()
    print("\nTesting MIF input:")
    mif_output = model(mif_input, 'mif')
    for key, value in mif_output.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.shape}")
    
    # Test H&E input
    he_input = torch.randn(1, 3, 256, 256).cuda()
    print("\nTesting H&E input:")
    he_output = model(he_input, 'he')
    for key, value in he_output.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: {value.shape}")




Testing MIF input:
cell_binary_map: torch.Size([1, 2, 256, 256])
cell_hv_map: torch.Size([1, 2, 256, 256])
nuclei_binary_map: torch.Size([1, 2, 256, 256])
nuclei_hv_map: torch.Size([1, 2, 256, 256])
nuclei_type_map: torch.Size([1, 6, 256, 256])
tissue_types: torch.Size([1, 6])
gates: torch.Size([1, 2])

Testing H&E input:
cell_binary_map: torch.Size([1, 2, 256, 256])
cell_hv_map: torch.Size([1, 2, 256, 256])
nuclei_binary_map: torch.Size([1, 2, 256, 256])
nuclei_hv_map: torch.Size([1, 2, 256, 256])
nuclei_type_map: torch.Size([1, 6, 256, 256])
tissue_types: torch.Size([1, 6])
gates: torch.Size([1, 2])


In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
from pathlib import Path

class MultiModalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.binary_loss = nn.CrossEntropyLoss()
        self.hv_loss = nn.MSELoss()
        self.type_loss = nn.CrossEntropyLoss()
        self.tissue_loss = nn.CrossEntropyLoss()
        self.gate_regularization = 0.1

    def forward(self, predictions, targets, modality_type):
        total_loss = 0
        losses = {}
        
        # Process only available targets for each modality
        if modality_type == 'mif':
            if 'cell_binary_map' in predictions and 'cell_binary_map' in targets:
                losses['cell_binary_loss'] = self.binary_loss(
                    predictions['cell_binary_map'],
                    targets['cell_binary_map'].long()  # Ensure target is long type
                )
                total_loss += losses['cell_binary_loss']
            
            if 'cell_hv_map' in predictions and 'cell_hv_map' in targets:
                losses['cell_hv_loss'] = self.hv_loss(
                    predictions['cell_hv_map'],
                    targets['cell_hv_map']
                )
                total_loss += losses['cell_hv_loss']
        
        # Common losses for both modalities
        if 'nuclei_binary_map' in predictions and 'nuclei_binary_map' in targets:
            losses['nuclei_binary_loss'] = self.binary_loss(
                predictions['nuclei_binary_map'],
                targets['nuclei_binary_map'].long()
            )
            total_loss += losses['nuclei_binary_loss']
        
        if 'nuclei_hv_map' in predictions and 'nuclei_hv_map' in targets:
            losses['nuclei_hv_loss'] = self.hv_loss(
                predictions['nuclei_hv_map'],
                targets['nuclei_hv_map']
            )
            total_loss += losses['nuclei_hv_loss']
        
        # H&E specific losses
        if modality_type == 'he':
            if 'nuclei_type_map' in predictions and 'nuclei_type_map' in targets:
                losses['type_loss'] = self.type_loss(
                    predictions['nuclei_type_map'],
                    targets['nuclei_type_map'].long()
                )
                total_loss += losses['type_loss']
            
            if 'tissue_types' in predictions and 'tissue_types' in targets:
                losses['tissue_loss'] = self.tissue_loss(
                    predictions['tissue_types'],
                    targets['tissue_types'].long()
                )
                total_loss += losses['tissue_loss']
        
        if 'gates' in predictions:
            gates = predictions['gates']
            losses['gate_loss'] = self.gate_regularization * torch.mean(
                (gates - 0.5) ** 2
            )
            total_loss += losses['gate_loss']
        
        return total_loss, losses

class CombinedDataLoader:
    def __init__(self, mif_loader, he_loader):
        self.mif_loader = mif_loader
        self.he_loader = he_loader
        self.mif_iterator = iter(self.mif_loader)
        self.he_iterator = iter(self.he_loader)
        self.length = min(len(self.mif_loader), len(self.he_loader))

    def __len__(self):
        return self.length

    def __iter__(self):
        self.mif_iterator = iter(self.mif_loader)
        self.he_iterator = iter(self.he_loader)
        return self

    def __next__(self):
        try:
            mif_batch = next(self.mif_iterator)  # List of length 3
        except StopIteration:
            self.mif_iterator = iter(self.mif_loader)
            mif_batch = next(self.mif_iterator)

        try:
            he_batch = next(self.he_iterator)  # List of length 4
        except StopIteration:
            self.he_iterator = iter(self.he_loader)
            he_batch = next(self.he_iterator)

        return {
            'mif': {
                'image': mif_batch[0],          # Input image
                'cell_mask': mif_batch[1],      # Cell mask
                'nuclei_mask': mif_batch[2],    # Nuclei mask
            },
            'he': {
                'image': he_batch[0],           # Input image
                'nuclei_mask': he_batch[1],     # Nuclei mask
                'type_map': he_batch[2],        # Type map
                'tissue_type': he_batch[3]      # Tissue type
            }
        }

def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_loader):
        # Process MIF data
        mif_data = batch['mif']['image'].to(device)
        mif_targets = {
            'cell_binary_map': batch['mif']['cell_mask'].to(device),
            'nuclei_binary_map': batch['mif']['nuclei_mask'].to(device)
        }
        
        # Process H&E data
        he_data = batch['he']['image'].to(device)
        he_targets = {
            'nuclei_binary_map': batch['he']['nuclei_mask'].to(device),
            'nuclei_type_map': batch['he']['type_map'].to(device),
            'tissue_types': batch['he']['tissue_type'].to(device)
        }
        
        optimizer.zero_grad()
        
        # Forward passes
        mif_output = model(mif_data, 'mif')
        he_output = model(he_data, 'he')
        
        # Calculate losses
        mif_loss, mif_losses = criterion(mif_output, mif_targets, 'mif')
        he_loss, he_losses = criterion(he_output, he_targets, 'he')
        
        loss = mif_loss + he_loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch [{batch_idx}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, '
                  f'MIF Loss: {mif_loss.item():.4f}, '
                  f'H&E Loss: {he_loss.item():.4f}')
    
    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in val_loader:
            # Process MIF data
            mif_data = batch['mif']['image'].to(device)
            mif_targets = {
                'cell_binary_map': batch['mif']['cell_mask'].to(device),
                'nuclei_binary_map': batch['mif']['nuclei_mask'].to(device)
            }
            
            # Process H&E data
            he_data = batch['he']['image'].to(device)
            he_targets = {
                'nuclei_binary_map': batch['he']['nuclei_mask'].to(device),
                'nuclei_type_map': batch['he']['type_map'].to(device),
                'tissue_types': batch['he']['tissue_type'].to(device)
            }
            
            mif_output = model(mif_data, 'mif')
            he_output = model(he_data, 'he')
            
            mif_loss, _ = criterion(mif_output, mif_targets, 'mif')
            he_loss, _ = criterion(he_output, he_targets, 'he')
            
            total_loss += (mif_loss + he_loss).item()
    
    return total_loss / len(val_loader)

def main():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize dataset manager
    data_manager = MultiModalDatasetManager()
    
    # Setup datasets and dataloaders
    tissuenet_datasets, tissuenet_loaders = data_manager.setup_tissuenet()
    pannuke_datasets, pannuke_loaders = data_manager.setup_pannuke()
    
    # Create combined loaders
    train_loader = CombinedDataLoader(
        tissuenet_loaders['train'],
        pannuke_loaders['train']
    )
    val_loader = CombinedDataLoader(
        tissuenet_loaders['val'],
        pannuke_loaders['val']
    )
    
    # Initialize model
    model = MultiModalExpertNetwork().to(device)
    
    # Setup training
    criterion = MultiModalLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    num_epochs = 100
    best_val_loss = float('inf')
    save_dir = Path('./checkpoints')
    save_dir.mkdir(exist_ok=True)
    
    # Training loop
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss = validate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        
        # Save checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = save_dir / 'best_model.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }, checkpoint_path)
            print(f'Saved best model to {checkpoint_path}')
        
        # Early stopping
        if optimizer.param_groups[0]['lr'] < 1e-6:
            print('Learning rate too small. Stopping training.')
            break

if __name__ == "__main__":
    main()

Using device: cuda
Successfully loaded configuration files

Setting up TissueNet datasets...
Created train dataset with 10320 samples
Created val dataset with 3118 samples
Created test dataset with 1324 samples

Setting up PanNuke datasets...
Created train dataset with 5179 samples
Created val dataset with 2722 samples





Epoch 1/100


AttributeError: 'dict' object has no attribute 'to'

In [14]:
# Debug dataset outputs
def inspect_dataloaders(tissuenet_loader, pannuke_loader):
    print("\nInspecting TissueNet (MIF) batch:")
    mif_batch = next(iter(tissuenet_loader))
    print("MIF batch type:", type(mif_batch))
    if isinstance(mif_batch, (list, tuple)):
        print("MIF batch length:", len(mif_batch))
        print("MIF first element type:", type(mif_batch[0]))
        print("MIF first element shape:", mif_batch[0].shape if torch.is_tensor(mif_batch[0]) else "Not a tensor")
    elif isinstance(mif_batch, dict):
        print("MIF batch keys:", mif_batch.keys())
        for k, v in mif_batch.items():
            print(f"Key: {k}, Type: {type(v)}, Shape: {v.shape if torch.is_tensor(v) else 'Not a tensor'}")
    
    print("\nInspecting PanNuke (H&E) batch:")
    he_batch = next(iter(pannuke_loader))
    print("H&E batch type:", type(he_batch))
    if isinstance(he_batch, (list, tuple)):
        print("H&E batch length:", len(he_batch))
        print("H&E first element type:", type(he_batch[0]))
        print("H&E first element shape:", he_batch[0].shape if torch.is_tensor(he_batch[0]) else "Not a tensor")
    elif isinstance(he_batch, dict):
        print("H&E batch keys:", he_batch.keys())
        for k, v in he_batch.items():
            print(f"Key: {k}, Type: {type(v)}, Shape: {v.shape if torch.is_tensor(v) else 'Not a tensor'}")

# In your main function:
data_manager = MultiModalDatasetManager()
tissuenet_datasets, tissuenet_loaders = data_manager.setup_tissuenet()
pannuke_datasets, pannuke_loaders = data_manager.setup_pannuke()

inspect_dataloaders(tissuenet_loaders['train'], pannuke_loaders['train'])

Successfully loaded configuration files

Setting up TissueNet datasets...
Created train dataset with 10320 samples
Created val dataset with 3118 samples
Created test dataset with 1324 samples

Setting up PanNuke datasets...
Created train dataset with 5179 samples
Created val dataset with 2722 samples

Inspecting TissueNet (MIF) batch:
MIF batch type: <class 'list'>
MIF batch length: 3
MIF first element type: <class 'torch.Tensor'>
MIF first element shape: torch.Size([16, 2, 256, 256])

Inspecting PanNuke (H&E) batch:
H&E batch type: <class 'list'>
H&E batch length: 4
H&E first element type: <class 'torch.Tensor'>
H&E first element shape: torch.Size([16, 3, 256, 256])
