In [25]:
from medmnist.dataset import (PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST,
                                  BreastMNIST, BloodMNIST, TissueMNIST, OrganAMNIST, OrganCMNIST, OrganSMNIST,
                                  OrganMNIST3D, NoduleMNIST3D, AdrenalMNIST3D, FractureMNIST3D, VesselMNIST3D, SynapseMNIST3D)
from medmnist.dataset import MedMNIST

from torch import nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from pathlib import Path
from typing import Type

In [26]:
torch.set_float32_matmul_precision('medium')

In [27]:
datasets = (PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST,
                                  BreastMNIST, BloodMNIST, TissueMNIST, OrganAMNIST, OrganCMNIST, OrganSMNIST,
                                  OrganMNIST3D, NoduleMNIST3D, AdrenalMNIST3D, FractureMNIST3D, VesselMNIST3D, SynapseMNIST3D)

In [28]:
from torchvision.models import resnet18, ResNet18_Weights, resnet34, resnet50, resnet101, resnet152, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights

In [30]:
import pickle
from sklearn.model_selection import StratifiedKFold
import numpy as np
class StratifiedSplitter:
    def __init__(self):
        pass

    def split(self, strat_array, n_splits):
        self.n_splits = n_splits
        self.strat_array = strat_array
        self.masks = {i: np.zeros(len(strat_array)).astype(bool) for i in range(n_splits)}
        skf = StratifiedKFold(n_splits=self.n_splits, random_state=42, shuffle=True)
        for i, (train_index, test_index) in enumerate(skf.split(self.strat_array, self.strat_array)):
            self.masks[i][train_index] = True

    def get_train(self, fold):
        return self.masks[fold]

    def get_test(self, fold):
        return ~self.masks[fold]

    @staticmethod
    def load(path):
        with open(path, 'rb') as file:
            loaded = pickle.load(file)
        cl = StratifiedSplitter()
        cl.masks = loaded
        # Note: n_splits and strat_array aren't saved/loaded, might be needed if reusing methods other than get_train/test
        return cl

    def save(self, path):
        with open(path, 'wb') as file:
            pickle.dump(self.masks, file)
    
class ResnetModel(nn.Module):
  def __init__(self, n_classes, size=18, activation='sigmoid'):
    super(ResnetModel, self).__init__()
    self.backbone = self._backbone(size)
    self.fc = nn.Linear(self.backbone.fc.in_features, n_classes)
    self.backbone.fc = nn.Identity()
    if activation == 'sigmoid':
      self.activation = nn.Sigmoid()
    elif activation == 'softmax':
      self.activation = nn.LogSoftmax(dim=1)
    else:
      raise ValueError(f"Activation {activation} is not supported")

  def _backbone(self, size):
    if size == 18:
      backbone = resnet18(weights=ResNet18_Weights.DEFAULT)
    elif size == 34:
      backbone = resnet34(weights=ResNet34_Weights.DEFAULT)
    elif size == 50:
      backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
    elif size == 101:
      backbone = resnet101(weights=ResNet101_Weights.DEFAULT)
    elif size == 152:
      backbone = resnet152(weights=ResNet152_Weights.DEFAULT)
    else:
      raise ValueError("Invalid ResNet size. Choose from 18, 34, 50, 101, or 152.")
    return backbone
  
  def forward(self, x):
    x = self.backbone(x)
    x = self.fc(x)
    x = self.activation(x)
    return x

