In [11]:
import tensorflow as tf
import numpy as np

import os
import sys
# Add the current directory and its parent to the Python path
current_dir = os.path.dirname(os.path.abspath("__file__"))
parent_dir = os.path.dirname(current_dir)
sys.path.extend([current_dir, parent_dir])

from src.models.expert_he import create_he_expert
from src.models.vit import create_vit_model  # Make sure this import works
# Create dummy data
batch_size = 16
input_shape = (256, 256, 3)
num_classes = {'tc_branch': 19, 'nt_branch': 6}

dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32)

# Create model
model, encoder = create_he_expert(input_shape, num_classes)

# Compile model (with dummy loss and optimizer)
model.compile(optimizer='adam', loss='mse')

# Print model summary
model.summary()

# Try a forward pass
try:
    outputs = model(dummy_input)
    print("\nForward pass successful!")
    print("Output shapes:")
    for i, name in enumerate(['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']):
        print(f"{name}: {outputs[i].shape}")
except Exception as e:
    print(f"Error during forward pass: {str(e)}")

# Check if shapes match expected output
expected_shapes = [
    (batch_size, 256, 256, 1),  # np_branch
    (batch_size, 256, 256, 2),  # hv_branch
    (batch_size, 256, 256, num_classes['nt_branch']),  # nt_branch
    (batch_size, num_classes['tc_branch'])  # tc_branch
]

all_shapes_correct = True
for i, (output, expected_shape) in enumerate(zip(outputs, expected_shapes)):
    if output.shape != expected_shape:
        print(f"Shape mismatch in branch {i}: Expected {expected_shape}, got {output.shape}")
        all_shapes_correct = False

if all_shapes_correct:
    print("\nAll output shapes are correct!")
else:
    print("\nSome output shapes are incorrect. Please check the model architecture.")

Model output shapes:
NP branch: (None, 256, 256, 1)
HV branch: (None, 256, 256, 2)
NT branch: (None, 256, 256, 6)
TC branch: (None, 19)
Model: "model_22"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_16 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 model_21 (Functional)       (None, 256, 64)              735488    ['input_16[0][0]']            
                                                                                                  
 dense_150 (Dense)           (None, 256, 65536)           4259840   ['model_21[0][0]']            
                                                                                                  
 reshape_15 (Reshape)        (None, 256, 256, 256)    

In [25]:
import tensorflow as tf
import numpy as np
import os
import sys
from tensorflow.keras import mixed_precision

# Add the current directory and its parent to the Python path
current_dir = os.path.dirname(os.path.abspath("__file__"))
parent_dir = os.path.dirname(current_dir)
sys.path.extend([current_dir, parent_dir])

# Import your model creation functions
from src.models.expert_he import create_he_expert
from src.models.vit import create_vit_model

# Define loss functions here
def weighted_bce(class_weights):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        weights = tf.reduce_sum(tf.cast(class_weights, tf.float32) * y_true, axis=-1)
        return tf.reduce_mean(bce * weights)
    return loss

def weighted_focal_loss(alpha, gamma):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        focal_loss = -alpha * y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
        return tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))
    return loss

def dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    numerator = 2 * tf.reduce_sum(y_true * y_pred)
    denominator = tf.reduce_sum(y_true + y_pred)
    return 1 - numerator / denominator

# Set up mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Set up multi-GPU strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

# Create dummy data
batch_size = 32
input_shape = (256, 256, 3)
num_classes = {'tc_branch': 19, 'nt_branch': 6}
dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32)

# Create dummy labels
dummy_np = np.random.randint(0, 2, (batch_size, 256, 256, 1)).astype(np.float32)
dummy_hv = np.random.rand(batch_size, 256, 256, 2).astype(np.float32)
dummy_nt = np.random.randint(0, num_classes['nt_branch'], (batch_size, 256, 256)).astype(np.float32)
dummy_nt = tf.keras.utils.to_categorical(dummy_nt, num_classes['nt_branch'])
dummy_tc = np.random.randint(0, num_classes['tc_branch'], (batch_size,)).astype(np.float32)
dummy_tc = tf.keras.utils.to_categorical(dummy_tc, num_classes['tc_branch'])

dummy_labels = {
    'np_branch': dummy_np,
    'hv_branch': dummy_hv,
    'nt_branch': dummy_nt,
    'tc_branch': dummy_tc
}

