## Cell 1: Imports and Setup

In [8]:
import tensorflow as tf
import numpy as np
import os
from sklearn.utils.class_weight import compute_class_weight
import yaml
import matplotlib.pyplot as plt
import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Enable mixed precision
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Set up multi-GPU strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4


## Cell 2: Preprocessing Functions

In [19]:
def load_pannuke_fold(fold_path):
    fold_number = os.path.basename(fold_path).split()[-1]  # Extract the fold number
    images = np.load(os.path.join(fold_path, 'images', f'fold{fold_number}', 'images.npy'))
    masks = np.load(os.path.join(fold_path, 'masks', f'fold{fold_number}', 'masks.npy'))
    types = np.load(os.path.join(fold_path, 'images', f'fold{fold_number}', 'types.npy'))
    
    print(f"Raw images shape: {images.shape}")
    print(f"Raw masks shape: {masks.shape}")
    print(f"Raw types shape: {types.shape}")
    
    # Ensure images are float32 and in range [0, 1]
    images = images.astype(np.float32) / 255.0
    
    print(f"Processed images shape: {images.shape}")
    print(f"Processed masks shape: {masks.shape}")
    print(f"Images dtype: {images.dtype}")
    print(f"Masks dtype: {masks.dtype}")
    print(f"Images min and max: {np.min(images):.4f}, {np.max(images):.4f}")
    print(f"Masks min and max: {np.min(masks):.4f}, {np.max(masks):.4f}")
    print(f"Percentage of non-zero mask pixels: {(masks > 0).mean() * 100:.2f}%")
    
    return images, masks, types

def create_hv_maps(masks):
    h, w = masks.shape[1:3]
    y_coords, x_coords = np.ogrid[:h, :w]
    hv_maps = np.zeros((masks.shape[0], h, w, 2), dtype=np.float32)
    
    for i in range(masks.shape[0]):
        for c in range(masks.shape[-1]):
            mask = masks[i, ..., c]
            if mask.sum() > 0:
                center_y, center_x = np.mean(np.where(mask), axis=1)
                hv_maps[i, ..., 0] += (x_coords - center_x) * mask
                hv_maps[i, ..., 1] += (y_coords - center_y) * mask
    
    # Normalize HV maps to [-1, 1]
    hv_maps = np.clip(hv_maps / np.maximum(h, w), -1, 1)
    return hv_maps

def load_fold(fold_path):
    fold_number = os.path.basename(fold_path).split()[-1]
    images = np.load(os.path.join(fold_path, 'images', f'fold{fold_number}', 'images.npy'))
    masks = np.load(os.path.join(fold_path, 'masks', f'fold{fold_number}', 'masks.npy'))
    types = np.load(os.path.join(fold_path, 'images', f'fold{fold_number}', 'types.npy'))
    
    # Convert images to float32 and normalize to [0, 1]
    images = images.astype(np.float32) / 255.0
    
    # Create binary masks for NP branch
    binary_masks = (masks.sum(axis=-1) > 0).astype(np.float32)
    
    # Create horizontal and vertical distance maps for HV branch
    hv_maps = create_hv_maps(masks)
    
    # Convert string labels to integer indices
    unique_types = np.unique(types)
    type_to_index = {t: i for i, t in enumerate(unique_types)}
    types_indices = np.array([type_to_index[t] for t in types])
    
    # Convert types to one-hot encoded format
    num_classes = len(unique_types)
    types_one_hot = tf.keras.utils.to_categorical(types_indices, num_classes=num_classes)
    
    return images, binary_masks, hv_maps, masks, types_one_hot, unique_types, types_indices