class Medmnist2DModel(pl.LightningModule):
  def __init__(self, model, loss):
    super(Medmnist2DModel, self).__init__()
    self.loss = loss
    self.model = model
  
    self.learning_rate = 3e-4
    self.wd = 1e-4

  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = self.loss(y_hat, y.float())
    self.log("train_loss", loss, on_epoch=True, logger=True, on_step=True)
    return loss
 
  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = self.loss(y_hat, y.float())
    self.log("validation_loss", loss, on_epoch=True, logger=True, on_step=True)
    return loss


  def configure_optimizers(self):
      opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.wd)
      out = opt
      # if self.onecycle:
      #     scheduler = OneCycleLR(opt, max_lr=self.learning_rate, total_steps=self.trainer.estimated_stepping_batches)
      #     lr_scheduler = {"scheduler": scheduler, "interval": "step"}
      #     out = {"optimizer": opt, "lr_scheduler": lr_scheduler}
      # elif self.reduce_on_plateau:
      #     scheduler = ReduceLROnPlateau(opt, "max", factor=0.67, patience=3, eps=0.0001)
      #     return [opt], [{"scheduler": scheduler, "interval": "epoch", "monitor": "val_averageprecision_epoch"}]
      # else:
      #     out = opt

      return out

class MedmnistDataModule(pl.LightningDataModule):
  def __init__(self, dataset_type: Type[MedMNIST], train_splt='train', batch_size=128, num_workers=16, image_size=224):
    super().__init__()
    
    if isinstance(train_splt, int):
      pass # stratified splitting
      self.splitter = 123
    else:
      self.splitter = None
    self.dataset_type = dataset_type 
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.image_size = image_size

  def prepare_data(self):
    pass 

  def setup(self, stage=None): # add test/predict stages that load only one dataset
    if self.splitter is not None:
      pass
    else:
      print("Setting up train dataset")
      self.train_dataset = self.dataset_type(split='train', transform=self.train_trainsforms, as_rgb=True, size=self.image_size)
      print("Setting up val dataset")
      self.val_dataset = self.dataset_type(split='val', transform=self.val_test_transforms, as_rgb=True, size=self.image_size)
      print("Setting up test dataset")
      self.test_dataset = self.dataset_type(split='test', transform=self.val_test_transforms, as_rgb=True, size=self.image_size)
    return 

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle=True, num_workers=self.num_workers)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size, shuffle=False, num_workers=self.num_workers)
  
  def test_dataloader(self):
    return DataLoader(self.test_dataset, batch_size = self.batch_size, shuffle=False, num_workers=self.num_workers)
   
  @property
  def train_trainsforms(self):
    data_transform = transforms.Compose(
      [
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])
      ]
    )
    return data_transform

  @property
  def val_test_transforms(self):
    data_transform = transforms.Compose(
      [
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5])
      ]
    )
    return data_transform


In [10]:
dm = MedmnistDataModule(PathMNIST)

In [12]:
model = Medmnist2DModel(ResnetModel(9, 18), F.nll_loss)

In [15]:
trainer = pl.Trainer(max_epochs=20)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dm)

Setting up train dataset
Setting up val dataset
Setting up test dataset


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params | Mode
---------------------------------------------
0 | model | ResnetModel | 11.2 M | eval
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.725    Total estimated model params size (MB)
0         Modules in train mode
71        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

ERROR:tornado.general:SEND Error: Host unreachable


In [None]:
dl = dm.train_dataloader()

In [None]:
x, y = next(iter(dl))

In [None]:
x.shape, y.shape

(torch.Size([64, 3, 224, 224]), torch.Size([64, 14]))

In [None]:
y

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 

In [7]:
# Cell 1: Imports and Definitions

# MedMNIST Imports
from medmnist.dataset import (PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST,
                              BreastMNIST, BloodMNIST, TissueMNIST, OrganAMNIST, OrganCMNIST, OrganSMNIST,
                              OrganMNIST3D, NoduleMNIST3D, AdrenalMNIST3D, FractureMNIST3D, VesselMNIST3D, SynapseMNIST3D)
from medmnist.dataset import MedMNIST
from medmnist import INFO # Import INFO to get dataset details

# PyTorch and PyTorch Lightning Imports
from torch import nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.transforms as transforms
import torchmetrics # Import torchmetrics
from torch.utils.data import DataLoader

# Other Imports
from pathlib import Path
from typing import Type, Optional
import pickle
from sklearn.model_selection import StratifiedKFold
import numpy as np

# Model Imports (specific models)
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152

