In [1]:
import vesuvius
from vesuvius import Volume
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
class InkVolumeDataset(Dataset):
    def __init__(self, volume, labels, tile_size, depth):
        """
        volume: [D, H, W] - 3D volume of grayscale slices
        labels: [H, W] - 2D binary mask shared across depth
        tile_size: size of each 2D tile (height and width)
        depth: number of slices to stack per sample
        """
        self.volume = volume
        self.labels = labels
        self.tile_size = tile_size
        self.depth = depth
        self.D, self.H, self.W = volume.shape

        self.blocks = []
        for d in range(0, self.D - depth + 1, int(depth//2)):
            for y in range(0, self.H - tile_size + 1, tile_size):
                for x in range(0, self.W - tile_size + 1, tile_size):
                    label_tile = labels[y:y+tile_size, x:x+tile_size]
                    if label_tile.shape == (tile_size, tile_size):
                        self.blocks.append((d, y, x))

    def __len__(self):
        return len(self.blocks)

    def __getitem__(self, idx):
        d, y, x = self.blocks[idx]
        
        block = self.volume[d:d+self.depth, y:y+self.tile_size, x:x+self.tile_size]
        label_tile = self.labels[y:y+self.tile_size, x:x+self.tile_size]

        # Convert to tensor and ensure proper normalization
        # Don't divide by 255 again if already normalized
        block = torch.tensor(block, dtype=torch.float32)
        
        # Add channel dimension: [D, H, W] -> [1, D, H, W]
        block = block.unsqueeze(0)

        # Binary label: 1 if any ink present (more robust checking)
        has_ink = np.any(label_tile > 0.5)  # More robust than == 1.0
        label = torch.tensor([float(has_ink)], dtype=torch.float32)

        return block, label

In [3]:
class InkDetector(nn.Module):
    def __init__(self):
        super(InkDetector, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3),

            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d(1)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.Linear(128, 64),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [4]:
# Load data
segment_id = 20230827161847
segment = Volume(segment_id, normalize=True)

# Extract volume and labels
volume = segment[:64, 200:5600, 1000:4600]
labels = segment.inklabel[200:5600, 1000:4600] / 255.0

# Data setup
tile_size = 32
depth = 8
split_x = int(volume.shape[2] * 0.75)

train_volume = volume[:, :, :split_x]
train_labels = labels[:, :split_x]
valid_volume = volume[:, :, split_x:]
valid_labels = labels[:, split_x:]

# Create datasets
train_dataset = InkVolumeDataset(train_volume, train_labels, tile_size=tile_size, depth=depth)
valid_dataset = InkVolumeDataset(valid_volume, valid_labels, tile_size=tile_size, depth=depth)

# Check label distribution
all_labels = [int(label.item()) for _, label in train_dataset]
label_counts = Counter(all_labels)
print(f"Label distribution: {label_counts}")

# Calculate class weights for imbalanced data
pos_weight = None
if label_counts[0] > 0 and label_counts[1] > 0:
    pos_weight = torch.tensor([label_counts[0] / label_counts[1]])
    print(f"Using pos_weight: {pos_weight.item():.2f}")

# Create dataloaders
batch_size = 8  # Reduced batch size for debugging
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)

Label distribution: Counter({0: 155910, 1: 55770})
Using pos_weight: 2.80


In [5]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = InkDetector().to(device)

# Use weighted loss for imbalanced data
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device) if pos_weight is not None else None)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0

    for batch_images, batch_labels in train_loader:
        # batch_images already has shape [B, 1, D, H, W] from dataset
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device).view(-1, 1)

        optimizer.zero_grad()
        outputs = model(batch_images)

        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        # Fixed prediction logic
        predicted = (torch.sigmoid(outputs) > 0.5).float()  # Apply sigmoid first
        train_correct += (predicted == batch_labels).sum().item()
        train_total += batch_labels.size(0)

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device).view(-1, 1)
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            
            # Fixed prediction logic
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    print(
        f"Epoch {epoch+1}/{num_epochs} | "
        f"Train Acc: {train_correct/train_total:.4f} | Val Acc: {val_correct/val_total:.4f} | "
        f"Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(valid_loader):.4f}"
    )
    torch.save(model.state_dict(), f'model_epoch_{epoch}.pth')


Epoch 1/20 | Train Acc: 0.4547 | Val Acc: 0.7708 | Train Loss: 1.0212 | Val Loss: 0.9777
Epoch 2/20 | Train Acc: 0.4941 | Val Acc: 0.7708 | Train Loss: 1.0211 | Val Loss: 0.9768


KeyboardInterrupt: 

