In [10]:
import tensorflow as tf
import numpy as np

import os
import sys
# Add the current directory and its parent to the Python path
current_dir = os.path.dirname(os.path.abspath("__file__"))
parent_dir = os.path.dirname(current_dir)
sys.path.extend([current_dir, parent_dir])

from src.models.expert_he import create_he_expert
from src.models.vit import create_vit_model  # Make sure this import works
# Create dummy data
batch_size = 16
input_shape = (256, 256, 3)
num_classes = {'tc_branch': 19, 'nt_branch': 6}

dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32)

# Create model
model, encoder = create_he_expert(input_shape, num_classes)

# Compile model (with dummy loss and optimizer)
model.compile(optimizer='adam', loss='mse')

# Print model summary
model.summary()

# Try a forward pass
try:
    outputs = model(dummy_input)
    print("\nForward pass successful!")
    print("Output shapes:")
    for i, name in enumerate(['np_branch', 'hv_branch', 'nt_branch', 'tc_branch']):
        print(f"{name}: {outputs[i].shape}")
except Exception as e:
    print(f"Error during forward pass: {str(e)}")

# Check if shapes match expected output
expected_shapes = [
    (batch_size, 256, 256, 1),  # np_branch
    (batch_size, 256, 256, 2),  # hv_branch
    (batch_size, 256, 256, num_classes['nt_branch']),  # nt_branch
    (batch_size, num_classes['tc_branch'])  # tc_branch
]

all_shapes_correct = True
for i, (output, expected_shape) in enumerate(zip(outputs, expected_shapes)):
    if output.shape != expected_shape:
        print(f"Shape mismatch in branch {i}: Expected {expected_shape}, got {output.shape}")
        all_shapes_correct = False

if all_shapes_correct:
    print("\nAll output shapes are correct!")
else:
    print("\nSome output shapes are incorrect. Please check the model architecture.")

Model output shapes:
NP branch: (None, 256, 256, 1)
HV branch: (None, 256, 256, 2)
NT branch: (None, 256, 256, 6)
TC branch: (None, 19)
Model: "model_19"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_14 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 model_18 (Functional)       (None, 256, 64)              735488    ['input_14[0][0]']            
                                                                                                  
 dense_131 (Dense)           (None, 256, 65536)           4259840   ['model_18[0][0]']            
                                                                                                  
 reshape_13 (Reshape)        (None, 256, 256, 256)    