# Configuration
torch.set_float32_matmul_precision('medium')

# --- StratifiedSplitter Class (remains unchanged) ---
class StratifiedSplitter:
  def __init__(self):
    pass

  def split(self, strat_array, n_splits):
    self.n_splits = n_splits
    self.strat_array = strat_array
    self.masks = {i: np.zeros(len(strat_array)).astype(bool) for i in range(n_splits)}
    skf = StratifiedKFold(n_splits=self.n_splits, random_state=42, shuffle=True)
    for i, (train_index, test_index) in enumerate(skf.split(self.strat_array, self.strat_array)):
      self.masks[i][train_index] = True

  def get_train(self, fold):
    return self.masks[fold]

  def get_test(self, fold):
    return ~self.masks[fold]

  @staticmethod
  def load(path):
    with open(path, 'rb') as file:
      loaded = pickle.load(file)
    cl = StratifiedSplitter()
    cl.masks = loaded
    # Potential Bug: Need to set other attributes like n_splits?
    return cl # Return the instance

  def save(self, path):
    # Potential Bug: Should serialize self.masks, not self directly
    with open(path, 'wb') as file:
      pickle.dump(self.masks, file) # Corrected serialization

# --- ResnetModel Class (remains unchanged from previous version) ---
class ResnetModel(nn.Module):
  # Explicitly add activation type
  def __init__(self, n_classes, size=18, activation_type='logsoftmax', pretrained=False):
    super(ResnetModel, self).__init__()
    self.backbone = self._backbone(size, pretrained) # Pass pretrained flag
    self.fc = nn.Linear(self.backbone.fc.in_features, n_classes)
    self.backbone.fc = nn.Identity() # Remove original fc

    # Instantiate activation based on type
    if activation_type == 'sigmoid':
      self.activation = nn.Sigmoid()
      print("Using Sigmoid activation in model.")
    elif activation_type == 'logsoftmax':
      self.activation = nn.LogSoftmax(dim=1)
      print("Using LogSoftmax activation in model.")
    elif activation_type == 'identity':
        self.activation = nn.Identity()
        print("Using Identity activation (outputting logits).")
    # Add placeholder for your custom GEV activation later
    # elif activation_type == 'gev':
    #   self.activation = YourGEVActivation(...)
    else:
      raise ValueError(f"Activation type '{activation_type}' is not supported")

  def _backbone(self, size, pretrained=False):
    weights = None # Default to no pretrained weights
    if pretrained:
        print("Warning: Pretrained weights requested but may not be optimal without ImageNet normalization and augmentations.")
        pass # Keep weights=None for now based on user request

    if size == 18:
      backbone = resnet18(weights=weights)
    elif size == 34:
      backbone = resnet34(weights=weights)
    elif size == 50:
      backbone = resnet50(weights=weights)
    elif size == 101:
      backbone = resnet101(weights=weights)
    elif size == 152:
      backbone = resnet152(weights=weights)
    else:
      raise ValueError("Invalid ResNet size. Choose from 18, 34, 50, 101, or 152.")

    # Input channel adaptation might be needed if as_rgb=False and channels != 3
    # Since as_rgb=True is forced later, this is less critical now.
    # Example:
    # if n_channels != 3:
    #    print(f"Adapting ResNet conv1 for {n_channels} input channels.")
    #    backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    return backbone

  def forward(self, x):
    x = self.backbone(x)
    x = self.fc(x)
    x = self.activation(x) # Apply activation
    return x


