In [None]:

import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
from medmnist import OrganMNIST3D, NoduleMNIST3D, AdrenalMNIST3D, VesselMNIST3D
from utils.visualization import visualize_3d_sample
from utils.data_loader import get_medmnist_dataloaders
from config import ORGAN_CLASSES, ORGAN_TO_REGION

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Load OrganMNIST3D dataset
train_loader, val_loader, test_loader, num_classes = get_medmnist_dataloaders(
    dataset_name='organ',
    batch_size=32,
    num_workers=4
)

print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

In [None]:
# Get a batch of images
images, labels = next(iter(train_loader))

print(f"Batch shape: {images.shape}")  # (batch, channels, depth, height, width)
print(f"Labels shape: {labels.shape}")

# Visualize first sample
sample_idx = 0
sample_img = images[sample_idx]
sample_label = labels[sample_idx].item()

print(f"\nSample {sample_idx}:")
print(f"  Label: {sample_label} - {ORGAN_CLASSES[sample_label]}")
print(f"  Region: {ORGAN_TO_REGION[ORGAN_CLASSES[sample_label]]}")

visualize_3d_sample(sample_img, label=ORGAN_CLASSES[sample_label])

In [None]:
# Analyze class distribution
train_dataset = train_loader.dataset
all_labels = train_dataset.labels.squeeze()

unique, counts = np.unique(all_labels, return_counts=True)

# Plot class distribution
plt.figure(figsize=(12, 6))
plt.bar(unique, counts)
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('OrganMNIST3D Class Distribution (Training Set)')
plt.xticks(unique, [ORGAN_CLASSES[i] for i in unique], rotation=45, ha='right')
plt.tight_layout()
plt.show()

print("\nClass distribution:")
for cls_idx, count in zip(unique, counts):
    print(f"  {ORGAN_CLASSES[cls_idx]}: {count} ({count/len(all_labels)*100:.1f}%)")

In [None]:
# Group by anatomical region
region_counts = {}
for cls_idx in unique:
    organ_name = ORGAN_CLASSES[cls_idx]
    region = ORGAN_TO_REGION[organ_name]
    if region not in region_counts:
        region_counts[region] = 0
    # Find count for this class
    count = counts[np.where(unique == cls_idx)[0][0]]
    region_counts[region] += count

# Plot region distribution
plt.figure(figsize=(10, 6))
regions = list(region_counts.keys())
region_values = list(region_counts.values())

colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
plt.bar(regions, region_values, color=colors)
plt.xlabel('Anatomical Region')
plt.ylabel('Count')
plt.title('Distribution by Anatomical Region')
plt.tight_layout()
plt.show()

print("\nRegion distribution:")
for region, count in region_counts.items():
    print(f"  {region}: {count} ({count/len(all_labels)*100:.1f}%)")

In [None]:
# Compute image statistics
sample_images, _ = next(iter(train_loader))

print("Image statistics:")
print(f"  Shape: {sample_images.shape}")
print(f"  Min value: {sample_images.min():.4f}")
print(f"  Max value: {sample_images.max():.4f}")
print(f"  Mean: {sample_images.mean():.4f}")
print(f"  Std: {sample_images.std():.4f}")
print(f"  Data type: {sample_images.dtype}")

In [None]:
# Load other region-specific datasets
datasets_info = {
    'nodule': ('chest', NoduleMNIST3D),
    'adrenal': ('abdomen', AdrenalMNIST3D),
    'vessel': ('brain', VesselMNIST3D),
}

print("Additional MedMNIST3D Datasets:\n")
for dataset_name, (region, dataset_class) in datasets_info.items():
    dataset = dataset_class(split='train', download=True)
    print(f"{dataset_name.upper()} ({region}):")
    print(f"  Samples: {len(dataset)}")
    print(f"  Classes: {len(np.unique(dataset.labels))}")
    print(f"  Image shape: {dataset[0][0].shape}\n")