In [1]:
import os
import argparse
from pathlib import Path
from typing import Optional, Callable, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import models, transforms
from PIL import Image
from tqdm.notebook import tqdm
from latentmi import lmi, ksg
import pandas as pd


In [2]:
class CellCycleDataset(Dataset):
    """Dataset for cell cycle phase classification from multi-channel images."""
    
    def __init__(self, data_dir: str, transform: Optional[Callable] = None, 
                 max_samples: Optional[int] = None):
        """
        Args:
            data_dir: Path to CellCycle directory containing phase subdirectories
            transform: Optional transform to apply to images
            max_samples: Maximum number of samples to load (None = all)
        """
        self.data_dir = Path(data_dir)
        self.transform = transform
        
        # Define cell cycle phases
        self.phases = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase']
        self.phase_to_idx = {phase: idx for idx, phase in enumerate(self.phases)}
        
        # Collect all samples
        print('Indexing dataset...')
        self.samples = []
        
        for phase in self.phases:
            phase_dir = self.data_dir / phase
            if not phase_dir.exists():
                continue
                
            # Find all unique cell IDs (by looking at Ch3 files)
            ch3_files = sorted(phase_dir.glob('*_Ch3.ome.jpg'))
            
            for ch3_file in ch3_files:
                # Extract base name (e.g., "49033" from "49033_Ch3.ome.jpg")
                base_name = ch3_file.stem.replace('_Ch3.ome', '')
                
                # Check if all 3 channels exist
                ch4_file = phase_dir / f"{base_name}_Ch4.ome.jpg"
                ch6_file = phase_dir / f"{base_name}_Ch6.ome.jpg"
                
                if ch4_file.exists() and ch6_file.exists():
                    self.samples.append({
                        'ch3': ch3_file,
                        'ch4': ch4_file,
                        'ch6': ch6_file,
                        'phase': phase,
                        'label': self.phase_to_idx[phase]
                    })
                    
                    if max_samples and len(self.samples) >= max_samples:
                        break
            
            if max_samples and len(self.samples) >= max_samples:
                break
        
        print(f'Found {len(self.samples)} valid cells')
        for phase in self.phases:
            count = sum(1 for s in self.samples if s['phase'] == phase)
            print(f'  {phase}: {count}')
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        sample = self.samples[idx]
        
        # Load all 3 channels
        ch3 = Image.open(sample['ch3']).convert('L')
        ch4 = Image.open(sample['ch4']).convert('L')
        ch6 = Image.open(sample['ch6']).convert('L')
        
        # Convert to tensors and stack
        ch3_t = transforms.ToTensor()(ch3)
        ch4_t = transforms.ToTensor()(ch4)
        ch6_t = transforms.ToTensor()(ch6)
        
        img = torch.cat([ch3_t, ch4_t, ch6_t], dim=0)  # Shape: [3, H, W]
        
        # Apply transforms
        if self.transform:
            img = self.transform(img)
        
        return img, sample['label']