# --- Medmnist2DModel Class (remains unchanged from previous version) ---
class Medmnist2DModel(pl.LightningModule):
  # Added num_classes and task_type for metrics
  def __init__(self, model: nn.Module, num_classes: int, task_type: str, loss_fn: nn.Module, learning_rate=3e-4, weight_decay=1e-4):
    super(Medmnist2DModel, self).__init__()
    self.save_hyperparameters(ignore=['model', 'loss_fn']) # Save LR, WD etc.

    self.model = model
    self.loss = loss_fn # Use the provided loss function

    # --- Initialize Metrics ---
    if task_type not in ['binary', 'multiclass', 'multilabel']:
        raise ValueError(f"Invalid task_type '{task_type}' for torchmetrics.")

    common_metric_args = {'task': task_type}
    if task_type != 'binary':
         common_metric_args['num_classes'] = num_classes

    self.val_auc = torchmetrics.AUROC(**common_metric_args)
    self.val_ap = torchmetrics.AveragePrecision(**common_metric_args)
    self.test_auc = torchmetrics.AUROC(**common_metric_args)
    self.test_ap = torchmetrics.AveragePrecision(**common_metric_args)

    self.validation_step_outputs = []
    self.test_step_outputs = []


  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)

    if isinstance(self.loss, (nn.NLLLoss, nn.CrossEntropyLoss)):
        y = y.squeeze().long()
    elif isinstance(self.loss, (nn.BCEWithLogitsLoss, nn.BCELoss)):
         y = y.float()
    else:
        pass

    loss = self.loss(y_hat, y)
    self.log("train_loss", loss, on_epoch=True, logger=True, on_step=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat_activated = self.model(x)

    if isinstance(self.loss, (nn.NLLLoss, nn.CrossEntropyLoss)):
        y_true_loss = y.squeeze().long()
    elif isinstance(self.loss, (nn.BCEWithLogitsLoss, nn.BCELoss)):
         y_true_loss = y.float()
    else:
        y_true_loss = y

    loss = self.loss(y_hat_activated, y_true_loss)
    self.log("val_loss", loss, on_epoch=True, logger=True, on_step=False)

    y_true_metric = y.squeeze().int()

    if isinstance(self.loss, nn.NLLLoss):
        y_hat_probs = torch.exp(y_hat_activated)
    elif isinstance(self.loss, nn.CrossEntropyLoss):
        y_hat_probs = torch.softmax(y_hat_activated, dim=1)
    elif isinstance(self.loss, nn.BCEWithLogitsLoss):
         y_hat_probs = torch.sigmoid(y_hat_activated)
    elif isinstance(self.loss, nn.BCELoss):
         y_hat_probs = y_hat_activated
    else:
        y_hat_probs = y_hat_activated

    self.val_auc.update(y_hat_probs, y_true_metric)
    self.val_ap.update(y_hat_probs, y_true_metric)

    self.validation_step_outputs.append(loss)
    return loss

  def on_validation_epoch_end(self):
    auc = self.val_auc.compute()
    ap = self.val_ap.compute()
    self.log('val_auc_epoch', auc, prog_bar=True)
    self.log('val_ap_epoch', ap, prog_bar=True)
    self.val_auc.reset()
    self.val_ap.reset()
    self.validation_step_outputs.clear()

  # Add test_step and on_test_epoch_end similarly if needed

  def configure_optimizers(self):
      opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
      return opt


# --- MedmnistDataModule modified to load native size ---
class MedmnistDataModule(pl.LightningDataModule):
  # Removed data_root from __init__
  def __init__(self, dataset_type: Type[MedMNIST],
               image_size: int, # Specify the desired MedMNIST native size
               train_splt='train',
               batch_size=128,
               num_workers=16,
               as_rgb=True,
               download=True, # Keep download flag
               splitter=None):
    super().__init__()

    if image_size not in dataset_type.available_sizes:
        raise ValueError(f"Size {image_size} not available for {dataset_type.flag}. Available sizes: {dataset_type.available_sizes}")

    self.dataset_type = dataset_type
    self.image_size_to_load = image_size
    self.train_splt_config = train_splt
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.as_rgb = as_rgb
    self.download = download # Still useful to control download behavior
    self.splitter = splitter
    # self.data_root is no longer needed here

    if not self.as_rgb:
        print("Warning: as_rgb=False. Standard ResNet expects 3 channels...")


  def prepare_data(self):
    # Download data for the specified size using the default root
    if self.download:
        print(f"Downloading {self.dataset_type.flag} size {self.image_size_to_load} (using default root ~/.medmnist)...")
        # Omit the root argument here
        _ = self.dataset_type(split='train', download=True, size=self.image_size_to_load)


  def setup(self, stage: Optional[str] = None):
    # Omit the root argument when creating datasets
    common_args = {
        'as_rgb': self.as_rgb,
        'size': self.image_size_to_load,
        'download': False # Already handled in prepare_data
        # root argument omitted here
    }

    if stage == 'fit' or stage is None:
        print(f"Setting up train/val datasets for {self.dataset_type.flag} (size {self.image_size_to_load})...")
        if self.splitter is not None and isinstance(self.train_splt_config, int):
             raise NotImplementedError("Stratified splitting not fully implemented yet.")
        else:
             print(f"Using predefined 'train' and 'val' splits.")
             self.train_dataset = self.dataset_type(split='train', transform=self.train_transforms, **common_args)
             self.val_dataset = self.dataset_type(split='val', transform=self.val_test_transforms, **common_args)

    if stage == 'test' or stage is None:
        print(f"Setting up test dataset for {self.dataset_type.flag} (size {self.image_size_to_load})...")
        self.test_dataset = self.dataset_type(split='test', transform=self.val_test_transforms, **common_args)

    print("Dataset setup finished.")


  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True)

  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True)

  def test_dataloader(self):
    return DataLoader(self.test_dataset, batch_size = self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True if self.num_workers > 0 else False, pin_memory=True)

  @property
  def train_transforms(self):
    trans = []
    # --- No Resize needed as we load the correct size ---
    # --- No Augmentation Added Here (as requested) ---
    trans.append(transforms.ToTensor())
    # Normalize for 3 channels since as_rgb=True
    trans.append(transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]))
    return transforms.Compose(trans)

  @property
  def val_test_transforms(self):
    trans = []
    # --- No Resize needed ---
    trans.append(transforms.ToTensor())
    # Normalize for 3 channels
    trans.append(transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]))
    return transforms.Compose(trans)