In [9]:
def validate_stitched_volumes_depth_blocks(model, train_volume, train_labels, valid_volume, valid_labels, 
                                          tile_size=16, depth=8, device='cuda'):
    """
    Validate model on both train and validation volumes across multiple depth blocks, 
    then stitch them together for visualization. Creates separate visualizations for each depth block.
    Assumes train_volume and valid_volume were split horizontally (along width axis).
    
    Number of depth blocks is automatically calculated based on volume depth and depth parameter
    to match the dataset creation logic (non-overlapping blocks).
    """
    
    def process_volume_depth_block(volume, labels, volume_name, depth_start, depth_end):
        """Helper function to process a single volume at a specific depth range"""
        model.eval()
        D, H, W = volume.shape
        
        prediction_map = np.zeros((H, W), dtype=np.float32)
        count_map = np.zeros((H, W), dtype=np.float32)
        
        # Create list of all tile coordinates
        tile_coords = []
        for y in range(0, H - tile_size + 1, tile_size):
            for x in range(0, W - tile_size + 1, tile_size):
                tile_coords.append((y, x))
        
        with torch.no_grad():
            # Process tiles with tqdm progress bar
            for y, x in tqdm(tile_coords, desc=f"Processing {volume_name} volume (depth {depth_start}-{depth_end-1})"):
                # Extract block from the specified depth range
                block = volume[depth_start:depth_end, y:y+tile_size, x:x+tile_size]
                
                if block.shape == (depth, tile_size, tile_size):
                    block_tensor = torch.from_numpy(block).float().unsqueeze(0).unsqueeze(0).to(device)
                    logits = model(block_tensor)
                    pred = torch.sigmoid(logits).item()
                    
                    prediction_map[y:y+tile_size, x:x+tile_size] += pred
                    count_map[y:y+tile_size, x:x+tile_size] += 1
        
        # Normalize predictions
        prediction_map = np.divide(prediction_map, count_map, where=count_map>0)
        return prediction_map
    
    # Calculate number of depth blocks to match dataset creation logic
    D = train_volume.shape[0]
    num_depth_blocks = (D - depth + 1) // depth
    
    print(f"Volume depth: {D}, Block depth: {depth}, Number of blocks: {num_depth_blocks}")
    
    # Store all results
    all_results = []
    
    # Process each depth block (matching dataset creation logic)
    for block_idx in range(num_depth_blocks):
        depth_start = block_idx * depth
        depth_end = depth_start + depth
        
        print(f"\n=== Processing Depth Block {block_idx + 1}/{num_depth_blocks} (slices {depth_start}-{depth_end-1}) ===")
        
        # Process both volumes for this depth block
        train_predictions = process_volume_depth_block(train_volume, train_labels, "training", depth_start, depth_end)
        valid_predictions = process_volume_depth_block(valid_volume, valid_labels, "validation", depth_start, depth_end)
        
        # Stitch everything back together horizontally
        # Use the middle slice of the current depth block for visualization
        middle_slice_idx = depth_start + depth // 2
        full_volume_slice = np.concatenate([train_volume[middle_slice_idx], valid_volume[middle_slice_idx]], axis=1)
        full_labels = np.concatenate([train_labels, valid_labels], axis=1)
        full_predictions = np.concatenate([train_predictions, valid_predictions], axis=1)
        
        # Create visualization for this depth block
        plt.figure(figsize=(24, 6))
        
        # Original slice
        plt.subplot(1, 4, 1)
        plt.imshow(full_volume_slice, cmap='gray')
        plt.title(f'Full Volume (Slice {middle_slice_idx})\nDepth Block {block_idx + 1} ({depth_start}-{depth_end-1})\nTrain | Valid')
        plt.axvline(x=train_volume.shape[2]-0.5, color='red', linestyle='--', linewidth=2, alpha=0.7)
        plt.axis('off')
        
        # Ground truth
        plt.subplot(1, 4, 2)
        plt.imshow(full_labels, cmap='binary')
        plt.title(f'Ground Truth Labels\nDepth Block {block_idx + 1}\nTrain | Valid')
        plt.axvline(x=train_labels.shape[1]-0.5, color='red', linestyle='--', linewidth=2, alpha=0.7)
        plt.axis('off')
        
        # Predictions
        plt.subplot(1, 4, 3)
        img = plt.imshow(full_predictions, cmap='pink', vmin=0, vmax=1)
        plt.colorbar(img, fraction=0.046, pad=0.04)
        plt.title(f'Model Predictions\nDepth Block {block_idx + 1}\nTrain | Valid')
        plt.axvline(x=train_predictions.shape[1]-0.5, color='red', linestyle='--', linewidth=2, alpha=0.7)
        plt.axis('off')
        
        # Overlay
        plt.subplot(1, 4, 4)
        plt.imshow(full_predictions, cmap='inferno', vmin=0, vmax=1)
        
        # Create overlay for ground truth
        label_overlay = np.zeros((*full_labels.shape, 4))  # RGBA
        label_overlay[full_labels > 0.5] = [1, 1, 1, 0.1]  # White with transparency
        plt.imshow(label_overlay)
        
        plt.title(f'Predictions + Ground Truth\nDepth Block {block_idx + 1}\nTrain | Valid\n(White = True Labels)')
        plt.axvline(x=train_predictions.shape[1]-0.5, color='red', linestyle='--', linewidth=2, alpha=0.7)
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Store results
        result = {
            'depth_block': block_idx + 1,
            'depth_range': (depth_start, depth_end-1),
            'train_predictions': train_predictions,
            'valid_predictions': valid_predictions,
            'full_predictions': full_predictions,
            'full_labels': full_labels,
            'full_volume_slice': full_volume_slice
        }
        all_results.append(result)
    
    # Print overall summary
    print(f"\n=== OVERALL SUMMARY ===")
    print(f"Processed {num_depth_blocks} depth blocks of {depth} slices each")
    print(f"Ground truth ink pixels: {(full_labels > 0.5).sum()} pixels")
    print(f"Train region width: {train_volume.shape[2]}")
    print(f"Valid region width: {valid_volume.shape[2]}")
    
    return all_results

In [None]:
results = validate_stitched_volumes_depth_blocks(model, train_volume, train_labels, valid_volume, valid_labels)