In [1]:
import os
from pathlib import Path
from typing import Optional, Callable, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from tqdm.notebook import tqdm
from latentmi import lmi, ksg
import pandas as pd
import medmnist
from medmnist import INFO

In [2]:
def add_gauss(x: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
    """Add Gaussian noise to image."""
    return x + torch.randn_like(x) * sigma


def pixelate(x: torch.Tensor, scale: int = 4) -> torch.Tensor:
    """Pixelate image by downsampling and upsampling."""
    h, w = x.shape[-2:]
    small = F.avg_pool2d(x, kernel_size=scale, stride=scale)
    return F.interpolate(small, size=(h, w), mode='nearest')


def create_model(num_classes: int, device: torch.device) -> nn.Module:
    """Create MobileNetV3-small model adapted for 3-channel input."""
    model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')
    
    # MobileNetV3 already expects 3 channels, but we'll modify classifier
    model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    
    return model.to(device)


def train_model(model: nn.Module, train_loader: DataLoader, noise_fn: Optional[Callable],
                device: torch.device, n_epochs: int = 50, lr: float = 1e-3) -> None:
    """Train the model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for epoch in tqdm(range(n_epochs), desc="Epochs", leave=True):
        epoch_loss = 0.0
        correct = 0
        total = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}", leave=False):
            x, y = x.to(device), y.squeeze().long().to(device)
            
            # Apply noise if specified
            if noise_fn:
                x = noise_fn(x)
            
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            _, predicted = outputs.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
        
        avg_loss = epoch_loss / len(train_loader)
        accuracy = 100. * correct / total
        print(f"Epoch {epoch+1}/{n_epochs} | Loss: {avg_loss:.4f} | Acc: {accuracy:.2f}%")


def compute_mi(model: nn.Module, data_loader: DataLoader, noise_fn: Optional[Callable],
               device: torch.device, num_classes: int) -> Tuple[float, float, dict]:
    """Compute mutual information between representations and labels.
    
    Returns:
        mi_score: Continuous MI from representations
        discrete_mi: Overall discrete MI
        ova_mis: Dictionary mapping class indices to one-vs-all MIs
    """
    model.eval()
    
    representations = []
    targets = []
    y_hats = []
    
    # Create hook to extract features
    features = {}
    def get_features(name):
        def hook(model, input, output):
            features[name] = output.detach()
        return hook
    
    # Register hook on last layer before classifier
    model.classifier[0].register_forward_hook(get_features('last_layer'))
    
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.squeeze().long()
            
            # Apply noise if specified
            if noise_fn:
                x = noise_fn(x)
            
            y_hat = model(x)
            representations.append(features['last_layer'].cpu())
            targets.append(y)
            y_hats.append(y_hat.cpu())
    
    # Convert to numpy arrays
    X = torch.cat(representations).numpy()
    Y = torch.cat(targets).numpy()
    Y_onehot = np.eye(num_classes)[Y]
    Y_hats_onehot = torch.cat(y_hats).numpy()
    Y_hats = np.argmax(Y_hats_onehot, axis=1)
    
    print(f"Representations shape: {X.shape}, Labels shape: {Y.shape}, Predictions shape: {Y_hats.shape}")
    
    # Calculate MI using latentmi
    pmis, _, _ = lmi.estimate(X, Y_onehot, validation_split=0.3, batch_size=512, 
                             epochs=50, quiet=False)
    mi_score = np.nanmean(pmis)
    
    print(f"MI score (representation): {mi_score:.4f}")

    discrete_mi = ksg.midd(Y, Y_hats)
    print(f"MI score (discrete): {discrete_mi:.4f}")
    
    # Compute one-vs-all MIs with balanced resampling
    print("\nComputing one-vs-all MIs...")
    ova_mis = {}
    
    for class_idx in range(num_classes):
        # Create binary labels (1 for current class, 0 for all others)
        Y_binary = (Y == class_idx).astype(int)
        Y_hats_binary = (Y_hats == class_idx).astype(int)
        
        # Count samples in each binary class
        n_positive = np.sum(Y_binary)
        n_negative = len(Y_binary) - n_positive
        
        # Balanced resampling: take min of the two classes and sample equally
        n_samples = min(n_positive, n_negative)
        
        if n_samples == 0:
            print(f"  Class_{class_idx}: skipped (no samples)")
            ova_mis[f"Class_{class_idx}"] = np.nan
            continue
        
        positive_indices = np.where(Y_binary == 1)[0]
        negative_indices = np.where(Y_binary == 0)[0]
        
        # Randomly sample n_samples from each class
        np.random.seed(42)  
        sampled_positive = np.random.choice(positive_indices, size=n_samples, replace=True)
        sampled_negative = np.random.choice(negative_indices, size=n_samples, replace=True)
        
        # Combine and shuffle
        balanced_indices = np.concatenate([sampled_positive, sampled_negative])
        np.random.shuffle(balanced_indices)
        
        # Get balanced labels and predictions
        Y_binary_balanced = Y_binary[balanced_indices]
        Y_hats_binary_balanced = Y_hats_binary[balanced_indices]
        
        # Compute MI for this one-vs-all binary classification
        ova_mi = ksg.midd(Y_binary_balanced.reshape(-1, 1), Y_hats_binary_balanced.reshape(-1, 1))
        ova_mis[f"Class_{class_idx}"] = ova_mi
        
        print(f"  Class_{class_idx} (one-vs-all, balanced {n_samples} per class): {ova_mi:.4f} bits")
    
    return mi_score, discrete_mi, ova_mis


def run_experiment(dataset_name: str, output_dir: str, noise_fn: Optional[Callable],
                   tag: str, device: torch.device, n_epochs: int = 50,
                   size: int = 224, download: bool = True) -> Tuple[float, float, dict]:
    """Run a single noise experiment."""
    print(f"\n{'='*60}")
    print(f"Starting experiment: {tag}")
    print(f"{'='*60}")
    
    # Get dataset info
    info = INFO[dataset_name]
    num_classes = len(info['label'])
    DataClass = getattr(medmnist, info['python_class'])
    
    print(f"Dataset: {dataset_name}")
    print(f"Number of classes: {num_classes}")
    print(f"Labels: {info['label']}")
    
    # Prepare transforms - convert grayscale to RGB by repeating channel
    tfm = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),  # Convert 1-channel to 3-channel
    ])
    
    # Load datasets
    train_ds = DataClass(split='train', transform=tfm, download=download, size=size)
    test_ds = DataClass(split='test', transform=tfm, download=download, size=size)
    
    train_dl = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=4)
    test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=4)
    
    print(f'Dataset split: {len(train_ds)} train, {len(test_ds)} test')
    
    # Create and train model
    model = create_model(num_classes, device)
    
    train_model(model, train_dl, noise_fn, device, n_epochs=n_epochs)
    
    # Save model
    os.makedirs(output_dir, exist_ok=True)
    model_path = os.path.join(output_dir, f"model_{tag}.pt")
    torch.save(model.state_dict(), model_path)
    print(f"Saved model to {model_path}")
    
    # Compute MI
    mi_score, discrete_mi, ova_mis = compute_mi(model, test_dl, noise_fn, device, num_classes)
    
    return mi_score, discrete_mi, ova_mis

In [3]:
dataset_name = 'tissuemnist'
output_dir = 'tissuemnist_models'
epochs = 10
size = 224
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
print(f"Using device: {device}")

experiments = [
    (None, "clean"),
] + [
    (lambda x, s=s: add_gauss(x, s), f"gauss_{s:.6f}")
    for s in np.logspace(-3, 1, 10)
] + [
    (lambda x, sc=sc: pixelate(x, sc), f"pix_{sc}x")
    for sc in [2, 4, 7, 8, 14, 16, 28, 32, 56, 112, 224]
]

# Run all experiments
results = {}
for noise_fn, tag in experiments:
    try:
        mi_score, discrete_mi, ova_mis = run_experiment(
            dataset_name, 
            output_dir, 
            noise_fn, 
            tag, 
            device,
            n_epochs=epochs,
            size=size,
            download=True
        )
        results[tag] = {
            'mi_score': mi_score,
            'discrete_mi': discrete_mi,
            **{f'ova_mi_{k}': v for k, v in ova_mis.items()}
        }
    except Exception as e:
        print(f"Error in experiment {tag}: {e}")
        import traceback
        traceback.print_exc()
        results[tag] = {'mi_score': np.nan, 'discrete_mi': np.nan}

# Save results
results_df = pd.DataFrame([
    {'experiment': tag, **scores}
    for tag, scores in results.items()
])
results_path = os.path.join(output_dir, 'results.csv')
results_df.to_csv(results_path, index=False)
print(f"\nResults saved to {results_path}")
print(results_df)

Using device: cuda

Starting experiment: clean
Dataset: tissuemnist
Number of classes: 8
Labels: {'0': 'Collecting Duct, Connecting Tubule', '1': 'Distal Convoluted Tubule', '2': 'Glomerular endothelial cells', '3': 'Interstitial endothelial cells', '4': 'Leukocytes', '5': 'Podocytes', '6': 'Proximal Tubule Segments', '7': 'Thick Ascending Limb'}




Dataset split: 165466 train, 47280 test


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 1/10 | Loss: 0.9024 | Acc: 67.24%


Epoch 2/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 2/10 | Loss: 0.7276 | Acc: 73.79%


Epoch 3/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 3/10 | Loss: 0.6651 | Acc: 76.08%


Epoch 4/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 4/10 | Loss: 0.6131 | Acc: 78.03%


Epoch 5/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 5/10 | Loss: 0.5691 | Acc: 79.54%


Epoch 6/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 6/10 | Loss: 0.5215 | Acc: 81.23%


Epoch 7/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 7/10 | Loss: 0.4742 | Acc: 83.00%


Epoch 8/10:   0%|          | 0/324 [00:00<?, ?it/s]

Epoch 8/10 | Loss: 0.4245 | Acc: 84.66%


Epoch 9/10:   0%|          | 0/324 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:

import matplotlib.pyplot as plt

# Extract noise levels and MIs
gauss_results = {k: v for k, v in results.items() if k.startswith('gauss_')}
pix_results = {k: v for k, v in results.items() if k.startswith('pix_')}

if gauss_results:
    gauss_sigmas = [float(k.split('_')[1]) for k in gauss_results.keys()]
    gauss_mis = [v['mi_score'] for v in gauss_results.values()]
    gauss_discrete_mis = [v['discrete_mi'] for v in gauss_results.values()]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.semilogx(gauss_sigmas, gauss_mis, 'o-', label='Representation MI')
    ax1.semilogx(gauss_sigmas, gauss_discrete_mis, 's-', label='Discrete MI')
    ax1.set_xlabel('Gaussian Noise Sigma')
    ax1.set_ylabel('Mutual Information (bits)')
    ax1.set_title('MI vs Gaussian Noise')
    ax1.legend()
    ax1.grid(True)
    
    if pix_results:
        pix_scales = [int(k.split('_')[1].rstrip('x')) for k in pix_results.keys()]
        pix_mis = [v['mi_score'] for v in pix_results.values()]
        pix_discrete_mis = [v['discrete_mi'] for v in pix_results.values()]
        
        ax2.semilogx(pix_scales, pix_mis, 'o-', label='Representation MI')
        ax2.semilogx(pix_scales, pix_discrete_mis, 's-', label='Discrete MI')
        ax2.set_xlabel('Pixelation Scale')
        ax2.set_ylabel('Mutual Information (bits)')
        ax2.set_title('MI vs Pixelation')
        ax2.legend()
        ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'mi_scaling.png'), dpi=150)
    plt.show()