print("Definitions loaded.")

Definitions loaded.


In [8]:
# Cell 2: Configuration and Execution

# --- Configuration ---
DATASET_CLASS = PathMNIST
IMAGE_SIZE = 224
MODEL_SIZE = 18
ACTIVATION = 'logsoftmax'
LOSS_FN = nn.NLLLoss()
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 64
EPOCHS = 10
NUM_WORKERS = 4

# --- Get Dataset Info ---
info = INFO[DATASET_CLASS.flag]
n_classes = len(info['label'])
task = info['task']

# Map MedMNIST task to torchmetrics task type
if task == 'multi-class' or task == 'ordinal-regression':
    metric_task_type = 'multiclass'
elif task == 'binary-class':
    metric_task_type = 'binary'
elif task == 'multi-label, binary-class':
    metric_task_type = 'multilabel'
    if isinstance(LOSS_FN, nn.NLLLoss):
         print("\n*** WARNING: NLLLoss is not typical for multi-label tasks. Consider BCEWithLogitsLoss and model activation='identity'. ***\n")
else:
    raise ValueError(f"Unsupported task type: {task}")


# --- Initialize Modules ---
print(f"Using Dataset: {DATASET_CLASS.flag}")
print(f"Task: {task}, Metric Task: {metric_task_type}, Num Classes: {n_classes}")
print(f"Model Size: {MODEL_SIZE}, Activation: {ACTIVATION}, Loss: {type(LOSS_FN).__name__}")
print(f"Image Load Size: {IMAGE_SIZE}") # Changed from Resize Target

# Check if the chosen size is available for the dataset before instantiating DataModule
if IMAGE_SIZE not in DATASET_CLASS.available_sizes:
     raise ValueError(f"Size {IMAGE_SIZE} is not available for {DATASET_CLASS.flag}. Available: {DATASET_CLASS.available_sizes}")


dm = MedmnistDataModule(
    dataset_type=DATASET_CLASS,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    as_rgb=True,
    download=True # Control download still useful
)