def add_gauss(x: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
    """Add Gaussian noise to image."""
    return x + torch.randn_like(x) * sigma


def pixelate(x: torch.Tensor, scale: int = 4) -> torch.Tensor:
    """Pixelate image by downsampling and upsampling."""
    h, w = x.shape[-2:]
    small = F.avg_pool2d(x, kernel_size=scale, stride=scale)
    return F.interpolate(small, size=(h, w), mode='nearest')


def create_model(num_classes: int, device: torch.device) -> nn.Module:
    """Create MobileNetV3-small model adapted for 3-channel input."""
    model = models.mobilenet_v3_small(weights='DEFAULT')
    
    # Modify first conv layer to accept 3 channels (already 3, but re-init for clarity)
    model.features[0][0] = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
    
    # Modify classifier for num_classes
    model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    
    return model.to(device)


def train_model(model: nn.Module, train_loader: DataLoader, noise_fn: Optional[Callable],
                device: torch.device, n_epochs: int = 50, lr: float = 1e-3) -> None:
    """Train the model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for epoch in tqdm(range(n_epochs), leave=True):
        epoch_loss = 0.0
        for x, y in tqdm(train_loader, leave=False):
            x, y = x.to(device), y.to(device)
            
            # Apply noise if specified
            if noise_fn:
                x = noise_fn(x)
            
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{n_epochs} | Loss: {avg_loss:.4f}")


def compute_mi(model: nn.Module, data_loader: DataLoader, noise_fn: Optional[Callable],
               device: torch.device, num_classes: int = 7) -> Tuple[float, float, dict]:
    """Compute mutual information between representations and labels.
    
    Returns:
        mi_score: Continuous MI from representations
        discrete_mi: Overall discrete MI
        ova_mis: Dictionary mapping class names to one-vs-all MIs
    """
    model.eval()
    
    representations = []
    targets = []
    y_hats = []
    
    # Create hook to extract features
    features = {}
    def get_features(name):
        def hook(model, input, output):
            features[name] = output.detach()
        return hook
    
    # Register hook on last layer before classifier
    model.classifier[0].register_forward_hook(get_features('last_layer'))
    
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            
            # Apply noise if specified
            if noise_fn:
                x = noise_fn(x)
            
            y_hat = model(x)
            representations.append(features['last_layer'].cpu())
            targets.append(y)
            y_hats.append(y_hat.cpu())
    
    # Convert to numpy arrays
    X = torch.cat(representations).numpy()
    Y = torch.cat(targets).numpy()
    # convert Y to one-hot encoding
    Y_onehot = np.eye(num_classes)[Y]
    print(f"Labels one-hot shape: {Y_onehot.shape}")
    Y_hats_onehot = torch.cat(y_hats).numpy()
    print(f"Predictions one-hot shape: {Y_hats_onehot.shape}")
    # convert to non-onehot
    Y_hats = np.argmax(Y_hats_onehot, axis=1)
    print(f"Predictions non one-hot shape: {Y_hats.shape}")
    
    print(f"Representations shape: {X.shape}, Labels shape: {Y.shape}, Predictions shape: {Y_hats.shape}")
    
    # Calculate MI using latentmi
    pmis, _, _ = lmi.estimate(X, Y_onehot, validation_split=0.3, batch_size=512, 
                             epochs=50, quiet=False)
    mi_score = np.nanmean(pmis)
    
    print(f"MI score (representation): {mi_score:.4f}")

    discrete_mi = ksg.midd(Y, Y_hats)
    print(f"MI score (discrete): {discrete_mi:.4f}")
    
    # Compute one-vs-all MIs with balanced resampling
    print("\nComputing one-vs-all MIs...")
    phase_names = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase']
    ova_mis = {}
    
    for class_idx in range(num_classes):
        class_name = phase_names[class_idx] if class_idx < len(phase_names) else f"Class_{class_idx}"
        
        # Create binary labels (1 for current class, 0 for all others)
        Y_binary = (Y == class_idx).astype(int)
        Y_hats_binary = (Y_hats == class_idx).astype(int)
        
        # Count samples in each binary class
        n_positive = np.sum(Y_binary)
        n_negative = len(Y_binary) - n_positive
        
        # Balanced resampling: take min of the two classes and sample equally
        n_samples = min(n_positive, n_negative)
        
        positive_indices = np.where(Y_binary == 1)[0]
        negative_indices = np.where(Y_binary == 0)[0]
        
        # Randomly sample n_samples from each class
        np.random.seed(42)  # For reproducibility
        sampled_positive = np.random.choice(positive_indices, size=n_samples, replace=True)
        sampled_negative = np.random.choice(negative_indices, size=n_samples, replace=True)
        
        # Combine and shuffle
        balanced_indices = np.concatenate([sampled_positive, sampled_negative])
        np.random.shuffle(balanced_indices)
        
        # Get balanced labels and predictions
        Y_binary_balanced = Y_binary[balanced_indices]
        Y_hats_binary_balanced = Y_hats_binary[balanced_indices]
        
        # Compute MI for this one-vs-all binary classification
        ova_mi = ksg.midd(Y_binary_balanced, Y_hats_binary_balanced)
        ova_mis[class_name] = ova_mi
        
        print(f"  {class_name} (one-vs-all, balanced {n_samples} per class): {ova_mi:.4f} bits")
    
    return mi_score, discrete_mi, ova_mis

In [3]:
# ==================== Main Experiment ====================
def run_experiment(data_dir: str, output_dir: str, noise_fn: Optional[Callable],
                   tag: str, device: torch.device, n_epochs: int = 50,
                   max_samples: Optional[int] = None) -> Tuple[float, float, dict]:
    """Run a single noise experiment."""
    print(f"\n{'='*60}")
    print(f"Starting experiment: {tag}")
    print(f"{'='*60}")
    
    # Prepare data
    tfm = transforms.Compose([
        transforms.Resize((224, 224)),
    ])
    
    dataset = CellCycleDataset(data_dir, transform=tfm, max_samples=max_samples)
    
    # Split dataset
    n_total = len(dataset)
    n_train = int(n_total * 0.5)
    train_inds = np.random.choice(n_total, size=n_train, replace=False)
    test_inds = np.setdiff1d(np.arange(n_total), train_inds)
    
    train_ds = Subset(dataset, train_inds)
    test_ds = Subset(dataset, test_inds)
    
    train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=1)
    test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=1)
    
    print(f'Dataset split: {n_train} train, {len(test_inds)} test')
    
    # Create and train model
    num_classes = len(dataset.phases)
    model = create_model(num_classes, device)
    
    train_model(model, train_dl, noise_fn, device, n_epochs=n_epochs)
    
    # Save model
    os.makedirs(output_dir, exist_ok=True)
    model_path = os.path.join(output_dir, f"model_{tag}.pt")
    torch.save(model.state_dict(), model_path)
    print(f"Saved model to {model_path}")
    
    # Compute MI
    mi_score, discrete_mi, ova_mis = compute_mi(model, test_dl, noise_fn, device, num_classes)
    
    return mi_score, discrete_mi, ova_mis

In [4]:
data_dir = 'data/cellcycle/CellCycle'
output_dir = 'cellcycle_models'
epochs = 10
max_samples = 1000000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
print(f"Using device: {device}")

experiments = [
    (None, "clean"),
] + [
    (lambda x, s=s: add_gauss(x, s), f"gauss_{s:.6f}")
    for s in np.logspace(-3, 1, 10)
] + [
    (lambda x, sc=sc: pixelate(x, sc), f"pix_{sc}x")
    for sc in [2,4,7,8,14,16,28,32,56,112,224]
]

# Run all experiments
results = {}
for noise_fn, tag in experiments:
    try:
        mi_score, discrete_mi, ova_mis = run_experiment(
            data_dir, 
            output_dir, 
            noise_fn, 
            tag, 
            device,
            n_epochs=epochs,
            max_samples=max_samples
        )
        results[tag] = {
            'mi_score': mi_score,
            'discrete_mi': discrete_mi,
            **{f'ova_mi_{k}': v for k, v in ova_mis.items()}
        }
    except Exception as e:
        print(f"Error in experiment {tag}: {e}")
        results[tag] = {'mi_score': np.nan, 'discrete_mi': np.nan}

# Save results
results_df = pd.DataFrame([
    {'experiment': tag, **scores}
    for tag, scores in results.items()
])
results_path = os.path.join(output_dir, 'results.csv')
results_df.to_csv(results_path, index=False)
print(f"\nResults saved to {results_path}")
print(results_df)

Using device: cuda

Starting experiment: clean
Indexing dataset...
Found 32266 valid cells
  Anaphase: 15
  G1: 14333
  G2: 8601
  Metaphase: 68
  Prophase: 606
  S: 8616
  Telophase: 27
Dataset split: 16133 train, 16133 test


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 1/10 | Loss: 0.6158


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 2/10 | Loss: 0.4559


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 3/10 | Loss: 0.4113


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 4/10 | Loss: 0.3686


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 5/10 | Loss: 0.3410


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 6/10 | Loss: 0.3203


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 7/10 | Loss: 0.2787


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 8/10 | Loss: 0.2619


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 9/10 | Loss: 0.2546


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 10/10 | Loss: 0.2146
Saved model to cellcycle_models/model_clean.pt
Labels one-hot shape: (16133, 7)
Predictions one-hot shape: (16133, 7)
Predictions non one-hot shape: (16133,)
Representations shape: (16133, 1024), Labels shape: (16133,), Predictions shape: (16133,)
epoch 49 (of max 50) ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»ðŸŒ»MI score (representation): 1.0562
MI score (discrete): 0.7814

Computing one-vs-all MIs...
  Anaphase (one-vs-all, balanced 8 per class): 0.2190 bits
  G1 (one-vs-all, balanced 7065 per class): 0.5258 bits
  G2 (one-vs-all, balanced 4355 per class): 0.2733 bits
  Metaphase (one-vs-all, balanced 34 per class): 0.0456 bits
  Prophase (one-vs-all, balanced 327 per class): 0.4420 bits
  S (one-vs-all, balanced 4332 per class): 0.1270 bits
  Telophase (one-vs-all, balanced 12 per class): 1.0000 bits

Starting experiment: gauss_0.001000
Indexing dataset...
Found 32266 valid cells
  Anaphase: 15
  G1: 14333
  G2: 8601
  Metaphase: 68
  Prophase: 606
  S: 8616
  Telop

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 1/10 | Loss: 0.6160


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 2/10 | Loss: 0.4355


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 3/10 | Loss: 0.3880


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 4/10 | Loss: 0.3437


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 5/10 | Loss: 0.3339


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch 6/10 | Loss: 0.3272


  0%|          | 0/127 [00:00<?, ?it/s]

KeyboardInterrupt: 