In [4]:
from enhanced_vae_model import EnhancedVAE, init_model
from graphviz import Digraph

def visualize_vae(model):
    dot = Digraph(comment='Enhanced VAE')
    dot.attr(rankdir='TB', size='12,12')

    # Input
    dot.node('input', 'Input\n1x28x28')

    # Encoder
    with dot.subgraph(name='cluster_encoder') as c:
        c.attr(label='Encoder')
        c.node('conv1', 'Conv 32x14x14')
        c.node('conv2', 'Conv 64x7x7')
        c.node('conv3', 'Conv 128x4x4')
        c.node('conv4', 'Conv 256x1x1')
        c.node('fc1', 'FC 512')
        c.node('fc21', f'FC {model.get_latent_dim()}\n(mu)')
        c.node('fc22', f'FC {model.get_latent_dim()}\n(logvar)')

    # Latent Space
    dot.node('z', f'Latent Space\n{model.get_latent_dim()}')

    # Decoder
    with dot.subgraph(name='cluster_decoder') as c:
        c.attr(label='Decoder')
        c.node('fc3', 'FC 512')
        c.node('fc4', 'FC 256')
        c.node('deconv1', 'ConvTranspose 128x4x4')
        c.node('deconv2', 'ConvTranspose 64x7x7')
        c.node('deconv3', 'ConvTranspose 32x14x14')
        c.node('deconv4', 'ConvTranspose 1x28x28')

    # Output
    dot.node('output', 'Output\n1x28x28')

    # Connections
    dot.edge('input', 'conv1')
    dot.edge('conv1', 'conv2')
    dot.edge('conv2', 'conv3')
    dot.edge('conv3', 'conv4')
    dot.edge('conv4', 'fc1')
    dot.edge('fc1', 'fc21')
    dot.edge('fc1', 'fc22')
    dot.edge('fc21', 'z')
    dot.edge('fc22', 'z')
    dot.edge('z', 'fc3')
    dot.edge('fc3', 'fc4')
    dot.edge('fc4', 'deconv1')
    dot.edge('deconv1', 'deconv2')
    dot.edge('deconv2', 'deconv3')
    dot.edge('deconv3', 'deconv4')
    dot.edge('deconv4', 'output')

    return dot

# Initialize the model with latent_dim=10
model, _ = init_model(latent_dim=10, device='cpu')

# Generate and save the graph
dot = visualize_vae(model)
dot.render("vae_model_architecture", format="png", cleanup=True)
print("Model visualization saved as 'vae_model_architecture.png'")

Model visualization saved as 'vae_model_architecture.png'