# Instantiate the model (no pretraining)
model_core = ResnetModel(n_classes=n_classes, size=MODEL_SIZE, activation_type=ACTIVATION, pretrained=False)

# Instantiate the LightningModule
lightning_model = Medmnist2DModel(
    model=model_core,
    num_classes=n_classes,
    task_type=metric_task_type,
    loss_fn=LOSS_FN,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# --- Initialize Trainer ---
# Using default TensorBoardLogger. Logs stored in ./lightning_logs/
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator='auto',
    devices=1,
    logger=True, # TensorBoardLogger logs to ./lightning_logs
    callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=10)],
    # Add other callbacks like ModelCheckpoint if needed
    # callbacks=[pl.callbacks.ModelCheckpoint(monitor='val_auc_epoch', mode='max')]
)

# --- Run Training ---
print(f"Starting training for {EPOCHS} epochs...")
trainer.fit(lightning_model, dm)

# --- Optional: Run Testing ---
# print("Starting testing...")
# trainer.test(lightning_model, datamodule=dm)

print(f"Run finished. Check TensorBoard logs in ./lightning_logs")

# To view logs: tensorboard --logdir ./lightning_logs

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Using Dataset: pathmnist
Task: multi-class, Metric Task: multiclass, Num Classes: 9
Model Size: 18, Activation: logsoftmax, Loss: NLLLoss
Image Load Size: 224
Using LogSoftmax activation in model.
Starting training for 10 epochs...
Downloading pathmnist size 224 (using default root ~/.medmnist)...
Setting up train/val datasets for pathmnist (size 224)...
Using predefined 'train' and 'val' splits.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Dataset setup finished.



  | Name     | Type                       | Params | Mode 
----------------------------------------------------------------
0 | model    | ResnetModel                | 11.2 M | train
1 | loss     | NLLLoss                    | 0      | train
2 | val_auc  | MulticlassAUROC            | 0      | train
3 | val_ap   | MulticlassAveragePrecision | 0      | train
4 | test_auc | MulticlassAUROC            | 0      | train
5 | test_ap  | MulticlassAveragePrecision | 0      | train
----------------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.725    Total estimated model params size (MB)
76        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 1407/1407 [00:45<00:00, 30.76it/s, v_num=3, train_loss_step=0.0463, val_auc_epoch=1.000, val_ap_epoch=0.999, train_loss_epoch=0.0335]  

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1407/1407 [00:45<00:00, 30.61it/s, v_num=3, train_loss_step=0.0463, val_auc_epoch=1.000, val_ap_epoch=0.999, train_loss_epoch=0.0335]
Run finished. Check TensorBoard logs in ./lightning_logs


In [14]:
from medmnist import ChestMNIST
dataset = ChestMNIST(split="train", download=False, size=224)


In [22]:
import pandas as pd
import numpy as np
pd.Series(dataset.labels.sum(axis=1)).value_counts()


0    42405
1    21602
2     9970
3     3378
4      829
5      218
6       49
7       14
9        2
8        1
Name: count, dtype: int64

In [23]:
targets = dataset.labels
stratify_targets = np.array([''.join(map(str, row)) for row in targets])

In [24]:
stratify_targets

array(['00000000000000', '00000000000000', '00000000000000', ...,
       '00010100000010', '00000000000000', '00000000000000'],
      shape=(78468,), dtype='<U14')

In [20]:
from torch.utils.data import Subset
import torch

In [9]:
indices = torch.tensor([0,3,4,5])

In [12]:
sb = Subset(dataset, indices)

In [13]:
sb[0]

(<PIL.Image.Image image mode=RGB size=224x224>, array([0]))

In [36]:
sb.labels

AttributeError: 'Subset' object has no attribute 'labels'

In [31]:
splitter = StratifiedSplitter()
splitter.split(strat_array=stratify_targets, n_splits=5)



In [34]:
splitter.get_test(0)

array([False, False, False, ..., False,  True, False], shape=(78468,))