def preprocess_pannuke_data(data_dir, fold, batch_size, augment_fn=None):
    all_images, all_binary_masks, all_hv_maps, all_masks, all_types, unique_types, all_type_indices = [], [], [], [], [], None, []
    for fold_name in ['Fold 1', 'Fold 2', 'Fold 3']:
        fold_path = os.path.join(data_dir, fold_name)
        images, binary_masks, hv_maps, masks, types, fold_unique_types, type_indices = load_fold(fold_path)
        all_images.append(images)
        all_binary_masks.append(binary_masks)
        all_hv_maps.append(hv_maps)
        all_masks.append(masks)
        all_types.append(types)
        all_type_indices.extend(type_indices)
        if unique_types is None:
            unique_types = fold_unique_types

    all_images = np.concatenate(all_images)
    all_binary_masks = np.concatenate(all_binary_masks)
    all_hv_maps = np.concatenate(all_hv_maps)
    all_masks = np.concatenate(all_masks)
    all_types = np.concatenate(all_types)

    print(f"All images shape: {all_images.shape}")
    print(f"All binary masks shape: {all_binary_masks.shape}")
    print(f"All HV maps shape: {all_hv_maps.shape}")
    print(f"All masks shape: {all_masks.shape}")
    print(f"All types shape: {all_types.shape}")

    # Compute class weights for each branch
    class_weights = {
        'np_branch': compute_class_weight('balanced', classes=np.unique(all_binary_masks), y=all_binary_masks.flatten()),
        'nt_branch': compute_class_weight('balanced', classes=np.arange(all_masks.shape[-1]), y=np.argmax(all_masks, axis=-1).flatten()),
        'tc_branch': compute_class_weight('balanced', classes=np.arange(all_types.shape[-1]), y=np.argmax(all_types, axis=-1))
    }
    
    # Convert class weights to dictionaries
    class_weight_dicts = {
        'np_branch': dict(enumerate(class_weights['np_branch'])),
        'nt_branch': dict(enumerate(class_weights['nt_branch'])),
        'tc_branch': dict(enumerate(class_weights['tc_branch']))
    }

    # Split data
    total_samples = len(all_images)
    if fold == 1:
        train_end = int(0.7 * total_samples)
        val_end = int(0.85 * total_samples)
        train = (all_images[:train_end], all_binary_masks[:train_end], all_hv_maps[:train_end], all_masks[:train_end], all_types[:train_end])
        val = (all_images[train_end:val_end], all_binary_masks[train_end:val_end], all_hv_maps[train_end:val_end], all_masks[train_end:val_end], all_types[train_end:val_end])
        test = (all_images[val_end:], all_binary_masks[val_end:], all_hv_maps[val_end:], all_masks[val_end:], all_types[val_end:])
    elif fold == 2:
        train_start = int(0.15 * total_samples)
        train_end = int(0.85 * total_samples)
        train = (all_images[train_start:train_end], all_binary_masks[train_start:train_end], all_hv_maps[train_start:train_end], all_masks[train_start:train_end], all_types[train_start:train_end])
        val = (all_images[:train_start], all_binary_masks[:train_start], all_hv_maps[:train_start], all_masks[:train_start], all_types[:train_start])
        test = (all_images[train_end:], all_binary_masks[train_end:], all_hv_maps[train_end:], all_masks[train_end:], all_types[train_end:])
    elif fold == 3:
        train_start = int(0.3 * total_samples)
        val_start = int(0.85 * total_samples)
        train = (all_images[train_start:], all_binary_masks[train_start:], all_hv_maps[train_start:], all_masks[train_start:], all_types[train_start:])
        val = (all_images[val_start:], all_binary_masks[val_start:], all_hv_maps[val_start:], all_masks[val_start:], all_types[val_start:])
        test = (all_images[:train_start], all_binary_masks[:train_start], all_hv_maps[:train_start], all_masks[:train_start], all_types[:train_start])
    else:
        raise ValueError("Invalid fold number. Choose 1, 2, or 3.")

    # Create TensorFlow datasets
    print("Data loaded. Creating datasets...")

    def create_dataset(images, binary_masks, hv_maps, masks, types):
        def prepare_data(image, labels):
            if augment_fn is not None and tf.random.uniform(()) > 0.5:
                image, labels['np_branch'], labels['hv_branch'], labels['nt_branch'], _ = augment_fn(
                    image, labels['np_branch'], labels['hv_branch'], labels['nt_branch'], labels['tc_branch']
                )
            return image, labels

        dataset = tf.data.Dataset.from_tensor_slices((
            images,
            {
                'np_branch': binary_masks[..., np.newaxis],
                'hv_branch': hv_maps,
                'nt_branch': masks,
                'tc_branch': types
            }
        ))
        
        if augment_fn is not None:
            dataset = dataset.map(prepare_data, num_parallel_calls=tf.data.AUTOTUNE)

        return dataset.cache().shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    train_dataset = create_dataset(*train)
    val_dataset = create_dataset(*val)
    test_dataset = create_dataset(*test)

    print("Datasets created and optimized.")

    # Print shapes for debugging
    for images, labels in train_dataset.take(1):
        print("Sample data:")
        print(f"Images shape: {images.shape}, dtype: {images.dtype}")
        print(f"NP branch shape: {labels['np_branch'].shape}, dtype: {labels['np_branch'].dtype}")
        print(f"HV branch shape: {labels['hv_branch'].shape}, dtype: {labels['hv_branch'].dtype}")
        print(f"NT branch shape: {labels['nt_branch'].shape}, dtype: {labels['nt_branch'].dtype}")
        print(f"TC branch shape: {labels['tc_branch'].shape}, dtype: {labels['tc_branch'].dtype}")

    print("Data preprocessing complete.")
    return train_dataset, val_dataset, test_dataset, unique_types, class_weight_dict