with strategy.scope():
    model, encoder = create_he_expert(input_shape, num_classes)
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, clipvalue=0.5)
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)

    np_weights = tf.constant([1.0, 1.0], dtype=tf.float32)  # Example weights
    loss_functions = {
        'np_branch': lambda y_true, y_pred: weighted_bce(np_weights)(y_true, y_pred) + dice_loss(y_true, y_pred),
        'hv_branch': tf.keras.losses.MeanSquaredError(),
        'nt_branch': weighted_focal_loss(alpha=0.25, gamma=2.0),
        'tc_branch': tf.keras.losses.CategoricalCrossentropy(from_logits=False)
    }
    loss_weights = {
        'np_branch': 0.90,
        'hv_branch': 1.0,
        'nt_branch': 1.0,
        'tc_branch': 1.0
    }

    model.compile(
        optimizer=optimizer,
        loss=loss_functions,
        loss_weights=loss_weights,
        metrics={
            'np_branch': [tf.keras.metrics.BinaryIoU(target_class_ids=[1], threshold=0.5)],
            'hv_branch': [tf.keras.metrics.MeanAbsoluteError()],
            'nt_branch': [tf.keras.metrics.CategoricalAccuracy()],
            'tc_branch': [tf.keras.metrics.CategoricalAccuracy()]
        }
    )

    # Store loss functions and weights in the model for easy access
    model.loss_functions = loss_functions
    model.loss_weights = loss_weights

# Print model summary
model.summary()

# Try a forward pass
try:
    outputs = model(dummy_input)
    print("\nForward pass successful!")
    print("Output shapes:")
    for i, name in enumerate(['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']):
        print(f"{name}: {outputs[i].shape}")
except Exception as e:
    print(f"Error during forward pass: {str(e)}")

# Check if shapes match expected output
expected_shapes = [
    (batch_size, 256, 256, 1),  # np_branch
    (batch_size, 256, 256, 2),  # hv_branch
    (batch_size, 256, 256, num_classes['nt_branch']),  # nt_branch
    (batch_size, num_classes['tc_branch'])  # tc_branch
]
all_shapes_correct = True
for i, (output, expected_shape) in enumerate(zip(outputs, expected_shapes)):
    if output.shape != expected_shape:
        print(f"Shape mismatch in branch {i}: Expected {expected_shape}, got {output.shape}")
        all_shapes_correct = False
if all_shapes_correct:
    print("\nAll output shapes are correct!")
else:
    print("\nSome output shapes are incorrect. Please check the model architecture.")

@tf.function
def train_step(inputs, labels):
    def step_fn(inputs, labels):
        with tf.GradientTape() as tape:
            predictions = model(inputs, training=True)
            
            # Calculate per-example losses manually
            per_example_losses = []
            for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
                branch_loss = model.loss_functions[branch](labels[branch], predictions[branch])
                per_example_losses.append(model.loss_weights[branch] * branch_loss)
            
            per_example_loss = tf.add_n(per_example_losses)
            loss = tf.nn.compute_average_loss(per_example_loss, global_batch_size=batch_size)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        # Update metrics
        for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
            for metric in model.metrics[branch]:
                metric.update_state(labels[branch], predictions[branch])
        
        return loss

    per_replica_losses = strategy.run(step_fn, args=(inputs, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

# Modify the evaluation step as well
@tf.function
def distributed_evaluate_step(inputs, labels):
    def eval_step_fn(inputs, labels):
        predictions = model(inputs, training=False)
        
        # Calculate per-example losses manually
        per_example_losses = []
        for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
            branch_loss = model.loss_functions[branch](labels[branch], predictions[branch])
            per_example_losses.append(model.loss_weights[branch] * branch_loss)
        
        per_example_loss = tf.add_n(per_example_losses)
        loss = tf.nn.compute_average_loss(per_example_loss, global_batch_size=batch_size)
        
        # Update metrics
        for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
            for metric in model.metrics[branch]:
                metric.update_state(labels[branch], predictions[branch])
        
        return loss

    per_replica_losses = strategy.run(eval_step_fn, args=(inputs, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

# Try a training step
try:
    loss = train_step(dummy_input, dummy_labels)
    print(f"\nTraining step successful! Loss: {loss.numpy()}")
    print("Training metrics:")
    for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
        for metric in model.metrics[branch]:
            print(f"{branch} - {metric.name}: {metric.result().numpy()}")
except Exception as e:
    print(f"Error during training step: {str(e)}")

# Try evaluation
try:
    eval_loss = distributed_evaluate_step(dummy_input, dummy_labels)
    print("\nEvaluation successful!")
    print(f"Evaluation loss: {eval_loss.numpy()}")
    print("Evaluation metrics:")
    for branch in ['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']:
        for metric in model.metrics[branch]:
            print(f"{branch} - {metric.name}: {metric.result().numpy()}")
except Exception as e:
    print(f"Error during evaluation: {str(e)}")

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4
Model output shapes:
NP branch: (None, 256, 256, 1)
HV branch: (None, 256, 256, 2)
NT branch: (None, 256, 256, 6)
TC branch: (None, 19)
Model: "model_37"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_26 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 model_36 (Functional)       (None, 256, 64)              735488    ['input_26[0][0]']            
                                                                                     