In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

# Example input shapes per timeframe
trigger_input_shape = (672, 12)    # e.g., 15min, ~1 week
pattern_input_shape = (336, 12)    # e.g., 1H, ~2 weeks
structure_input_shape = (180, 12)  # e.g., 1D, ~6 months

latent_dim = 64  # Size of each branch's compressed representation

# --- Encoder Branch ---
def build_encoder_branch(input_shape, latent_dim):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv1D(32, kernel_size=5, activation='relu', padding='same')(inp)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Conv1D(64, kernel_size=3, activation='relu', padding='same')(x)
    attn_out = layers.Attention()([x, x])  # Self-attention
    x = layers.GlobalAveragePooling1D()(attn_out)
    encoded = layers.Dense(latent_dim, activation='relu')(x)
    return inp, encoded

# --- Decoder Branch ---
def build_decoder_branch(latent_dim, output_shape):
    inp = layers.Input(shape=(latent_dim,))
    x = layers.Dense((output_shape[0] // 2) * 64, activation='relu')(inp)
    x = layers.Reshape((output_shape[0] // 2, 64))(x)
    x = layers.Conv1DTranspose(32, 3, strides=2, padding='same', activation='relu')(x)
    x = layers.Conv1D(output_shape[1], kernel_size=3, padding='same', activation='linear')(x)
    return inp, x

# Encoder branches
tr_in, tr_enc = build_encoder_branch(trigger_input_shape, latent_dim)
pt_in, pt_enc = build_encoder_branch(pattern_input_shape, latent_dim)
st_in, st_enc = build_encoder_branch(structure_input_shape, latent_dim)

# Concatenate encoded branches → shared latent space
merged = layers.Concatenate()([tr_enc, pt_enc, st_enc])
shared_latent = layers.Dense(latent_dim * 2, activation='relu', name="latent_vector")(merged)

# Decoder branches from shared latent
decoder_input = layers.Input(shape=(latent_dim * 2,))
tr_din, tr_dout = build_decoder_branch(latent_dim * 2, trigger_input_shape)
pt_din, pt_dout = build_decoder_branch(latent_dim * 2, pattern_input_shape)
st_din, st_dout = build_decoder_branch(latent_dim * 2, structure_input_shape)

# Models
encoder = Model(inputs=[tr_in, pt_in, st_in], outputs=shared_latent, name="Encoder")
decoder = Model(inputs=decoder_input, outputs=[tr_dout, pt_dout, st_dout], name="Decoder")

# Full autoencoder
autoencoder_outputs = decoder(encoder.output)
autoencoder = Model(inputs=[tr_in, pt_in, st_in], outputs=autoencoder_outputs, name="MultiBranchAutoencoder")
autoencoder.compile(optimizer='adam', loss='mse')

autoencoder.summary()