print("Updated preprocess_pannuke_data function")

print("Preprocessing functions defined")

Updated preprocess_pannuke_data function
Preprocessing functions defined


## Cell 3: Model Creation (ViT and Expert HE)

In [20]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def create_patches(images, patch_size):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def create_vit_model(
    input_shape,
    patch_size,
    num_patches,
    projection_dim,
    num_transformer_layers,
    num_heads,
    mlp_head_units,
    dropout_rate,
    num_classes,
):
    inputs = layers.Input(shape=input_shape)
    patches = layers.Lambda(lambda x: create_patches(x, patch_size))(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(num_transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = layers.Dense(units=mlp_head_units[0], activation="gelu")(x3)
        x3 = layers.Dropout(dropout_rate)(x3)
        x3 = layers.Dense(units=projection_dim)(x3)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(dropout_rate)(representation)

    features = layers.Dense(mlp_head_units[0], activation="gelu")(representation)
    features = layers.Dropout(dropout_rate)(features)
    features = layers.Dense(mlp_head_units[1], activation="gelu")(features)
    features = layers.Dropout(dropout_rate)(features)

    outputs = layers.Dense(num_classes, activation="softmax")(features)

    return keras.Model(inputs=inputs, outputs=outputs)

def create_decoder_branch(inputs, num_filters, num_outputs, name):
    x = inputs
    for _ in range(3):  # Increase the number of layers
        x = layers.Conv2D(num_filters, 3, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
    x = layers.Add()([x, inputs])  # Add residual connection
    outputs = layers.Conv2D(num_outputs, 1, activation='sigmoid', name=name)(x)
    return outputs

def create_he_expert(input_shape, num_classes):
    vit_encoder = create_vit_model(
        input_shape=input_shape,
        patch_size=16,
        num_patches=(input_shape[0] // 16) ** 2,
        projection_dim=64,
        num_transformer_layers=8,
        num_heads=4,
        mlp_head_units=[2048, 1024],
        dropout_rate=0.1,
        num_classes=num_classes,
    )
    
    inputs = layers.Input(shape=input_shape)
    encoder_output = vit_encoder(inputs)
    
    # Reshape encoder output to 2D
    x = layers.Dense(input_shape[0] * input_shape[1], activation="relu")(encoder_output)
    x = layers.Reshape((input_shape[0], input_shape[1], -1))(x)
    
    # Decoder (adjust to maintain input dimensions)
    x = layers.Conv2D(256, 3, padding='same', activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    
    # Three decoder branches
    np_branch = layers.Conv2D(1, 1, activation='sigmoid', name="np_branch")(x)
    hv_branch = layers.Conv2D(2, 1, activation='tanh', name="hv_branch")(x)
    nt_branch = layers.Conv2D(6, 1, activation='softmax', name="nt_branch")(x)
    
    # Tissue classification branch
    tc_branch = layers.GlobalAveragePooling2D()(x)
    tc_branch = layers.Dense(num_classes, activation='softmax', name="tc_branch")(tc_branch)
    
    full_model = tf.keras.Model(inputs=inputs, outputs=[np_branch, hv_branch, nt_branch, tc_branch])
    encoder_model = tf.keras.Model(inputs=inputs, outputs=encoder_output)
    
    print("Model output shapes:")
    print(f"NP branch: {np_branch.shape}")
    print(f"HV branch: {hv_branch.shape}")
    print(f"NT branch: {nt_branch.shape}")
    print(f"TC branch: {tc_branch.shape}")
    
    return full_model, encoder_model

print("ViT and Expert HE model creation functions defined")

ViT and Expert HE model creation functions defined


## Cell 4: Loss Functions and Callbacks

In [21]:
import tensorflow as tf
import numpy as np

def weighted_bce(class_weights):
    class_weights_tensor = tf.constant(class_weights, dtype=tf.float32)
    
    def loss(y_true, y_pred):
        # Cast inputs to float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)

        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        weights = tf.reduce_sum(class_weights_tensor * y_true, axis=-1)
        return tf.reduce_mean(bce * weights)
    return loss

def weighted_focal_loss(alpha, gamma):
    def loss(y_true, y_pred):
        # Cast inputs to float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)

        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        focal_loss = -alpha * y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
        return tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))
    return loss

class ShapePrintingCallback(tf.keras.callbacks.Callback):
    def __init__(self, train_data):
        super().__init__()
        self.train_data = train_data

    def on_batch_begin(self, batch, logs=None):
        if batch == 0:
            print("\nChecking shapes on first batch:")
            x, y = next(iter(self.train_data))
            y_pred = self.model(x, training=False)
            for i, output_name in enumerate(['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']):
                print(f"{output_name} - True: {y[output_name].shape}, Pred: {y_pred[i].shape}")

class WarmUpLearningRateScheduler(tf.keras.callbacks.Callback):
    def __init__(self, warmup_batches, init_lr, verbose=0):
        super(WarmUpLearningRateScheduler, self).__init__()
        self.warmup_batches = warmup_batches
        self.init_lr = init_lr
        self.verbose = verbose
        self.batch_count = 0
        self.learning_rates = []

    def on_batch_end(self, batch, logs=None):
        self.batch_count = self.batch_count + 1
        lr = tf.keras.backend.get_value(self.model.optimizer.lr)
        self.learning_rates.append(lr)

    def on_batch_begin(self, batch, logs=None):
        if self.batch_count <= self.warmup_batches:
            lr = self.batch_count * self.init_lr / self.warmup_batches
            tf.keras.backend.set_value(self.model.optimizer.lr, lr)
            if self.verbose > 0:
                print('\nBatch %05d: WarmUpLearningRateScheduler setting learning '
                      'rate to %s.' % (self.batch_count + 1, lr))

class GradientNormLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.gradient_norms = []
        self.learning_rates = []

    def on_train_begin(self, logs=None):
        self.original_train_step = self.model.train_step

        def log_gradient_norm(norm):
            self.gradient_norms.append(float(norm.numpy()))
            return norm

        def log_learning_rate(lr):
            self.learning_rates.append(float(lr.numpy()))
            return lr

        @tf.function
        def train_step_with_gradient_logging(data):
            x, y = data
            with tf.GradientTape() as tape:
                y_pred = self.model(x, training=True)
                loss = self.model.compiled_loss(y, y_pred, regularization_losses=self.model.losses)
            
            # Compute gradients
            gradients = tape.gradient(loss, self.model.trainable_variables)
            
            # Compute gradient norm
            global_norm = tf.linalg.global_norm(gradients)
            
            # Apply gradients
            self.model.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            
            # Log gradient norm and learning rate
            tf.py_function(log_gradient_norm, [global_norm], Tout=tf.float32)
            
            if hasattr(self.model.optimizer, 'lr'):
                current_lr = self.model.optimizer.lr
                if callable(current_lr):
                    current_lr = current_lr(self.model.optimizer.iterations)
                tf.py_function(log_learning_rate, [current_lr], Tout=tf.float32)
            elif hasattr(self.model.optimizer, '_decayed_lr'):
                current_lr = self.model.optimizer._decayed_lr(tf.float32)
                tf.py_function(log_learning_rate, [current_lr], Tout=tf.float32)

            # Log to TensorBoard
            tf.summary.scalar('gradient_norm', global_norm, step=self.model.optimizer.iterations)
            tf.summary.scalar('learning_rate', current_lr, step=self.model.optimizer.iterations)
            
            # Update metrics
            self.model.compiled_metrics.update_state(y, y_pred)
            return {m.name: m.result() for m in self.model.metrics}

        self.model.train_step = train_step_with_gradient_logging

    def on_train_end(self, logs=None):
        self.model.train_step = self.original_train_step

    def on_epoch_end(self, epoch, logs=None):
        if self.gradient_norms:
            avg_gradient_norm = sum(self.gradient_norms) / len(self.gradient_norms)
            print(f"\nAverage Gradient Norm for Epoch {epoch + 1}: {avg_gradient_norm:.4f}")
        if self.learning_rates:
            avg_learning_rate = sum(self.learning_rates) / len(self.learning_rates)
            print(f"Average Learning Rate for Epoch {epoch + 1}: {avg_learning_rate:.6f}")
        self.gradient_norms = []
        self.learning_rates = []

print("Loss functions and callbacks defined")

Loss functions and callbacks defined


## Cell 5: Metrics

In [22]:
def calculate_metrics(y_true, y_pred):
    # Ensure y_true and y_pred have the same shape
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    # Convert probabilities to binary predictions
    y_pred_binary = (y_pred > 0.5).astype(int)
    accuracy = accuracy_score(y_true, y_pred_binary)
    precision = precision_score(y_true, y_pred_binary, average='binary')
    recall = recall_score(y_true, y_pred_binary, average='binary')
    f1 = f1_score(y_true, y_pred_binary, average='binary')
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

print("Metrics function defined")

Metrics function defined


## Cell 6: Main Function

In [25]:
import tensorflow as tf
import numpy as np
import os
import yaml
from tensorflow.keras import mixed_precision
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Load configuration
def load_config(config_path):
    with open(config_path, 'r') as file:
        return yaml.safe_load(file)

# Data augmentation function
def augment_data(image, np_mask, hv_map, nt_mask, tc_label):
    # Random flip left-right
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        np_mask = tf.image.flip_left_right(np_mask)
        hv_map = tf.image.flip_left_right(hv_map)
        nt_mask = tf.image.flip_left_right(nt_mask)
        hv_map = tf.stack([hv_map[..., 0] * -1, hv_map[..., 1]], axis=-1)

    # Random flip up-down
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_up_down(image)
        np_mask = tf.image.flip_up_down(np_mask)
        hv_map = tf.image.flip_up_down(hv_map)
        nt_mask = tf.image.flip_up_down(nt_mask)
        hv_map = tf.stack([hv_map[..., 0], hv_map[..., 1] * -1], axis=-1)

    # Random brightness
    image = tf.image.random_brightness(image, max_delta=0.2)
    
    # Random contrast
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    
    # Ensure image values are still in [0, 1]
    image = tf.clip_by_value(image, 0, 1)

    return image, np_mask, hv_map, nt_mask, tc_label

# Metrics calculation
def calculate_metrics(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    y_pred_binary = (y_pred > 0.5).astype(int)
    accuracy = accuracy_score(y_true, y_pred_binary)
    precision = precision_score(y_true, y_pred_binary, average='binary')
    recall = recall_score(y_true, y_pred_binary, average='binary')
    f1 = f1_score(y_true, y_pred_binary, average='binary')
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

def main(dry_run=True):
    # Get the directory of the current notebook
    notebook_dir = os.path.dirname(os.path.abspath('__file__'))
    
    # Construct paths to config files
    config_dir = os.path.join(notebook_dir, '..', 'configs')
    data_config_path = os.path.join(config_dir, 'data_config.yaml')
    model_config_path = os.path.join(config_dir, 'model_config.yaml')
    training_config_path = os.path.join(config_dir, 'training_config.yaml')

    # Load configurations
    data_config = load_config(data_config_path)
    model_config = load_config(model_config_path)
    training_config = load_config(training_config_path)
    # Enable mixed precision
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)

    # Set up multi-GPU strategy
    strategy = tf.distribute.MirroredStrategy()
    print(f"Number of devices: {strategy.num_replicas_in_sync}")

    # Adjust batch size for multi-GPU setup
    global_batch_size = 32  # Adjust as needed
    model_config['batch_size'] = global_batch_size // strategy.num_replicas_in_sync

    # Preprocess data
    train_dataset, val_dataset, test_dataset, unique_types, class_weight_dict = preprocess_pannuke_data(
        data_config['he_data_dir'],
        data_config['fold'],
        model_config['batch_size'],
        augment_data
    )
    
    # Optimize datasets
    train_dataset = train_dataset.cache().prefetch(tf.data.AUTOTUNE)
    val_dataset = val_dataset.cache().prefetch(tf.data.AUTOTUNE)
    test_dataset = test_dataset.cache().prefetch(tf.data.AUTOTUNE)

    print(f"Unique tissue types: {unique_types}")

    with strategy.scope():
        # Create model
        model, encoder = create_he_expert(model_config['input_shape'], len(unique_types))
        model.summary()

        # Define learning rate schedule
        initial_learning_rate = model_config['learning_rate']
        lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate,
            first_decay_steps=1000,
            t_mul=2.0,
            m_mul=0.9,
            alpha=0.1
        )

        # Define optimizer with gradient clipping
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipvalue=0.5)

        
        
                # Create weighted loss functions
        np_weights = tf.constant(list(class_weight_dict['np_branch'].values()), dtype=tf.float32)
        nt_weights = tf.constant(list(class_weight_dict['nt_branch'].values()), dtype=tf.float32)
        tc_weights = tf.constant(list(class_weight_dict['tc_branch'].values()), dtype=tf.float32)

        # Compile model with custom losses and class weights
        losses = {
            'np_branch': weighted_bce(np_weights),
            'hv_branch': tf.keras.losses.MeanSquaredError(),
            'nt_branch': weighted_focal_loss(alpha=0.25, gamma=2.0),
            'tc_branch': tf.keras.losses.CategoricalCrossentropy(from_logits=False)
        }

        
        loss_weights = {
            'np_branch': 1.0,
            'hv_branch': 0.5,
            'nt_branch': 0.1,
            'tc_branch': 1.0
        }

        model.compile(
            optimizer=optimizer,
            loss=losses,
            loss_weights=loss_weights,
            metrics={
                'np_branch': [tf.keras.metrics.BinaryIoU(target_class_ids=[1], threshold=0.5, dtype=tf.float32)],
                'hv_branch': [tf.keras.metrics.MeanAbsoluteError(dtype=tf.float32)],
                'nt_branch': [tf.keras.metrics.CategoricalAccuracy(dtype=tf.float32)],
                'tc_branch': [tf.keras.metrics.CategoricalAccuracy(dtype=tf.float32)]
            }
        )

    # Define callbacks
    callbacks = [ShapePrintingCallback(train_dataset)]

    if not dry_run:
        # Add additional callbacks for full training
        callbacks.extend([
            tf.keras.callbacks.EarlyStopping(patience=training_config['early_stopping_patience']),
            tf.keras.callbacks.ModelCheckpoint(
                filepath=training_config['model_checkpoint_path'],
                save_best_only=True
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,
                verbose=1,
                min_lr=1e-6
            ),
            WarmUpLearningRateScheduler(warmup_batches=1000, init_lr=1e-6, verbose=1),
            tf.keras.callbacks.TensorBoard(log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
            GradientNormLogger()
        ])

    # Train model
    try:
        if dry_run:
            print("Starting dry run...")
            history = model.fit(
                train_dataset.take(5),  # Only take 5 batches
                epochs=2,  # Run for 2 epochs
                validation_data=val_dataset.take(2),  # Only take 2 batches for validation
                callbacks=callbacks
            )
            print("Dry run completed successfully!")
        else:
            print("Starting full training...")
            history = model.fit(
                train_dataset,
                epochs=training_config['epochs'],
                validation_data=val_dataset,
                callbacks=callbacks
            )
            print("Full training completed!")

        # Plot training history
        plot_training_history(history)

        # Evaluate model
        print("Evaluating model on test set...")
        test_results = model.evaluate(test_dataset, verbose=1)
        print(f"Test Results: {test_results}")

        # Calculate and print metrics
        print("Calculating detailed metrics...")
        y_pred = model.predict(test_dataset)
        for branch, y_true in zip(['np_branch', 'nt_branch', 'tc_branch'], test_dataset.map(lambda x, y: y)):
            metrics = calculate_metrics(y_true[branch].numpy(), y_pred[branch])
            print(f"Metrics for {branch}:")
            for metric, value in metrics.items():
                print(f"{metric}: {value}")

        # Save the model
        print("Saving model...")
        model.save(training_config['final_model_path'])
        
        # Save the encoder separately
        encoder.save(training_config['encoder_model_path'])

        print("Model and encoder saved successfully.")

    except Exception as e:
        print(f"An error occurred during training: {str(e)}")
        raise

    print("Pipeline completed successfully.")


## Cell 7: Run the Main Function

In [None]:
# Run the main function
main(dry_run=True)  # Set to False for full training

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4
All images shape: (7901, 256, 256, 3)
All binary masks shape: (7901, 256, 256)
All HV maps shape: (7901, 256, 256, 2)
All masks shape: (7901, 256, 256, 6)
All types shape: (7901, 19)
