# MultiStain CycleGAN Training for Histopathology Stain Normalization

This notebook implements a complete pipeline to train a MultiStain-CycleGAN model for stain normalization using unaligned H5 datasets. The goal of the model is to learn a mapping that translates histopathology images from one domain (images from a given medical center) to another target domain (the staining style used in the test data). This is important for reducing the variability in stain appearance that can adversely affect downstream tasks such as tumor classification.

What does this notebook do?

In [1]:
# !git clone https://github.com/DBO-DKFZ/multistain_cyclegan_normalization.git

Cloning into 'multistain_cyclegan_normalization'...
remote: Enumerating objects: 136, done.[K
remote: Counting objects: 100% (136/136), done.[K
remote: Compressing objects: 100% (105/105), done.[K
remote: Total 136 (delta 27), reused 123 (delta 19), pack-reused 0 (from 0)[K
Receiving objects: 100% (136/136), 3.42 MiB | 26.34 MiB/s, done.
Resolving deltas: 100% (27/27), done.


In [1]:
!pip install dominate

Collecting dominate
  Downloading dominate-2.9.1-py2.py3-none-any.whl.metadata (13 kB)
Downloading dominate-2.9.1-py2.py3-none-any.whl (29 kB)
Installing collected packages: dominate
Successfully installed dominate-2.9.1


In [None]:
%cd multistain_cyclegan_normalization

In [None]:
# %cd ..

In [37]:
# !mkdir checkps

We define a custom dataset class (H5Dataset) that loads images from H5 files representing two domains: the source (which combines training and validation sets) and the target (test set).

In [None]:
# This dataset class loads images from H5 files.
# It supports unaligned datasets: the source (domain A) and target (domain B) images are loaded separately.
class H5Dataset(Dataset):
    def __init__(self, source_paths, target_paths, transform=None, max_source=None, max_target=None, seed=42, domain_id=None):
        super().__init__()
        # Make sure source_paths and target_paths are lists
        self.source_paths = source_paths if isinstance(source_paths, list) else [source_paths]
        self.target_paths = target_paths if isinstance(target_paths, list) else [target_paths]
        self.transform = transform
        self.seed = seed
        self.domain_id = domain_id  # If you want to load only data from a specific domain, you can set this
        random.seed(self.seed)
        # Gather keys (file identifiers) from the source H5 files;
        # we balance labels here if needed and optionally filter by a given domain.
        self.source_keys = self._gather_keys(self.source_paths, max_source, balance=True, domain=self.domain_id)
        # Gather keys from the target H5 files (e.g., test set) without balancing labels.
        self.target_keys = self._gather_keys(self.target_paths, max_target, balance=False)
        self.len_source = len(self.source_keys)
        self.len_target = len(self.target_keys)

    # Internal function to load keys from each h5 file.
    # If "balance" is True, we separate keys by label (0 and 1) and then sample equally.
    # If a domain is specified, we only use keys whose metadata matches the domain.
    def _gather_keys(self, paths, max_items, balance=False, domain=None):
        collected = []
        for fpath in paths:
            with h5py.File(fpath, 'r') as file:
                keys = list(file.keys())
                if domain is not None:
                    # Filter keys by checking if the metadata value equals the domain ID.
                    filtered = [k for k in keys if int(np.array(file[k]['metadata'])[0]) == domain]
                    selected = filtered
                elif balance:
                    # Balance the dataset by labels (0 and 1)
                    keys_by_label = {0: [], 1: []}
                    for k in keys:
                        try:
                            lb = int(np.array(file[k]['label']))
                            keys_by_label[lb].append(k)
                        except KeyError:
                            continue
                    if max_items is not None:
                        # Shuffle lists and choose an equal number from each label.
                        random.shuffle(keys_by_label[0])
                        random.shuffle(keys_by_label[1])
                        min_count = min(max_items // 2, len(keys_by_label[0]), len(keys_by_label[1]))
                        selected = keys_by_label[0][:min_count] + keys_by_label[1][:min_count]
                    else:
                        selected = keys_by_label[0] + keys_by_label[1]
                else:
                    # If no balancing is needed, sample randomly if max_items is provided.
                    selected = keys if max_items is None else random.sample(keys, min(max_items, len(keys)))
                # For each selected key, store a tuple of (file path, key)
                collected.extend([(fpath, k) for k in selected])
        random.shuffle(collected)
        return collected

    # The total length of the dataset is the maximum of source and target sizes.
    def __len__(self):
        return max(len(self.source_keys), len(self.target_keys))

    # This method returns a dictionary with images from both domains and their paths.
    # The modulo operations allow cycling if one domain has fewer images than the other.
    def __getitem__(self, idx):
        idx_src = idx % self.len_source
        idx_tgt = idx % self.len_target
        src_path, src_key = self.source_keys[idx_src]
        tgt_path, tgt_key = self.target_keys[idx_tgt]
        with h5py.File(src_path, 'r') as fs, h5py.File(tgt_path, 'r') as ft:
            img_src = torch.tensor(fs[src_key]['img'][()])
            img_tgt = torch.tensor(ft[tgt_key]['img'][()])
        # If the image is in channel-last format, we transpose it to channel-first.
        if img_src.ndim == 3 and img_src.shape[-1] == 3:
            img_src = img_src.permute(2, 0, 1)
        if img_tgt.ndim == 3 and img_tgt.shape[-1] == 3:
            img_tgt = img_tgt.permute(2, 0, 1)
        # Convert pixel values from [0,1] to [-1, 1]
        img_src = img_src.float() * 2.0 - 1.0
        img_tgt = img_tgt.float() * 2.0 - 1.0
        # Apply any provided transformations (e.g. resizing, cropping)
        if self.transform:
            img_src = self.transform(img_src)
            img_tgt = self.transform(img_tgt)
        # Return a dictionary containing the images and a string representation of their paths
        return {'A': img_src, 'B': img_tgt,
                'A_paths': f"{src_path}:{src_key}",
                'B_paths': f"{tgt_path}:{tgt_key}"}


This class instantiates and sets up a MultiStain-CycleGAN model training using the parameters provided by an options object.

A training loop iterates over the dataset batches, updates the model's learning rate, optimizes model parameters, prints loss metrics periodically, and saves checkpoints at defined intervals.

The objective is for the model to learn how to normalize the stains, converting images from the source domain to match the visual appearance of the target domain.

In [None]:
# =======================================================================
# Training function for CycleGAN
# =======================================================================
def cyclegan_training(options):
    # Print the options to verify our configuration.
    print("Training options:", options)
    
    # Create the dataset from source and target H5 files.
    # The source_paths are our training and validation data, while target_paths is the test data.
    # We're not applying any transformation.
    train_data = H5Dataset(
        source_paths=[options.train_path, options.val_path],
        target_paths=options.test_path,
        transform=None
    )
    
    # Create a DataLoader to shuffle and batch the dataset.
    data_loader = DataLoader(train_data, batch_size=options.batch_size, shuffle=True)
    print(f"Total images: {len(train_data)}")
    
    # Instantiate the CycleGAN model and set it up with our options.
    cyclegan_model = MultiStainCycleGANModel(options)
    cyclegan_model.setup(options)
    
    # Create a visualizer object to help display training results.
    vis = Visualizer(options)
    
    total_iters = 0  # This will count the total number of images processed.
    
    # Loop over epochs; the total number of epochs includes epochs with constant learning rate and decay.
    for ep in range(options.epoch_count, options.n_epochs + options.n_epochs_decay + 1):
        # Update the learning rate as per the scheduler.
        cyclegan_model.update_learning_rate()
        
        start_time = time.time()  # Record the start time for epoch timing.
        ep_iter = 0  # Initialize an epoch-specific iteration counter.
        print(f"\n-- Epoch {ep} --")
        
        # Reset visualizer to clear any previous output.
        vis.reset()
        cyclegan_model.isTrain = True  # Ensure model is in training mode.
        
        # Iterate over batches provided by the DataLoader.
        for i, batch in enumerate(data_loader):
            total_iters += options.batch_size
            ep_iter += options.batch_size
            
            # Set the current batch as input for the model.
            cyclegan_model.set_input(batch)
            # Run one forward and backward pass, update model parameters.
            cyclegan_model.optimize_parameters()
            
            # Every few iterations (as specified by print_freq), print the loss.
            if total_iters % options.print_freq == 0:
                # Get current losses as a dictionary.
                losses = cyclegan_model.get_current_losses()
                # Format the loss dictionary into a string for printing.
                loss_message = ", ".join([f"{k}: {v:.4f}" for k, v in losses.items()])
                print(f"[Epoch {ep} | Iter {ep_iter}] {loss_message}")
                # Get current visuals (images) from the model and display them.
                current_visuals = cyclegan_model.get_current_visuals()
                vis.display_current_results(current_visuals, ep, save_result=True)
        
        # Every save_epoch_freq, save the model checkpoints.
        if ep % options.save_epoch_freq == 0:
            print(f"Saving checkpoint for epoch {ep}")
            cyclegan_model.save_networks(ep)
        
        print(f"Epoch {ep} finished in {time.time() - start_time:.2f}s")
    
    # Return the trained CycleGAN model.
    return cyclegan_model



Overwriting models/base_model.py


Start the training using the options below. The options are quite the same as the paper's.

In [None]:
if __name__ == "__main__":
    import argparse

    # This Options class is used to parse all the command-line arguments that we need for training.
    # It combines basic options and training-specific options.
    class Options():
        def __init__(self):
            # We initialize an ArgumentParser with nice formatting for default values.
            self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            self.initialized = False

        def initialize(self):
           
            self.parser.add_argument('--dataroot', default='../', help='(unused in our simplified version)')
            self.parser.add_argument('--name', type=str, default='gen', help='experiment name (this folder will store checkpoints)')
            self.parser.add_argument('--gpu_ids', type=int, default=0, help='GPU ID to use; -1 for CPU')
            self.parser.add_argument('--model', type=str, default='multistain_cyclegan', help='model name')
            self.parser.add_argument('--direction', type=str, default='AtoB', help='mapping direction: AtoB or BtoA')
            self.parser.add_argument('--batch_size', type=int, default=32, help='size of each mini-batch')
            self.parser.add_argument('--input_nc', type=int, default=3, help='number of channels in input images')
            self.parser.add_argument('--output_nc', type=int, default=3, help='number of channels in output images')
            self.parser.add_argument('--ngf', type=int, default=64, help='number of filters in the last conv layer of G')
            self.parser.add_argument('--ndf', type=int, default=64, help='number of filters in the first conv layer of D')
            self.parser.add_argument('--netG', type=str, default='resnet_9blocks', help='architecture of the generator')
            self.parser.add_argument('--netD', type=str, default='basic', help='architecture of the discriminator')
            self.parser.add_argument('--norm', type=str, default='instance', help='normalization type: instance or batch')
            self.parser.add_argument('--no_dropout', action='store_true', help='disable dropout in the generator (if true)')
            self.parser.add_argument('--init_type', type=str, default='normal', help='weight initialization method')
            self.parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for initialization')
            self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='type of dataset (must be unaligned)')
            self.parser.add_argument('--color_augment', action='store_true', help='enable color jitter augmentation')
            self.parser.add_argument('--brightness', type=float, default=0.0, help='brightness value for augmentation')
            self.parser.add_argument('--contrast', type=float, default=0.0, help='contrast value for augmentation')
            self.parser.add_argument('--saturation', type=float, default=0.0, help='saturation value for augmentation')
            self.parser.add_argument('--hue', type=float, default=0.0, help='hue value for augmentation')
            self.parser.add_argument('--gan_mode', type=str, default='lsgan', help='type of GAN loss (lsgan, vanilla, or wgangp)')
            self.parser.add_argument('--pool_size', type=int, default=50, help='size of image buffer for generated images')
            self.parser.add_argument('--D_thresh', action='store_true', help='use threshold for updating the discriminator')
            self.parser.add_argument('--D_thresh_value', type=float, default=0.5, help='threshold value for discriminator update')
            self.parser.add_argument('--n_layers_D', type=int, default=3, help='number of layers in the PatchGAN discriminator')
            self.parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate scheduler type')
            self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='number of iterations between lr decays')
            self.parser.add_argument('--display_id', type=int, default=-0, help='window ID for visdom display')
            self.parser.add_argument('--display_winsize', type=int, default=256, help='window size for visdom display')
            self.parser.add_argument('--display_port', type=int, default=8097, help='port number for visdom server')
            self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server address')
            self.parser.add_argument('--display_env', type=str, default='main', help='environment name for visdom display')
            self.parser.add_argument('--display_ncols', type=int, default=4, help='number of images per row in visdom')
            self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results as HTML')
            self.parser.add_argument('--checkpoints_dir', type=str, default='../checkps', help='directory where models are saved')

            # Training-specific options
            self.parser.add_argument('--epoch_count', type=int, default=1, help='starting epoch number')
            self.parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with constant learning rate')
            self.parser.add_argument('--n_epochs_decay', type=int, default=2, help='number of epochs with decaying learning rate')
            self.parser.add_argument('--lr_G', type=float, default=0.0002, help='learning rate for the generator')
            self.parser.add_argument('--lr_D', type=float, default=0.0002, help='learning rate for the discriminator')
            self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum parameter for Adam optimizer')
            self.parser.add_argument('--netD_opt', type=str, default='adam', help='optimizer for discriminator (adam or sgd)')
            self.parser.add_argument('--print_freq', type=int, default=100, help='frequency (in iterations) to print losses')
            self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency (in epochs) for saving checkpoints')
            self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle consistency loss A')
            self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle consistency loss B')
            self.parser.add_argument('--lambda_identity', type=float, default=0.5, help='weight for identity loss')
            self.parser.add_argument('--max_items_A', type=int, default=None, help='maximum number of images to load from domain A (train)')
            self.parser.add_argument('--max_items_B', type=int, default=None, help='maximum number of images to load from domain B (val)')
            self.parser.add_argument('--train_path', type=str, default='../data/train.h5', help='path to the train h5 file')
            self.parser.add_argument('--val_path', type=str, default='../data/val.h5', help='path to the validation h5 file')
            self.parser.add_argument('--test_path', type=str, default='../data/test.h5', help='path to the test h5 file')
            self.parser.add_argument('--domain', type=int, default=None, help='ID of the selected source domain')
            self.initialized = True

        def parse(self):
            if not self.initialized:
                self.initialize()
            # We use parse_known_args() to ignore extraneous arguments injected by Jupyter
            opt, _ = self.parser.parse_known_args()
            opt.isTrain = True  # Force training mode
            return opt

    # Instantiate options and print them for debugging
    opt = Options().parse()
    import pprint
    pprint.pprint(vars(opt))

    # Now run the CycleGAN training function with our parsed options.
    trained_cyclegan = cyclegan_training(opt)


[INFO] Options: Namespace(dataroot='../', name='experiment_name', gpu_ids=0, model='multistain_cyclegan', direction='AtoB', batch_size=32, input_nc=3, output_nc=3, ngf=64, ndf=64, netG='resnet_9blocks', netD='basic', norm='instance', no_dropout=False, init_type='normal', init_gain=0.02, dataset_mode='unaligned', color_augment=False, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, gan_mode='lsgan', pool_size=50, D_thresh=False, D_thresh_value=0.5, n_layers_D=3, lr_policy='linear', lr_decay_iters=50, display_id=0, display_winsize=256, display_port=8097, display_server='http://localhost', display_env='main', display_ncols=4, no_html=False, checkpoints_dir='../checkps', epoch_count=1, n_epochs=20, n_epochs_decay=2, lr_G=0.0002, lr_D=0.0002, beta1=0.5, netD_opt='adam', print_freq=100, save_epoch_freq=1, lambda_A=10.0, lambda_B=10.0, lambda_identity=0.5, max_items_A=None, max_items_B=None, train_path='../../input/mva-dlmi-2025-histopathology-ood-classification/train.h5', val_path='../

  self.grad_scaler_G = GradScaler()
  self.grad_scaler_D = GradScaler()
  with autocast():
  with autocast():
  with autocast():


[Epoch 1 | Iter 800] D_A: 0.2908, G_A: 0.4988, cycle_A: 4.3898, idt_A: 1.5857, D_B: 0.5103, G_B: 0.9404, cycle_B: 3.3824, idt_B: 2.4091
[Epoch 1 | Iter 1600] D_A: 0.3303, G_A: 0.2183, cycle_A: 3.0030, idt_A: 1.1791, D_B: 0.2435, G_B: 0.3233, cycle_B: 2.4801, idt_B: 1.5730
[Epoch 1 | Iter 2400] D_A: 0.2052, G_A: 0.4415, cycle_A: 2.5390, idt_A: 1.1865, D_B: 0.2504, G_B: 0.2231, cycle_B: 2.3528, idt_B: 1.2443
[Epoch 1 | Iter 3200] D_A: 0.3074, G_A: 0.2012, cycle_A: 2.8466, idt_A: 1.1350, D_B: 0.2072, G_B: 0.3637, cycle_B: 2.4266, idt_B: 1.2769
[Epoch 1 | Iter 4000] D_A: 0.3023, G_A: 0.9212, cycle_A: 2.2348, idt_A: 1.0720, D_B: 0.3061, G_B: 0.8874, cycle_B: 2.2868, idt_B: 1.1111
[Epoch 1 | Iter 4800] D_A: 0.2422, G_A: 0.2007, cycle_A: 2.4792, idt_A: 0.9183, D_B: 0.3738, G_B: 1.1714, cycle_B: 1.8843, idt_B: 1.1759
[Epoch 1 | Iter 5600] D_A: 0.2040, G_A: 0.2926, cycle_A: 2.4087, idt_A: 0.9410, D_B: 0.1649, G_B: 0.4188, cycle_B: 2.0227, idt_B: 1.2148
[Epoch 1 | Iter 6400] D_A: 0.1870, G_A: 0.