# Recurrent block Test

In [1]:
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, TimeDistributed, Bidirectional, GRU, LayerNormalization, Masking
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    LayerNormalization,
    TimeDistributed,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import time


In [2]:
class RecurrentBlockPT(nn.Module):
    # __init__ remains correct.
    def __init__(self, input_features: int, latent_dim: int, bidirectional_merge: str = "concat"):
        super().__init__()
        self.latent_dim = latent_dim
        if bidirectional_merge != "concat":
            warnings.warn("Bidirectional merge mode defaulting to 'concat'.")
        self.conv1d = nn.Conv1d(in_channels=input_features, out_channels=2 * latent_dim, kernel_size=5, padding="same", bias=False)
        self.gru1 = nn.GRU(input_size=2 * latent_dim, hidden_size=2 * latent_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.norm1 = nn.LayerNorm(4 * latent_dim, eps=1e-3)
        self.gru2 = nn.GRU(input_size=4 * latent_dim, hidden_size=latent_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.norm2 = nn.LayerNorm(2 * latent_dim, eps=1e-3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, N, _ = x.shape
        
        # Stage 1: Convolution
        conv_in = x.reshape(B * T, N, -1).permute(0, 2, 1) # B*T as TF TimeDistributed analogon
        conv_out = F.relu(self.conv1d(conv_in))
        gru1_in = conv_out.permute(0, 2, 1) # Shape: (B*T, N, F_conv)
        
        # --- Prepare for packing ---
        # Calculate mask from convolution output (TF GRU layer with masking analogon)
        mask = (torch.abs(gru1_in).sum(dim=-1) > 0) # Shape: (B*T, N)
        lengths = mask.sum(dim=1).cpu()
        valid_indices = torch.where(lengths > 0)[0]
        
        # --- Stage 2: First GRU with packing ---
        gru1_out_full = torch.zeros(B * T, N, 4 * self.latent_dim, device=x.device, dtype=x.dtype) # allocate data
        if len(valid_indices) > 0:
            valid_lengths = lengths[valid_indices]
            # Apply GRU whilst ignoring masked data
            packed_input = pack_padded_sequence(
                gru1_in[valid_indices], valid_lengths, batch_first=True, enforce_sorted=False
            )
            packed_output, _ = self.gru1(packed_input)
            unpacked_output, _ = pad_packed_sequence(
                packed_output, batch_first=True, total_length=N
            )
            gru1_out_full[valid_indices] = unpacked_output

        # Stage 3: First LayerNorm
        norm1_in = gru1_out_full.reshape(B, T, N, -1)
        norm1_out = self.norm1(norm1_in)

        # --- Stage 4: Second GRU with packing ---
        gru2_in = norm1_out.reshape(B * T, N, -1)
        gru2_h_n_full = torch.zeros(2, B * T, self.latent_dim, device=x.device, dtype=x.dtype)
        if len(valid_indices) > 0:
            valid_lengths = lengths[valid_indices] # Use the same lengths
            packed_input_2 = pack_padded_sequence(
                gru2_in[valid_indices], valid_lengths, batch_first=True, enforce_sorted=False
            )
            _, h_n_2 = self.gru2(packed_input_2)
            gru2_h_n_full[:, valid_indices, :] = h_n_2
        gru2_final_state = gru2_h_n_full.permute(1, 0, 2).reshape(B * T, -1)

        # Stage 5: Second LayerNorm
        norm2_out = self.norm2(gru2_final_state)
        
        final_output = norm2_out.reshape(B, T, -1)
        return final_output

In [3]:
def get_recurrent_block(
    x: tf.Tensor, latent_dim: int, gru_unroll: bool, bidirectional_merge: str
):
    """Build a recurrent embedding block, using a 1D convolution followed by two bidirectional GRU layers.

    Args:
        x (tf.Tensor): Input tensor.
        latent_dim (int): Number of dimensions of the output tensor.
        gru_unroll (bool): whether to unroll the GRU layers. Defaults to False.
        bidirectional_merge (str): how to merge the forward and backward GRU layers. Defaults to "concat".

    Returns:
        tf.keras.models.Model object with the specified architecture.

    """
    encoder = TimeDistributed(
        tf.keras.layers.Conv1D(
            filters=2 * latent_dim,
            kernel_size=5,
            strides=1,  # Increased strides yield shorter sequences
            padding="same",
            activation="relu",
            kernel_initializer=he_uniform(),
            use_bias=False,
        )
    )(x)
    encoder = tf.keras.layers.Masking(mask_value=0.0)(encoder)
    encoder = TimeDistributed(
        Bidirectional(
            GRU(
                2 * latent_dim,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=True,
                unroll=gru_unroll,
                use_bias=True,
            ),
            merge_mode=bidirectional_merge,
        )
    )(encoder)
    encoder = LayerNormalization()(encoder)
    encoder = TimeDistributed(
        Bidirectional(
            GRU(
                latent_dim,
                activation="tanh",
                recurrent_activation="sigmoid",
                return_sequences=False,
                unroll=gru_unroll,
                use_bias=True,
            ),
            merge_mode=bidirectional_merge,
        )
    )(encoder)
    encoder = LayerNormalization()(encoder)
    

    return tf.keras.models.Model(x, encoder)

In [4]:
def transfer_recurrent_block_weights(tf_model, pt_model):
    """Transfers weights for the full recurrent block with GRU gate permutation."""
    conv_td, _, gru1_td, norm1, gru2_td, norm2 = tf_model.layers[1:]


    def permute_gru_weights(keras_weights):
        W_ih, W_hh, B = keras_weights
        W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
        W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)
        W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
        W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)
        B_ih, B_hh = B
        B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
        B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)
        B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
        B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])
        return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt

    pt_model.conv1d.weight.data = torch.from_numpy(conv_td.layer.get_weights()[0]).permute(2, 1, 0)
    
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(gru1_td.layer.forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1); pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(gru1_td.layer.backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1); pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)

    pt_model.norm1.weight.data = torch.from_numpy(norm1.get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm1.get_weights()[1])

    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(gru2_td.layer.forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2); pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(gru2_td.layer.backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2); pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    
    pt_model.norm2.weight.data = torch.from_numpy(norm2.get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm2.get_weights()[1])

In [5]:
class TestRecurrentBlockTranslation(unittest.TestCase):
    def setUp(self):
        """Set up the full models and transfer weights."""
        tf.keras.backend.clear_session()
        self.latent_dim = 8
        self.input_shape = (10, 6, 3) # (T, N, F)
        
        self.tf_model = get_recurrent_block(
            tf.keras.Input(shape=self.input_shape), self.latent_dim, False, "concat"
        )
        self.pt_model = RecurrentBlockPT(self.input_shape[-1], self.latent_dim)
        self.pt_model.eval()

        transfer_recurrent_block_weights(self.tf_model, self.pt_model)
        
        # Create test data WITH MASKING
        self.np_input = np.random.rand(4, *self.input_shape).astype(np.float32)
        # Mask the last two "nodes" for the first sample in the batch
        self.np_input[0, :, :, :] = 0.0

    def test_final_forward_pass_with_masking(self):
        """Test the full block with the pack_padded_sequence masking method."""
        tf_start=time.time()
        tf_output = self.tf_model(self.np_input, training=False)
        tf_end=time.time()
        tf_output_np = tf_output.numpy()
        

        pt_input_tensor = torch.from_numpy(self.np_input)
        with torch.no_grad():
            pt_start = time.time()
            pt_output = self.pt_model(pt_input_tensor)
            pt_end=time.time()
        pt_output_np = pt_output.cpu().numpy()
        print("Tensorflow execution time: " + str(tf_end-tf_start))
        print("Pytorch execution time: " + str(pt_end-pt_start))

        np.testing.assert_allclose(tf_output_np, pt_output_np, rtol=1e-5, atol=1e-5)
        print("✅ Full `RecurrentBlockPT` translation test PASSED!")
        
#To run in Jupyter:
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestRecurrentBlockTranslation)
runner.run(suite)

test_final_forward_pass_with_masking (__main__.TestRecurrentBlockTranslation)
Test the full block with the pack_padded_sequence masking method. ... ok

----------------------------------------------------------------------
Ran 1 test in 1.421s

OK


Tensorflow execution time: 0.07317900657653809
Pytorch execution time: 0.024010419845581055
✅ Full `RecurrentBlockPT` translation test PASSED!


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

# ProbabilisticDecoder Test

In [1]:
import unittest
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import layers as tfpl
from tensorflow_probability import distributions as tfd
from tensorflow_probability.python.bijectors import scale as tfb
import torch
import torch.nn as nn
from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import AffineTransform
import time

In [2]:
class ProbabilisticDecoder(tf.keras.layers.Layer):
    """Map the reconstruction output of a given decoder to a multivariate normal distribution."""

    def __init__(self, output_data_shape, **kwargs):
        """Initialize the probabilistic decoder."""
        super().__init__(**kwargs)
        self.time_distributer = tf.keras.layers.Dense(
            tfpl.IndependentNormal.params_size(output_data_shape) // 2
        )
        self.probabilistic_decoding = tfpl.DistributionLambda(
            make_distribution_fn=lambda decoded: tfd.Masked(
                tfd.Independent(
                    tfd.Normal(
                        loc=decoded[0], scale=tf.ones_like(decoded[0]),
                        validate_args=False, allow_nan_stats=False,
                    ),
                    reinterpreted_batch_ndims=1,
                ),
                validity_mask=decoded[1],
            ),
            convert_to_tensor_fn="mean",
        )
        self.scaled_probabilistic_decoding = tfpl.DistributionLambda(
            make_distribution_fn=lambda decoded: tfd.Masked(
                tfd.TransformedDistribution(
                    decoded[0], # base distribution
                    tfb.Scale(tf.cast(tf.expand_dims(decoded[1], axis=2), tf.float32)), # bijector
                    name="vae_reconstruction",
                ),
                validity_mask=decoded[1],
            ),
            convert_to_tensor_fn="mean",
        )

    def call(self, inputs):
        hidden, validity_mask = inputs
        loc_params = tf.keras.layers.TimeDistributed(self.time_distributer)(hidden)
        prob_decoded = self.probabilistic_decoding([loc_params, validity_mask])
        scaled_prob_decoded = self.scaled_probabilistic_decoding(
            [prob_decoded, validity_mask]
        )
        return scaled_prob_decoded

In [3]:
# FIX: Create a subclass that knows how to compute the mean for an Affine transform.
class AffineTransformedDistribution(TransformedDistribution):
    """
    A specific TransformedDistribution for Affine transforms that implements .mean.
    """
    def __init__(self, base_distribution, transform):
        super().__init__(base_distribution, transform)

    @property
    def mean(self):
        """
        Computes the mean of the transformed distribution.
        E[loc + scale * X] = loc + scale * E[X]
        """
        # The transform itself is callable and applies the affine transformation.
        return self.transforms[0](self.base_dist.mean)

class ProbabilisticDecoderPT(nn.Module):
    """
    PyTorch translation of the ProbabilisticDecoder, including scaling transform.
    """
    def __init__(self, hidden_dim: int, data_dim: int):
        super().__init__()
        self.loc_projection = nn.Linear(in_features=hidden_dim, out_features=data_dim)

    def forward(self, hidden: torch.Tensor, validity_mask: torch.Tensor) -> AffineTransformedDistribution:
        B, T, D = hidden.shape
        # Reconstruct mean locations
        loc_params = self.loc_projection(hidden.view(B * T, -1)).reshape(B, T, -1)

        # Define Gaussian distributions with means (init: var=1)
        scale_params = torch.ones_like(loc_params)
        base_dist = torch.distributions.Normal(loc=loc_params, scale=scale_params)

        # Multivariate Gaussian distributions for feature vector
        independent_dist = torch.distributions.Independent(base_dist, 1)
        
        # Define transform to map masked values to 0 (y = 0 + 0 * x) and unmasked-values to themselves (y = 0 + 1.0 * x)
        scale_transform = validity_mask.unsqueeze(-1).to(hidden.dtype)
        transform = AffineTransform(loc=0, scale=scale_transform)
        
        # Returns a custom class instead of the generic one as "mean" functionality otherwise would be missing.
        final_dist = AffineTransformedDistribution(independent_dist, transform)
        return final_dist

In [4]:
def transfer_probabilistic_decoder_weights(tf_model, pt_model):
    dense_layer = tf_model.time_distributer
    W, b = dense_layer.get_weights()
    pt_model.loc_projection.weight.data = torch.from_numpy(W.T)
    pt_model.loc_projection.bias.data = torch.from_numpy(b)

In [5]:
class TestProbabilisticDecoderFinal(unittest.TestCase):
    def setUp(self):
        tf.keras.backend.clear_session()
        self.batch_size, self.time_steps, self.hidden_dim, self.data_dim = 4, 10, 32, 5

        # Create TF model
        self.tf_model = ProbabilisticDecoder(output_data_shape=(self.data_dim,))
        
        # Create PyTorch model
        self.pt_model = ProbabilisticDecoderPT(hidden_dim=self.hidden_dim, data_dim=self.data_dim)
        self.pt_model.eval()

        # --- THE FIX: Zero out the input tensor based on the mask ---
        np_hidden_original = np.random.rand(self.batch_size, self.time_steps, self.hidden_dim).astype(np.float32)
        
        # Create a float mask (1.0/0.0)
        self.np_float_mask = np.ones((self.batch_size, self.time_steps), dtype=np.float32)
        self.np_float_mask[0, -1] = 0.0 # Mask last step of first item
        self.np_float_mask[1, 5:] = 0.0 # Mask multiple steps of second item

        # Apply the mask to the input data itself before feeding it to the models
        self.np_hidden_masked = np_hidden_original * self.np_float_mask[:, :, np.newaxis]

        # TF needs a boolean mask for tfd.Masked
        self.np_bool_mask = self.np_float_mask.astype(bool)

        # Build the TF model by calling it once with the masked input
        self.tf_model([tf.constant(self.np_hidden_masked), tf.constant(self.np_bool_mask)])
        
        # Transfer weights
        transfer_probabilistic_decoder_weights(self.tf_model, self.pt_model)
        print("✅ Weights transferred successfully for final test.")

    def test_final_forward_pass(self):
        """Tests that the .mean() of the final transformed distributions are identical."""
        # --- TensorFlow ---
        # Pass the zeroed-out hidden data and the boolean mask
        tf_start=time.time()
        tf_dist = self.tf_model([self.np_hidden_masked, self.np_bool_mask])
        tf_mean_np = tf_dist.mean().numpy()
        tf_end=time.time()


        # --- PyTorch ---
        pt_hidden_tensor = torch.from_numpy(self.np_hidden_masked)
        pt_mask_tensor = torch.from_numpy(self.np_float_mask)
        with torch.no_grad():
            pt_start=time.time()
            pt_dist = self.pt_model(pt_hidden_tensor, pt_mask_tensor)
        pt_mean_np = pt_dist.mean.cpu().numpy()
        pt_end=time.time()

        # --- Verification ---
        np.testing.assert_allclose(tf_mean_np, pt_mean_np, rtol=1e-6, atol=1e-6)
        
        # Check a masked-out part is zero
        self.assertTrue(np.all(pt_mean_np[0, -1, :] == 0.0))
        # Check an un-masked part is not zero
        self.assertFalse(np.all(pt_mean_np[0, 0, :] == 0.0))

        print("Tensorflow execution time: " + str(tf_end-tf_start))
        print("Pytorch execution time: " + str(pt_end-pt_start))
        
        print("\n✅ `ProbabilisticDecoderPT` FINAL translation test PASSED!")

# To run in Jupyter or a script:
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestProbabilisticDecoderFinal)
runner.run(suite)

test_final_forward_pass (__main__.TestProbabilisticDecoderFinal)
Tests that the .mean() of the final transformed distributions are identical. ... 

✅ Weights transferred successfully for final test.


ok

----------------------------------------------------------------------
Ran 1 test in 0.437s

OK


Tensorflow execution time: 0.03962516784667969
Pytorch execution time: 0.2196040153503418

✅ `ProbabilisticDecoderPT` FINAL translation test PASSED!


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

# Recurrent Decoder Test

In [6]:
import unittest
import numpy as np
import tcn
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.mixture import GaussianMixture
from spektral.layers import CensNetConv
from tensorflow.keras import Input, Model
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)
from tensorflow.keras.optimizers import Nadam
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import time
import deepof.model_utils
import deepof.clustering.model_utils_new
from deepof.clustering.censNetConv_pt import CensNetConvPT
import deepof.utils
from deepof.data_loading import get_dt
import warnings
from deepof.clustering.model_utils_new import ProbabilisticDecoderPT
from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import AffineTransform

In [7]:
def get_recurrent_decoder(
    input_shape: tuple,
    latent_dim: int,
    gru_unroll: bool = False,
    bidirectional_merge: str = "concat",
):
    """Return a recurrent neural decoder.

    Builds a deep neural network capable of decoding the structured latent space generated by one of the compatible
    classes into a sequence of motion tracking instances, either reconstructing the original
    input, or generating new data from given clusters.

    Args:
        input_shape (tuple): shape of the input data
        latent_dim (int): dimensionality of the latent space
        gru_unroll (bool): whether to unroll the GRU layers. Defaults to False.
        bidirectional_merge (str): how to merge the forward and backward GRU layers. Defaults to "concat".

    Returns:
        keras.Model: a keras model that can be trained to decode the latent space into a series of motion tracking
        sequences.

    """
    # Define and instantiate generator
    g = Input(shape=latent_dim)  # Decoder input, shaped as the latent space
    x = Input(shape=input_shape)  # Encoder input, used to generate an output mask
    validity_mask = tf.math.logical_not(tf.reduce_all(x == 0.0, axis=2))

    generator = RepeatVector(input_shape[0])(g)
    generator = Bidirectional(
        GRU(
            latent_dim,
            activation="tanh",
            recurrent_activation="sigmoid",
            return_sequences=True,
            unroll=gru_unroll,
            use_bias=True,
        ),
        merge_mode=bidirectional_merge,
    )(generator, mask=validity_mask)
    generator = LayerNormalization()(generator)
    generator = Bidirectional(
        GRU(
            2 * latent_dim,
            activation="tanh",
            recurrent_activation="sigmoid",
            return_sequences=True,
            unroll=gru_unroll,
            use_bias=True,
        ),
        merge_mode=bidirectional_merge,
    )(generator)
    generator = LayerNormalization()(generator)
    generator = tf.keras.layers.Conv1D(
        filters=2 * latent_dim,
        kernel_size=5,
        strides=1,
        padding="same",
        activation="relu",
        kernel_initializer=he_uniform(),
        use_bias=False,
    )(generator)
    generator = LayerNormalization()(generator)

    x_decoded = deepof.model_utils.ProbabilisticDecoder(input_shape)(
        [generator, validity_mask]
    )

    return Model([g, x], x_decoded, name="recurrent_decoder")


In [8]:
class RecurrentDecoderPT(nn.Module):
    """
    A full PyTorch implementation of the recurrent decoder.
    """
    def __init__(self, output_shape: tuple, latent_dim: int, bidirectional_merge: str = "concat"):
        super().__init__()
        self.latent_dim = latent_dim
        self.output_shape = output_shape
        if bidirectional_merge != "concat":
            warnings.warn("Bidirectional merge mode is fixed to 'concat' to correspond with original TensorFlow implementation.")

        # First Bi-GRU layer
        self.gru1 = nn.GRU(
            input_size=latent_dim,
            hidden_size=latent_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.norm1 = nn.LayerNorm(2 * latent_dim, eps=1e-3)

        # Second Bi-GRU layer
        self.gru2 = nn.GRU(
            input_size=2 * latent_dim, # Input from first Bi-GRU
            hidden_size=2 * latent_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.norm2 = nn.LayerNorm(4 * latent_dim, eps=1e-3) # Output of second Bi-GRU is 2 * (2*latent_dim)

        # Convolutional Layer
        self.conv1d = nn.Conv1d(
            in_channels=4 * latent_dim, # Input from second norm layer
            out_channels=2 * latent_dim,
            kernel_size=5,
            padding="same",
            bias=False
        )
        self.norm3 = nn.LayerNorm(2 * latent_dim, eps=1e-3) # Output of Conv1D

        # Probabilistic Layer 
        self.prob_decoder = ProbabilisticDecoderPT(
            hidden_dim=2 * latent_dim, # Input from third norm layer
            data_dim=output_shape[1]
        )

    def forward(self, g: torch.Tensor, x: torch.Tensor) -> TransformedDistribution:
        B, T, _ = x.shape

        # 1. Create the validity mask and sequence lengths from input 'x'
        validity_mask = ~torch.all(x == 0.0, dim=2)
        lengths = validity_mask.sum(dim=1).cpu().to(torch.int64)
        valid_indices = torch.where(lengths > 0)[0]

        # 2. Emulate RepeatVector
        generator = g.unsqueeze(1).expand(-1, T, -1)

        # 3. First Bi-GRU with masking
        gru1_out_full = torch.zeros(B, T, 2 * self.latent_dim, device=g.device, dtype=g.dtype)
        if len(valid_indices) > 0:
            # Apply GRU whilst ignoring masked data
            packed_input_1 = pack_padded_sequence(
                generator[valid_indices], lengths[valid_indices], batch_first=True, enforce_sorted=False
            )
            packed_output_1, _ = self.gru1(packed_input_1)
            unpacked_output_1, _ = pad_packed_sequence(
                packed_output_1, batch_first=True, total_length=T
            )
            gru1_out_full[valid_indices] = unpacked_output_1
        norm1_out = self.norm1(gru1_out_full)

        # 4. Second Bi-GRU with masking (reusing the same mask/lengths)
        gru2_out_full = torch.zeros(B, T, 4 * self.latent_dim, device=g.device, dtype=g.dtype)
        if len(valid_indices) > 0:
            packed_input_2 = pack_padded_sequence(
                norm1_out[valid_indices], lengths[valid_indices], batch_first=True, enforce_sorted=False
            )
            packed_output_2, _ = self.gru2(packed_input_2)
            unpacked_output_2, _ = pad_packed_sequence(
                packed_output_2, batch_first=True, total_length=T
            )
            gru2_out_full[valid_indices] = unpacked_output_2
        norm2_out = self.norm2(gru2_out_full)

        # 5. Convolution Block
        # Conv1d expects (B, C, T), so we permute
        conv_in = norm2_out.permute(0, 2, 1)
        conv_out = F.relu(self.conv1d(conv_in))
        # Permute back to (B, T, C) for LayerNorm
        norm3_in = conv_out.permute(0, 2, 1)
        norm3_out = self.norm3(norm3_in)

        # 6. Final Probabilistic Decoder
        final_dist = self.prob_decoder(norm3_out, validity_mask)

        return final_dist

In [9]:
# Helper function from the provided example to handle gate order differences
def permute_gru_weights(keras_weights):
    """Permutes GRU weights from Keras (z, r, n) to PyTorch (r, z, n) format."""
    W_ih, W_hh, B = keras_weights
    # Keras gate order: z, r, n (update, reset, new/candidate)
    W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
    W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)

    # PyTorch gate order: r, z, n (reset, update, new/candidate)
    W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
    W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)

    # Keras has two bias vectors (input-hidden and recurrent), which are concatenated in B
    B_ih, B_hh = B
    B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
    B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)

    B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
    B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])

    return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt
    
def transfer_recurrent_decoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent decoder model.
    """
    # Find layers by type to avoid index issues
    bidi_layers = [l for l in tf_model.layers if isinstance(l, Bidirectional)]
    norm_layers = [l for l in tf_model.layers if isinstance(l, LayerNormalization)]
    conv_layers = [l for l in tf_model.layers if isinstance(l, tf.keras.layers.Conv1D)]
    prob_dec_layer = next(l for l in tf_model.layers if isinstance(l, deepof.model_utils.ProbabilisticDecoder))

    # --- GRU 1 and Norm 1 ---
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(bidi_layers[0].forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1)
    pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(bidi_layers[0].backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1)
    pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)
    pt_model.norm1.weight.data = torch.from_numpy(norm_layers[0].get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm_layers[0].get_weights()[1])

    # --- GRU 2 and Norm 2 ---
    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(bidi_layers[1].forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2)
    pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(bidi_layers[1].backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2)
    pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    pt_model.norm2.weight.data = torch.from_numpy(norm_layers[1].get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm_layers[1].get_weights()[1])

    # --- Conv1D and Norm 3 ---
    # TF Conv1D weights: (kernel_w, kernel_h, in_c, out_c) -> (5, 1, 4*ld, 2*ld)
    # PT Conv1d weights: (out_c, in_c, kernel_w)
    conv_weights_tf = conv_layers[0].get_weights()[0]
    pt_model.conv1d.weight.data = torch.from_numpy(conv_weights_tf).squeeze(1).permute(2, 1, 0)
    pt_model.norm3.weight.data = torch.from_numpy(norm_layers[2].get_weights()[0]); pt_model.norm3.bias.data = torch.from_numpy(norm_layers[2].get_weights()[1])

    # --- Probabilistic Decoder ---
    # TF Dense weights: (in_features, out_features)
    # PT Linear weights: (out_features, in_features)
    prob_dec_weights, prob_dec_bias = prob_dec_layer.time_distributer.get_weights()
    pt_model.prob_decoder.loc_projection.weight.data = torch.from_numpy(prob_dec_weights.T)
    pt_model.prob_decoder.loc_projection.bias.data = torch.from_numpy(prob_dec_bias)

In [10]:
class TestRecurrentDecoderTranslation(unittest.TestCase):
    def setUp(self):
        """Set up the full models and transfer weights."""
        tf.keras.backend.clear_session()
        # Make epsilon consistent between TF and PT LayerNorm
        tf.keras.backend.set_epsilon(1e-3)

        self.latent_dim = 16
        self.input_shape = (15, 8)  # (T, Features)
        self.batch_size = 4

        # Instantiate the original full TensorFlow model
        self.tf_model = get_recurrent_decoder(
            input_shape=self.input_shape,
            latent_dim=self.latent_dim,
            bidirectional_merge="concat"
        )

        # Instantiate the full PyTorch model
        self.pt_model = RecurrentDecoderPT(
            output_shape=self.input_shape,
            latent_dim=self.latent_dim
        )
        self.pt_model.eval()

        # Transfer all weights
        transfer_recurrent_decoder_weights(self.tf_model, self.pt_model)

        # Create test data WITH MASKING
        self.np_latent_input = np.random.rand(self.batch_size, self.latent_dim).astype(np.float32)
        self.np_sequence_input = np.random.rand(self.batch_size, *self.input_shape).astype(np.float32)
        # Mask some steps for sample 0
        self.np_sequence_input[0, -3:, :] = 0.0
        # Mask all steps for sample 1
        self.np_sequence_input[1, :, :] = 0.0

    def test_full_forward_pass_with_masking(self):
        """Test the full decoder translation against the original TF model."""
        # TensorFlow execution
        tf_start = time.time()
        tf_output_dist = self.tf_model([self.np_latent_input, self.np_sequence_input], training=False)
        # CORRECTED LINE: Call .mean() on the distribution object first
        tf_output_np = tf_output_dist.mean().numpy()
        tf_end = time.time()


        # PyTorch execution
        pt_latent_tensor = torch.from_numpy(self.np_latent_input)
        pt_sequence_tensor = torch.from_numpy(self.np_sequence_input)
        with torch.no_grad():
            pt_start = time.time()
            pt_dist = self.pt_model(pt_latent_tensor, pt_sequence_tensor)
            # Use the .mean property to get the tensor output
            pt_output = pt_dist.mean
        pt_output_np = pt_output.cpu().numpy()
        pt_end = time.time()

        print("Tensorflow execution time: " + str(tf_end-tf_start))
        print("Pytorch execution time: " + str(pt_end-pt_start))

        # Compare the final tensor outputs
        np.testing.assert_allclose(tf_output_np, pt_output_np, rtol=1e-5, atol=1e-4)
        print("✅ Full `RecurrentDecoderPT` translation test PASSED!")

# To run in a Python script or Jupyter notebook:
if __name__ == '__main__':
    # Add deepof and other necessary imports from the original problem description
    # Then run the test suite
    runner = unittest.TextTestRunner(verbosity=2)
    suite = unittest.TestLoader().loadTestsFromTestCase(TestRecurrentDecoderTranslation)
    runner.run(suite)

test_full_forward_pass_with_masking (__main__.TestRecurrentDecoderTranslation)
Test the full decoder translation against the original TF model. ... ok

----------------------------------------------------------------------
Ran 1 test in 1.539s

OK


Tensorflow execution time: 0.11825919151306152
Pytorch execution time: 0.022947072982788086
✅ Full `RecurrentDecoderPT` translation test PASSED!


# Recurrent Encoder Test

In [11]:
import unittest
import numpy as np
import tcn
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.mixture import GaussianMixture
from spektral.layers import CensNetConv
from tensorflow.keras import Input, Model
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)
from tensorflow.keras.optimizers import Nadam
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import time
import deepof.model_utils
import deepof.clustering.model_utils_new
from deepof.clustering.censNetConv_pt import CensNetConvPT
import deepof.utils
from deepof.data_loading import get_dt
import warnings
from deepof.clustering.model_utils_new import ProbabilisticDecoderPT, RecurrentBlockPT
from torch.distributions import Distribution, TransformedDistribution
from torch.distributions.transforms import AffineTransform

In [12]:
def get_recurrent_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    gru_unroll: bool = False,
    bidirectional_merge: str = "concat",
    interaction_regularization: float = 0.0,
):
    """Return a deep recurrent neural encoder.

     Builds a neural network capable of encoding the motion tracking instances into a vector ready to be fed to
    one of the provided structured latent spaces.

    Args:
        input_shape (tuple): shape of the node features for the input data. Should be time x nodes x features.
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        latent_dim (int): dimension of the latent space.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        gru_unroll (bool): whether to unroll the GRU layers. Defaults to False.
        bidirectional_merge (str): how to merge the forward and backward GRU layers. Defaults to "concat".
        interaction_regularization (float): Regularization parameter for the interaction features.

    Returns:
        keras.Model: a keras model that can be trained to encode motion tracking instances into a vector.

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )


    else:
        x_flat = tf.reshape(x, [-1, input_shape[0], input_shape[1] * input_shape[2]])
        x_reshaped = tf.expand_dims(x_flat, axis=1)

    # Instantiate temporal RNN block
    encoder = deepof.clustering.model_utils_new.get_recurrent_block(
        x_reshaped, latent_dim, gru_unroll, bidirectional_merge
    )(x_reshaped)


    # Instantiate spatial graph block
    if use_gnn:

        # Embed edge features too
        a_encoder = deepof.clustering.model_utils_new.get_recurrent_block(
            a_reshaped, latent_dim, gru_unroll, bidirectional_merge
        )(a_reshaped)
    
        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [encoder, (laplacian, edge_laplacian, incidence), a_encoder], mask=None
        )
        

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        encoder = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        encoder = tf.squeeze(encoder, axis=1)

    encoder_output = tf.keras.layers.Dense(latent_dim, kernel_initializer="he_uniform")(
        encoder
    )
    
    return Model([x, a], encoder_output, name="recurrent_encoder")

In [13]:
class RecurrentEncoderPT(nn.Module):
    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray,
        latent_dim: int,
        use_gnn: bool = True,
        interaction_regularization: float = 0.0,
    ):
        super().__init__()
        self.use_gnn = use_gnn
        self.num_nodes = adjacency_matrix.shape[0]
        self.latent_dim = latent_dim

        if self.use_gnn:
            # Node path initialization 
            node_feat_per_animal = input_shape[-1] // self.num_nodes
            self.node_recurrent_block = RecurrentBlockPT(
                input_features=node_feat_per_animal, latent_dim=latent_dim
            )

            # Edge path initialization 
            self.edge_recurrent_block = RecurrentBlockPT(
                input_features=1, latent_dim=latent_dim
            )

            self.spatial_gnn_block = CensNetConvPT(
                node_channels=latent_dim,
                edge_channels=latent_dim,
            )
            lap, edge_lap, inc = self.spatial_gnn_block.preprocess(torch.tensor(adjacency_matrix))
            self.register_buffer("laplacian", lap.float())
            self.register_buffer("edge_laplacian", edge_lap.float())
            self.register_buffer("incidence", inc.float())
            
            self.num_edges = edge_feature_shape[1]
            final_dense_in = (self.num_nodes * latent_dim) + (self.num_edges * latent_dim)
            self.final_dense = nn.Linear(final_dense_in, latent_dim)

        else: # Non-GNN path 
            in_features = input_shape[1] * input_shape[2]
            self.recurrent_block = RecurrentBlockPT(
                input_features=in_features, latent_dim=latent_dim
            )
            self.final_dense = nn.Linear(latent_dim, latent_dim)

    def forward(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        B, T, N_nodes_total, F_nodes_total = x.shape
        _, _, E_edges_total, F_edges_total = a.shape

        if self.use_gnn:
            # --- Attempt to replicate the exact TensorFlow reshape logic ---
            
            # 1. Node Path
            F_per_node = F_nodes_total // self.num_nodes
            x_t = x.permute(3, 2, 1, 0)
            target_shape_x = (F_per_node, T, self.num_nodes, -1)
            x_reshaped_t = x_t.reshape(target_shape_x)
            x_reshaped = x_reshaped_t.permute(3, 2, 1, 0)
            
            # 2. Edge Path
            a_t = a.permute(3, 2, 1, 0)
            target_shape_a = (1, T, F_edges_total, -1)
            a_reshaped_t = a_t.reshape(target_shape_a)
            a_reshaped = a_reshaped_t.permute(3, 2, 1, 0)

            # 3. Pass through Recurrent Blocks
            node_output = self.node_recurrent_block(x_reshaped)           
            edge_output = self.edge_recurrent_block(a_reshaped)
            
            # 4. GNN and Final Layers
            adj_tuple = (self.laplacian, self.edge_laplacian, self.incidence)
            x_nodes, x_edges = self.spatial_gnn_block(
                [node_output, adj_tuple, edge_output]
            )
            x_nodes=F.relu(x_nodes)
            x_edges=F.relu(x_edges)
            
            b_prime = x_nodes.shape[0]
            x_nodes_flat = x_nodes.view(b_prime, -1)
            x_edges_flat = x_edges.view(b_prime, -1)
            encoder = torch.cat([x_nodes_flat, x_edges_flat], dim=-1)
            

        else: # Non-GNN path 
            x_reshaped = x.view(B, T, N_nodes_total * F_nodes_total).unsqueeze(1)
            encoder = self.recurrent_block(x_reshaped).squeeze(1)

        return self.final_dense(encoder)

In [14]:
def transfer_recurrent_block_weights(tf_model, pt_model):
    """Transfers weights for the full recurrent block with GRU gate permutation."""
    conv_td, _, gru1_td, norm1, gru2_td, norm2 = tf_model.layers[1:]


    def permute_gru_weights(keras_weights):
        W_ih, W_hh, B = keras_weights
        W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
        W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)
        W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
        W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)
        B_ih, B_hh = B
        B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
        B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)
        B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
        B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])
        return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt

    pt_model.conv1d.weight.data = torch.from_numpy(conv_td.layer.get_weights()[0]).permute(2, 1, 0)
    
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(gru1_td.layer.forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1); pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(gru1_td.layer.backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1); pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)

    pt_model.norm1.weight.data = torch.from_numpy(norm1.get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm1.get_weights()[1])

    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(gru2_td.layer.forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2); pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(gru2_td.layer.backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2); pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    
    pt_model.norm2.weight.data = torch.from_numpy(norm2.get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm2.get_weights()[1])

    
def transfer_censnet_weights(tf_layer, pt_layer):
    """
    Transfers all six weights from a Spektral CensNetConv layer to the
    corresponding CensNetConvPT layer.
    """
    # Get all weights from the TensorFlow layer. The order is determined by
    # the layer's build order in Spektral's source code.
    tf_weights = tf_layer.get_weights()

    # Unpack all six weights.
    # Order: kernel_node, bias_node, kernel_edge, bias_edge, projector_node, projector_edge
    kn_tf, bn_tf, ke_tf, be_tf, pn_tf, pe_tf = tf_weights

    # 1. & 2. Transfer Node Kernel and Bias
    # Keras Dense kernel is (in_features, out_features)
    pt_layer.node_kernel.data = torch.from_numpy(kn_tf)
    pt_layer.edge_kernel.data = torch.from_numpy(bn_tf)

    # 3. & 4. Transfer Edge Kernel and Bias
    # Same transposition logic applies.
    pt_layer.node_weights.data = torch.from_numpy(ke_tf)
    pt_layer.edge_weights.data = torch.from_numpy(be_tf)

    # 5. Transfer Node Projector Weights (P_n)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.node_bias.data = torch.from_numpy(pn_tf)

    # 6. Transfer Edge Projector Weights (P_e)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.edge_bias.data = torch.from_numpy(pe_tf)
    

def transfer_recurrent_encoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent encoder, finding layers
    by their default names and types to avoid modifying original code.
    """
    # The final dense layer is consistently the last one in the model's layer list.
    final_dense_tf = tf_model.layers[-1]
    final_dense_pt = pt_model.final_dense
    w, b = final_dense_tf.get_weights()
    final_dense_pt.weight.data = torch.from_numpy(w.T)
    final_dense_pt.bias.data = torch.from_numpy(b)

    if pt_model.use_gnn:
        # Keras automatically names nested models 'model', 'model_1', etc., by order of creation.
        # Node recurrent block is created first.
        node_recurrent_model = tf_model.get_layer("model")
        # Edge recurrent block is created second.
        edge_recurrent_model = tf_model.get_layer("model_1")
        # Find the CensNetConv layer by its class type.
        gnn_layer = next(l for l in tf_model.layers if isinstance(l, CensNetConv))

        transfer_recurrent_block_weights(node_recurrent_model, pt_model.node_recurrent_block)
        transfer_recurrent_block_weights(edge_recurrent_model, pt_model.edge_recurrent_block)
        transfer_censnet_weights(gnn_layer, pt_model.spatial_gnn_block)
    else: # Not using GNN
        # There is only one nested model, which Keras names 'model'.
        recurrent_model = tf_model.get_layer("model")
        transfer_recurrent_block_weights(recurrent_model, pt_model.recurrent_block)

In [15]:
class TestRecurrentEncoderTranslation(unittest.TestCase):
    def setUp(self):
        """Set up parameters and create random data that matches model assumptions."""
        tf.keras.backend.clear_session()
        self.latent_dim = 8
        
        # Inits
        self.b, self.t, self.n, self.f = 2, 10, 3, 12  # Batch, Time, Nodes, Features
        self.e, self.f_edge = 3, 3  # Edges, Edge Features

        self.input_shape = (self.t, self.n, self.f)
        self.edge_shape = (self.t, self.e, self.f_edge)
        self.adj_matrix = np.ones((self.n, self.n)) - np.eye(self.n)

        # Create random input data
        self.x_np = np.random.rand(self.b, self.t, self.n, self.f).astype(np.float32)
        self.a_np = np.random.rand(self.b, self.t, self.e, self.f_edge).astype(np.float32)
        
    def test_forward_pass_gnn(self):
        """Test the GNN-enabled path of the encoder."""
        # Build TF and PT models
        tf_model_gnn = get_recurrent_encoder(
            self.input_shape, self.edge_shape, self.adj_matrix, self.latent_dim, use_gnn=True
        )
        pt_model_gnn = RecurrentEncoderPT(
            self.input_shape, self.edge_shape, self.adj_matrix, self.latent_dim, use_gnn=True
        )
        pt_model_gnn.eval()

        # Run a single "dummy" forward pass on the PyTorch model.
        with torch.no_grad():
            pt_model_gnn(torch.from_numpy(self.x_np), torch.from_numpy(self.a_np))

        # Now that the weights have been initialized, we can transfer the TF values into them.
        transfer_recurrent_encoder_weights(tf_model_gnn, pt_model_gnn)

        # Execute and compare the outputs
        tf_start = time.time()
        tf_output = tf_model_gnn([self.x_np, self.a_np], training=False).numpy()
        tf_end = time.time()
        pt_start = time.time()
        with torch.no_grad():
            pt_output = pt_model_gnn(torch.from_numpy(self.x_np), torch.from_numpy(self.a_np)).detach().numpy()
        pt_end = time.time()

        print("Tensorflow execution time: " + str(tf_end-tf_start))
        print("Pytorch execution time: " + str(pt_end-pt_start))

        np.testing.assert_allclose(tf_output, pt_output, rtol=1e-5, atol=1e-4)
        print("✅ `RecurrentEncoderPT` (GNN path) translation test PASSED!")

    def test_forward_pass_no_gnn(self):
        """Test the non-GNN path of the encoder."""
        # Build TF and PT models
        tf_model_no_gnn = get_recurrent_encoder(
            self.input_shape, self.edge_shape, self.adj_matrix, self.latent_dim, use_gnn=False
        )
        pt_model_no_gnn = RecurrentEncoderPT(
            self.input_shape, self.edge_shape, self.adj_matrix, self.latent_dim, use_gnn=False
        )
        pt_model_no_gnn.eval()

        # Transfer weights
        transfer_recurrent_encoder_weights(tf_model_no_gnn, pt_model_no_gnn)

        # Execute and compare
        tf_output = tf_model_no_gnn([self.x_np, self.a_np], training=False).numpy()
        pt_output = pt_model_no_gnn(torch.from_numpy(self.x_np), torch.from_numpy(self.a_np)).detach().numpy()

        np.testing.assert_allclose(tf_output, pt_output, rtol=1e-5, atol=1e-5)
        print("✅ `RecurrentEncoderPT` (non-GNN path) translation test PASSED!")

# To run:
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestRecurrentEncoderTranslation)
runner.run(suite)

test_forward_pass_gnn (__main__.TestRecurrentEncoderTranslation)
Test the GNN-enabled path of the encoder. ... The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
ok
test_forward_pass_no_gnn (__main__.TestRecurrentEncoderTranslation)
Test the non-GNN path of the encoder. ... 

Tensorflow execution time: 0.09568667411804199
Pytorch execution time: 0.003713369369506836
✅ `RecurrentEncoderPT` (GNN path) translation test PASSED!


ok

----------------------------------------------------------------------
Ran 2 tests in 8.027s

OK


✅ `RecurrentEncoderPT` (non-GNN path) translation test PASSED!


<unittest.runner.TextTestResult run=2 errors=0 failures=0>

# Gaussian Mixture Latent

In [16]:
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
import tensorflow_probability as tfp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from typing import List, Tuple, Dict
import time
import deepof.model_utils

tfd = tfp.distributions

In [17]:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Nadam
import tensorflow_probability as tfp
from tensorflow.keras.layers import Layer # Assuming ClusterControl inherits from this
from typing import List

# These are placeholders for the external utilities used in the original model
# to make the class definition self-contained and runnable.
class ClusterControl(Layer):
    """Placeholder for the custom deepof.model_utils.ClusterControl layer."""
    def __init__(self, batch_size, n_components, encoding_dim, k, **kwargs):
        super().__init__(**kwargs)
    def call(self, inputs: List[tf.Tensor]) -> tf.Tensor:
        # The layer is pass-through for the latent vector
        return inputs[0]

def compute_kmeans_loss(latent_means: tf.Tensor, weight: float, batch_size: int) -> tf.Tensor:
    """Placeholder for the custom deepof.model_utils.compute_kmeans_loss function."""
    gram_matrix = (tf.transpose(latent_means) @ latent_means) / tf.cast(batch_size, tf.float32)
    s = tf.linalg.svd(gram_matrix, compute_uv=False)
    s = tf.sqrt(tf.maximum(s, 1e-9))
    return weight * tf.reduce_mean(s)

# TensorFlow Probability layers
tfpl = tfp.layers
tfd = tfp.distributions


class GaussianMixtureLatent(tf.keras.models.Model):
    """Gaussian Mixture probabilistic latent space model.

    Used to represent the embedding of motion tracking data in a mixture of Gaussians
    with a provided number of components, with means, covariances and weights.
    Implementation based on VaDE (https://arxiv.org/abs/1611.05148)
    and VaDE-SC (https://openreview.net/forum?id=RQ428ZptQfU).

    """

    def __init__(
        self,
        input_shape: tuple,
        n_components: int,
        latent_dim: int,
        batch_size: int,
        kl_warmup: int = 5,
        kl_annealing_mode: str = "linear",
        mc_kl: int = 100,
        mmd_warmup: int = 15,
        mmd_annealing_mode: str = "linear",
        kmeans_loss: float = 0.0,
        reg_cluster_variance: bool = False,
        **kwargs,
    ):
        """Initialize the Gaussian Mixture Latent layer.

        Args:
            input_shape (tuple): shape of the input data
            n_components (int): number of components in the Gaussian mixture.
            latent_dim (int): dimensionality of the latent space.
            batch_size (int): batch size for training.
            kl_warmup (int): number of epochs to warm up the KL divergence.
            kl_annealing_mode (str): mode to use for annealing the KL divergence. Must be one of "linear" and "sigmoid".
            mc_kl (int): number of Monte Carlo samples to use for computing the KL divergence.
            mmd_warmup (int): number of epochs to warm up the MMD.
            mmd_annealing_mode (str): mode to use for annealing the MMD. Must be one of "linear" and "sigmoid".
            kmeans_loss (float): weight of the Gram matrix regularization loss.
            reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
            **kwargs: keyword arguments passed to the parent class

        """
        super(GaussianMixtureLatent, self).__init__(**kwargs)
        self.seq_shape = input_shape[0] 
        self.n_components = n_components
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.kl_warmup = kl_warmup
        self.kl_annealing_mode = kl_annealing_mode
        self.mc_kl = mc_kl
        self.mmd_warmup = mmd_warmup
        self.mmd_annealing_mode = mmd_annealing_mode
        self.kmeans = kmeans_loss
        self.optimizer = Nadam(learning_rate=1e-3, clipvalue=0.75)
        self.reg_cluster_variance = reg_cluster_variance
        self.pretrain = tf.Variable(0.0, name="pretrain", trainable=False)

        # Initialize GM parameters
        self.c_mu = tf.Variable(
            tf.initializers.GlorotNormal()(shape=[self.n_components, self.latent_dim]),
            name="mu_c",
        )
        self.log_c_sigma = tf.Variable(
            tf.initializers.GlorotNormal()([self.n_components, self.latent_dim]),
            name="log_sigma_c",
        )

        # Initialize the Gaussian Mixture prior with the specified number of components
        self.prior = tf.constant(tf.ones([self.n_components]) * (1 / self.n_components))

        # Initialize layers
        self.z_gauss_mean = Dense(
            tfpl.IndependentNormal.params_size(self.latent_dim) // 2,
            name="cluster_means",
            activation="linear",
            kernel_initializer="glorot_uniform",
            activity_regularizer=None,
        )
        self.z_gauss_var = Dense(
            tfpl.IndependentNormal.params_size(self.latent_dim) // 2,
            name="cluster_variances",
            activation="softplus",
            kernel_initializer="glorot_uniform",
            activity_regularizer=tf.keras.regularizers.l1(0.1),
        )

        self.cluster_control_layer = deepof.model_utils.ClusterControl(
            batch_size=self.batch_size,
            n_components=self.n_components,
            encoding_dim=self.latent_dim,
            k=self.n_components,
        )

        # control KL weight
        self.kl_warm_up_iters = tf.cast(
            self.kl_warmup * (self.seq_shape // self.batch_size), tf.int64
        )
        self._kl_weight = tf.Variable(
            1.0, trainable=False, dtype=tf.float32, name="kl_weight"
        )

    def call(self, inputs, training=False, epsilon=None, return_all_outputs_for_testing=False): # pragma: no cover
        """Compute the output of the layer."""
        z_gauss_mean = self.z_gauss_mean(inputs)
        z_gauss_var = self.z_gauss_var(inputs)

        if epsilon is not None:
            # Use deterministic reparameterization for testing
            z_sample = z_gauss_mean + tf.math.sqrt(tf.math.exp(z_gauss_var)) * epsilon
        else:
            # Original stochastic sampling for production
            z_dist = tfd.MultivariateNormalDiag(
                loc=z_gauss_mean, scale_diag=tf.math.sqrt(tf.math.exp(z_gauss_var))
            )
            z_sample = tf.squeeze(z_dist.sample())

        # Compute embedding probabilities given each cluster
        p_z_c = tf.stack(
            [
                tfd.MultivariateNormalDiag(
                    loc=self.c_mu[i, :],
                    scale_diag=tf.math.exp(self.log_c_sigma)[i, :],
                ).log_prob((z_sample if training else z_gauss_mean))
                + 1e-6
                for i in range(self.n_components)
            ],
            axis=-1,
        )

        # Update prior
        prior = self.prior

        # Compute cluster probabilitie given embedding
        z_cat = tf.math.log(prior + 1e-6) + p_z_c
        z_cat = tf.nn.log_softmax(z_cat, axis=-1)
        z_cat = tf.math.exp(z_cat)

        # Add clustering loss
        loss_clustering = -tf.reduce_sum(
            tf.multiply(z_cat, tf.math.softmax(p_z_c, axis=-1)), axis=-1
        ) * (1.0 - tf.cast(self.pretrain, tf.float32))
        loss_prior = -tf.math.reduce_sum(
            tf.math.xlogy(z_cat, 1e-6 + prior), axis=-1
        ) * (1.0 - tf.cast(self.pretrain, tf.float32))

        #self.add_metric(loss_clustering, name="clustering_loss", aggregation="mean")
        #self.add_metric(loss_prior, name="prior_loss", aggregation="mean")

        # Update KL weight based on the current iteration
        if self.kl_warm_up_iters > 0:
            if self.kl_annealing_mode in ["linear", "sigmoid"]:
                self._kl_weight = tf.cast(
                    tf.keras.backend.min(
                        [self.optimizer.iterations / self.kl_warm_up_iters, 1.0]
                    ),
                    tf.float32,
                )
                if self.kl_annealing_mode == "sigmoid":
                    self._kl_weight = tf.math.sigmoid(
                        (2 * self._kl_weight - 1)
                        / (self._kl_weight - self._kl_weight**2)
                    )
            else:
                raise NotImplementedError(
                    "annealing_mode must be one of 'linear' and 'sigmoid'"
                )
        else:
            self._kl_weight = tf.cast(1.0, tf.float32)

        loss_variational_1 = -1 / 2 * tf.reduce_sum(z_gauss_var + 1, axis=-1)
        loss_variational_2 = tf.math.reduce_sum(
            tf.math.xlogy(z_cat, 1e-6 + z_cat), axis=-1
        )
        kl = loss_variational_1 + loss_variational_2 * (
            1.0 - tf.cast(self.pretrain, tf.float32)
        )
        kl_batch = self._kl_weight * kl

        #self.add_metric(self._kl_weight, aggregation="mean", name="kl_weight")
        #self.add_metric(kl, aggregation="mean", name="kl_divergence")

        #self.add_loss(tf.math.reduce_mean(loss_clustering))
        #self.add_loss(tf.math.reduce_mean(loss_prior))
        #self.add_loss(tf.math.reduce_mean(kl_batch))


        # Calculate metrics for potential return
        hard_groups = tf.math.argmax(z_cat, axis=1)
        max_groups = tf.reduce_max(z_cat, axis=1)
        n_populated = tf.cast(tf.shape(tf.unique(tf.reshape(hard_groups, [-1]))[0])[0], tf.float32)
        confidence = tf.reduce_mean(max_groups)

        z = z_sample if training else z_gauss_mean

        if self.n_components > 1:
            z = self.cluster_control_layer([z, z_cat])

        k_loss = 0.0
        if self.kmeans:
            k_loss = deepof.model_utils.compute_kmeans_loss(z, weight=self.kmeans, batch_size=self.batch_size)
            #self.add_loss(k_loss)
            #self.add_metric(k_loss, name="kmeans_loss")

        # MODIFIED: Add a switch for the return value
        if return_all_outputs_for_testing:
            # In test mode, return all computed values for direct comparison
            return z, z_cat, n_populated, confidence, k_loss
        else:
            # In production mode, use side effects (add_loss/add_metric) and return the original signature
            loss_clustering = -tf.reduce_sum(tf.multiply(z_cat, tf.math.softmax(p_z_c, axis=-1)), axis=-1) * (1.0 - tf.cast(self.pretrain, tf.float32))
            loss_prior = -tf.reduce_sum(tf.math.xlogy(z_cat, 1e-6 + self.prior), axis=-1) * (1.0 - tf.cast(self.pretrain, tf.float32))
            self.add_metric(loss_clustering, name="clustering_loss", aggregation="mean")
            self.add_metric(loss_prior, name="prior_loss", aggregation="mean")

            self.add_metric(self._kl_weight, aggregation="mean", name="kl_weight")
            self.add_metric(kl, aggregation="mean", name="kl_divergence")

            self.add_loss(tf.math.reduce_mean(loss_clustering))
            self.add_loss(tf.math.reduce_mean(loss_prior))
            self.add_loss(tf.math.reduce_mean(kl_batch))

            if self.kmeans:
                self.add_loss(k_loss)
                self.add_metric(k_loss, name="kmeans_loss")

            # ... all other add_loss and add_metric calls from the original ...
            return z, z_cat

In [18]:
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

class ClusterControlPT(nn.Module):
    """
    Calculates clustering metrics. This is a pass-through layer for the main
    latent vector `z`, returning it unmodified alongside a dictionary of metrics.
    """
    def __init__(self):
        super().__init__()

    def forward(
        self, z: torch.Tensor, z_cat: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Calculates metrics and passes the latent vector `z` through.

        Args:
            z: The latent vector (batch_size, latent_dim).
            z_cat: Cluster probabilities (batch_size, n_components).

        Returns:
            A tuple containing the unmodified `z` and a dictionary of metrics.
        """
        confidence, hard_groups = torch.max(z_cat, dim=1)
        
        # Calculate the number of unique clusters populated in the batch
        num_populated = torch.unique(hard_groups).numel()
        
        metrics = {
            "number_of_populated_clusters": torch.tensor(
                float(num_populated), device=z.device
            ),
            "confidence_in_selected_cluster": torch.mean(confidence),
        }
        
        return z, metrics

def compute_kmeans_loss_pt(latent_means: torch.Tensor, weight: float) -> torch.Tensor:
    """
    Computes a loss based on the singular values of the Gram matrix of the
    latent vectors, encouraging orthogonality.

    Args:
        latent_means: The latent vectors from the model (batch_size, latent_dim).
        weight: The weight to apply to this loss component.

    Returns:
        The calculated scalar loss tensor.
    """
    batch_size = float(latent_means.shape[0])
    gram_matrix = (latent_means.T @ latent_means) / batch_size
    
    # Compute singular values, which are the square roots of the eigenvalues for a symmetric matrix
    singular_values = torch.linalg.svdvals(gram_matrix)
    
    # Clamp to avoid NaN gradients from sqrt(0)
    penalization = torch.sqrt(torch.clamp(singular_values, min=1e-9))
    
    return weight * torch.mean(penalization)


class GaussianMixtureLatentPT(nn.Module):
    """
    PyTorch implementation of the Gaussian Mixture probabilistic latent space model.
    It embeds data into a latent space and models that space as a mixture of Gaussians.
    """
    def __init__(
        self,
        input_dim: int,
        n_components: int,
        latent_dim: int,
        kmeans: float,
        **kwargs,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.n_components = n_components
        self.latent_dim = latent_dim
        self.kmeans_weight = kmeans

        # --- Trainable Parameters for the GMM components ---
        self.gmm_means = nn.Parameter(torch.empty(n_components, latent_dim))
        self.gmm_log_vars = nn.Parameter(torch.empty(n_components, latent_dim))
        nn.init.xavier_normal_(self.gmm_means)
        nn.init.xavier_normal_(self.gmm_log_vars)

        # --- Encoder Layers to produce the latent distribution ---
        self.encoder_mean = nn.Linear(self.input_dim, self.latent_dim)
        self.encoder_log_var = nn.Linear(self.input_dim, self.latent_dim)

        # --- Non-trainable Buffers ---
        self.register_buffer('prior', torch.ones(n_components) / n_components)
        self.register_buffer('pretrain', torch.tensor(0.0))
        
        # --- Helper Layers ---
        self.cluster_control = ClusterControlPT()

    def _encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encodes the input into mean and log-variance of the latent distribution."""
        z_mean = self.encoder_mean(x)
        z_log_var = self.encoder_log_var(x) # Note: softplus is applied in the forward pass
        return z_mean, z_log_var

    def _reparameterize(
        self, mean: torch.Tensor, var: torch.Tensor, epsilon: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Performs reparameterization.
        MODIFIED to exactly replicate the original TF model's non-standard scale calculation.
        """
        # Original TF logic: scale = sqrt(exp(variance))
        # The 'var' input here is the direct output of the softplus activation.
        scale = torch.sqrt(torch.exp(var))
        
        if epsilon is None:
            epsilon = torch.randn_like(scale)
        return mean + scale * epsilon

    def _calculate_posterior(self, z: torch.Tensor) -> torch.Tensor:
        """Calculates the posterior probability p(c|z) for each sample."""
        # MODIFIED: The GMM parameters from TF are log-std-dev, not log-variance.
        # So we just exponentiate them to get the scale.
        gmm_scale = torch.exp(self.gmm_log_vars)

        gmm_dist = Normal(
            loc=self.gmm_means.unsqueeze(0),
            scale=gmm_scale.unsqueeze(0)
        )
        log_p_z_given_c = gmm_dist.log_prob(z.unsqueeze(1)).sum(dim=-1)
        
        log_p_c_given_z = torch.log(self.prior + 1e-9) + log_p_z_given_c
        
        return F.softmax(log_p_c_given_z, dim=-1)

    def forward(
        self, x: torch.Tensor, epsilon: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        
        z_mean, z_var_raw = self._encode(x)
        z_var = F.softplus(z_var_raw) # Apply activation

        # Pass z_var directly, not z_log_var
        z_sample = self._reparameterize(z_mean, z_var, epsilon)
        # ... rest of the method is the same ...
        z_for_downstream = z_sample if self.training else z_mean
        z_cat = self._calculate_posterior(z_for_downstream)
        z_final, metrics = self.cluster_control(z_for_downstream, z_cat)
        kmeans_loss = torch.tensor(0.0, device=x.device)
        if self.kmeans_weight > 0:
            kmeans_loss = compute_kmeans_loss_pt(z_final, weight=self.kmeans_weight)
        return (z_final, z_cat, metrics["number_of_populated_clusters"], metrics["confidence_in_selected_cluster"], kmeans_loss)

In [19]:
def transfer_gmm_weights(tf_model, pt_model: GaussianMixtureLatentPT):
    """
    Transfers weights from the final TF model to the refactored PT model,
    using the updated attribute names.
    """
    # --- Transfer GMM component parameters ---
    # OLD: pt_model.c_mu
    pt_model.gmm_means.data = torch.from_numpy(tf_model.c_mu.numpy())
    # OLD: pt_model.log_c_sigma
    pt_model.gmm_log_vars.data = torch.from_numpy(tf_model.log_c_sigma.numpy())

    # --- Transfer Encoder layer parameters ---
    tf_mean_weights = tf_model.z_gauss_mean.get_weights()
    # OLD: pt_model.z_gauss_mean
    pt_model.encoder_mean.weight.data = torch.from_numpy(tf_mean_weights[0].T)
    pt_model.encoder_mean.bias.data = torch.from_numpy(tf_mean_weights[1])
    
    tf_var_weights = tf_model.z_gauss_var.get_weights()
    # OLD: pt_model.z_gauss_var
    pt_model.encoder_log_var.weight.data = torch.from_numpy(tf_var_weights[0].T)
    pt_model.encoder_log_var.bias.data = torch.from_numpy(tf_var_weights[1])

In [20]:
class TestGMMFinalSimplified(unittest.TestCase):
    def setUp(self):
        self.input_dim, self.latent_dim, self.n_components, self.batch_size = 64, 16, 5, 4
        self.seq_shape = self.batch_size * 100
        self.kmeans_weight = 0.1

        tf.keras.backend.clear_session()
        # Instantiate the *actual* final TF model
        self.tf_model = GaussianMixtureLatent(
            input_shape=(self.seq_shape, self.input_dim),
            n_components=self.n_components,
            latent_dim=self.latent_dim,
            batch_size=self.batch_size,
            kmeans_loss=self.kmeans_weight
        )
        # Build the model using the test-mode signature
        self.tf_model(
            tf.zeros((1, self.input_dim)), 
            epsilon=tf.zeros((1, self.latent_dim)),
            return_all_outputs_for_testing=True
        )

        # PyTorch model setup remains the same
        self.pt_model = GaussianMixtureLatentPT(
            input_dim=self.input_dim, n_components=self.n_components,
            latent_dim=self.latent_dim, kmeans=self.kmeans_weight
        )
        
        transfer_gmm_weights(self.tf_model, self.pt_model)
        
        self.np_input = np.random.rand(self.batch_size, self.input_dim).astype(np.float32)
        seed = 42
        np.random.seed(seed)
        epsilon_np = np.random.randn(self.batch_size, self.latent_dim).astype(np.float32)
        self.epsilon_tf = tf.convert_to_tensor(epsilon_np)
        self.epsilon_pt = torch.from_numpy(epsilon_np)

    def run_comparison_test(self, training_mode: bool):
        mode_str = "TRAINING" if training_mode else "EVALUATION"
        print(f"\n--- Testing final integration in {mode_str} mode ---")
        
        self.pt_model.train(training_mode)

        tf_start = time.time()
        # Call the TF model with test flags enabled
        tf_z, tf_z_cat, tf_n_pop, tf_conf, tf_kmeans = self.tf_model(
            self.np_input, 
            training=training_mode, 
            epsilon=self.epsilon_tf, 
            return_all_outputs_for_testing=True
        )
        tf_end = time.time()
        
        pt_start = time.time()
        # PyTorch call remains the same
        with torch.no_grad():
            pt_z, pt_z_cat, pt_n_pop, pt_conf, pt_kmeans = self.pt_model(
                torch.from_numpy(self.np_input), epsilon=self.epsilon_pt
            )
        pt_end = time.time()
        
        print("Tensorflow execution time: " + str(tf_end-tf_start))
        print("Pytorch execution time: " + str(pt_end-pt_start))
        
        print("Comparing all outputs...")
        np.testing.assert_allclose(tf_z.numpy(), pt_z.numpy(), rtol=1e-5, atol=1e-5)
        np.testing.assert_allclose(tf_z_cat.numpy(), pt_z_cat.numpy(), rtol=1e-5, atol=1e-5)
        np.testing.assert_allclose(tf_n_pop.numpy(), pt_n_pop.numpy(), rtol=1e-5, atol=1e-5)
        np.testing.assert_allclose(tf_conf.numpy(), pt_conf.numpy(), rtol=1e-5, atol=1e-5)
        np.testing.assert_allclose(tf_kmeans.numpy(), pt_kmeans.numpy(), rtol=1e-5, atol=1e-5)
        print(f"✅ Final integration in {mode_str} mode PASSED!")

    def test_final_pass_train_mode(self):
        self.run_comparison_test(training_mode=True)
    
    def test_final_pass_eval_mode(self):
        self.run_comparison_test(training_mode=False)


# Run the test
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestGMMFinalSimplified)
runner.run(suite)

test_final_pass_eval_mode (__main__.TestGMMFinalSimplified) ... ok
test_final_pass_train_mode (__main__.TestGMMFinalSimplified) ... 


--- Testing final integration in EVALUATION mode ---
Tensorflow execution time: 0.053191423416137695
Pytorch execution time: 0.006725311279296875
Comparing all outputs...
✅ Final integration in EVALUATION mode PASSED!

--- Testing final integration in TRAINING mode ---


ok

----------------------------------------------------------------------
Ran 2 tests in 0.625s

OK


Tensorflow execution time: 0.05488920211791992
Pytorch execution time: 0.0
Comparing all outputs...
✅ Final integration in TRAINING mode PASSED!


<unittest.runner.TextTestResult run=2 errors=0 failures=0>

# Get Vade

In [21]:
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
import tensorflow_probability as tfp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from typing import List, Tuple, Dict, Callable
import time
import deepof.model_utils
from deepof.model_utils import ClusterControl, compute_kmeans_loss, CensNetConv, ProbabilisticDecoder
from deepof.models import get_recurrent_encoder, get_recurrent_decoder, GaussianMixtureLatent, get_TCN_encoder, get_TCN_decoder, get_transformer_encoder, get_transformer_decoder
from deepof.clustering.models_new import RecurrentEncoderPT, RecurrentDecoderPT, GaussianMixtureLatentPT
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)

tfd = tfp.distributions

In [22]:
def get_vade(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool,
    n_components: int,
    batch_size: int = 64,
    kl_warmup: int = 15,
    kl_annealing_mode: str = "sigmoid",
    mc_kl: int = 100,
    kmeans_loss: float = 1.0,
    reg_cluster_variance: bool = False,
    encoder_type: str = "recurrent",
    interaction_regularization: float = 0.0,
):
    """Build a Gaussian mixture variational autoencoder (VaDE) model, adapted to the DeepOF setting.

    Args:
        input_shape (tuple): shape of the input data.
        edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
        adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
        latent_dim (int): dimensionality of the latent space.
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        n_components (int): number of components in the Gaussian mixture.
        batch_size (int): batch size for training.
        kl_warmup (int): Number of iterations during which to warm up the KL divergence.
        kl_annealing_mode (str): mode to use for annealing the KL divergence. Must be one of "linear" and "sigmoid".
        mc_kl (int): number of Monte Carlo samples to use for computing the KL divergence.
        kmeans_loss (float): weight of the Gram matrix loss as described in deepof.model_utils.compute_kmeans_loss.
        reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
        encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
        interaction_regularization (float): weight of the interaction regularization term.

    Returns:
        encoder (tf.keras.Model): connected encoder of the VQ-VAE model. Outputs a vector of shape (latent_dim,).
        decoder (tf.keras.Model): connected decoder of the VQ-VAE model.
        grouper (tf.keras.Model): deep clustering branch of the VQ-VAE model. Outputs a vector of shape (n_components,) for each training instance, corresponding to the soft counts for each cluster.
        vade (tf.keras.Model): complete VaDE model

    """
    if encoder_type == "recurrent":
        encoder = get_recurrent_encoder(
            input_shape=input_shape[1:],
            adjacency_matrix=adjacency_matrix,
            edge_feature_shape=edge_feature_shape[1:],
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_recurrent_decoder(
            input_shape=input_shape[1:], latent_dim=latent_dim
        )

    elif encoder_type == "TCN":
        encoder = get_TCN_encoder(
            input_shape=input_shape[1:],
            adjacency_matrix=adjacency_matrix,
            edge_feature_shape=edge_feature_shape[1:],
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_TCN_decoder(input_shape=input_shape[1:], latent_dim=latent_dim)

    elif encoder_type == "transformer":
        encoder = get_transformer_encoder(
            input_shape[1:],
            edge_feature_shape=edge_feature_shape[1:],
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )
        decoder = get_transformer_decoder(input_shape[1:], latent_dim=latent_dim)

    latent_space = GaussianMixtureLatent(
        input_shape=input_shape[0],
        n_components=n_components,
        latent_dim=latent_dim,
        batch_size=batch_size,
        kl_warmup=kl_warmup,
        kl_annealing_mode=kl_annealing_mode,
        mc_kl=mc_kl,
        kmeans_loss=kmeans_loss,
        reg_cluster_variance=reg_cluster_variance,
        name="gaussian_mixture_latent",
    )

    # Connect encoder and latent space
    inputs = Input(input_shape[1:])
    a = tf.keras.layers.Input(edge_feature_shape[1:], name="encoder_edge_features")
    encoder_outputs = encoder([inputs, a])
    latent, categorical = latent_space(encoder_outputs)
    embedding = tf.keras.Model([inputs, a], latent, name="encoder")
    grouper = tf.keras.Model([inputs, a], categorical, name="grouper")

    # Connect decoder
    vade_outputs = decoder([embedding.outputs, inputs])

    # Instantiate fully connected model
    vade = tf.keras.Model(embedding.inputs, vade_outputs, name="VaDE")

    return embedding, decoder, grouper, vade


In [23]:
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple

# Assume the following translated blocks are imported and available:
# from deepof.clustering.models_new import (
#     RecurrentEncoderPT, RecurrentDecoderPT, GaussianMixtureLatentPT
# )
# And their corresponding TensorFlow versions and weight transfer functions are also available.

class VaDEPT(nn.Module):
    """
    A self-contained PyTorch implementation of the VaDE model.

    This class encapsulates the entire VaDE architecture, including the encoder,
    the Gaussian mixture latent space, and the decoder. It is instantiated with
    all necessary configuration parameters, building its sub-modules internally.
    This provides a clean, single-object interface for the model.
    """
    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray,
        latent_dim: int,
        n_components: int,
        use_gnn: bool = True,
        kmeans_loss: float = 1.0,
        interaction_regularization: float = 0.0,
    ):
        """
        Initializes and builds the VaDE model and its components.

        Args:
            input_shape (tuple): Shape of the input node features (Time, Nodes, Features_per_node).
            edge_feature_shape (tuple): Shape of the edge features (Time, Edges, Features_per_edge).
            adjacency_matrix (np.ndarray): Adjacency matrix of the connectivity graph.
            latent_dim (int): Dimensionality of the latent space.
            n_components (int): Number of components in the Gaussian mixture.
            use_gnn (bool): If True, use the GNN-based encoder.
            kmeans_loss (float): Weight of the k-means style loss in the latent space.
            interaction_regularization (float): Regularization for GNN interaction features.
        """
        super().__init__()
        
        # Store key dimensions for internal use (e.g., reshaping in forward pass)
        time_steps, n_nodes, n_features_per_node = input_shape
        self.input_n_nodes = n_nodes
        self.input_n_features_per_node = n_features_per_node

        # 1. Instantiate Encoder
        self.encoder = RecurrentEncoderPT(
            input_shape=input_shape,
            edge_feature_shape=edge_feature_shape,
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )

        # 2. Instantiate Latent Space
        self.latent_space = GaussianMixtureLatentPT(
            input_dim=latent_dim,
            n_components=n_components,
            latent_dim=latent_dim,
            kmeans=kmeans_loss,
        )

        # 3. Instantiate Decoder
        decoder_output_features = n_nodes * n_features_per_node
        self.decoder = RecurrentDecoderPT(
            output_shape=(time_steps, decoder_output_features),
            latent_dim=latent_dim,
        )

    def forward(
        self, x: torch.Tensor, a: torch.Tensor
    ) -> Tuple[torch.distributions.Distribution, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Defines the full forward pass for the VaDE model (training and evaluation).

        Args:
            x (torch.Tensor): Input node features tensor (B, T, N, F_node).
            a (torch.Tensor): Input edge features tensor (B, T, E, F_edge).

        Returns:
            A tuple containing:
            - reconstruction_dist (torch.distributions.Distribution): The output distribution from the decoder.
            - latent (torch.Tensor): The sampled latent representation from the GMM space.
            - categorical (torch.Tensor): The cluster probabilities (soft assignments).
            - kmeans_loss (torch.Tensor): The k-means regularization loss from the latent space.
        """
        # 1. Encode the input to get the pre-latent representation
        encoder_output = self.encoder(x, a)
        
        # 2. Pass through GMM latent space
        latent, categorical, _, _, kmeans_loss, gmm_params = self.latent_space(encoder_output)
        
        # 3. Decode the latent sample back to the original data space
        # Reshape x to (B, T, N*F) for the decoder's masking logic
        B, T, _, _ = x.shape
        x_for_decoder = x.view(B, T, self.input_n_nodes * self.input_n_features_per_node)
        
        reconstruction_dist = self.decoder(latent, x_for_decoder)
        
        return reconstruction_dist, latent, categorical, kmeans_loss

    @torch.no_grad()
    def embed(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        """
        Inference-only method to get the latent embedding. Equivalent to the 'embedding' Keras model.

        Args:
            x (torch.Tensor): Input node features tensor.
            a (torch.Tensor): Input edge features tensor.

        Returns:
            torch.Tensor: The latent representation `z`.
        """
        encoder_output = self.encoder(x, a)
        latent, _, _, _, _, _ = self.latent_space(encoder_output)
        return latent

    @torch.no_grad()
    def group(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        """
        Inference-only method to get cluster probabilities. Equivalent to the 'grouper' Keras model.

        Args:
            x (torch.Tensor): Input node features tensor.
            a (torch.Tensor): Input edge features tensor.

        Returns:
            torch.Tensor: The soft cluster assignments (categorical probabilities).
        """
        encoder_output = self.encoder(x, a)
        _, categorical, _, _, _, _ = self.latent_space(encoder_output)
        return categorical

In [24]:
def transfer_recurrent_block_weights(tf_model, pt_model):
    """Transfers weights for the full recurrent block with GRU gate permutation."""
    conv_td, _, gru1_td, norm1, gru2_td, norm2 = tf_model.layers[1:]


    def permute_gru_weights(keras_weights):
        W_ih, W_hh, B = keras_weights
        W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
        W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)
        W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
        W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)
        B_ih, B_hh = B
        B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
        B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)
        B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
        B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])
        return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt

    pt_model.conv1d.weight.data = torch.from_numpy(conv_td.layer.get_weights()[0]).permute(2, 1, 0)
    
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(gru1_td.layer.forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1); pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(gru1_td.layer.backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1); pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)

    pt_model.norm1.weight.data = torch.from_numpy(norm1.get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm1.get_weights()[1])

    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(gru2_td.layer.forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2); pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(gru2_td.layer.backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2); pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    
    pt_model.norm2.weight.data = torch.from_numpy(norm2.get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm2.get_weights()[1])

    
def transfer_censnet_weights(tf_layer, pt_layer):
    """
    Transfers all six weights from a Spektral CensNetConv layer to the
    corresponding CensNetConvPT layer.
    """
    # Get all weights from the TensorFlow layer. The order is determined by
    # the layer's build order in Spektral's source code.
    tf_weights = tf_layer.get_weights()

    # Unpack all six weights.
    # Order: kernel_node, bias_node, kernel_edge, bias_edge, projector_node, projector_edge
    kn_tf, bn_tf, ke_tf, be_tf, pn_tf, pe_tf = tf_weights

    # Build weights on first pass
    if pt_layer.node_kernel is None:
        # Move parameters to the same device as input tensors
        pt_layer._build(kn_tf.T.shape, bn_tf.T.shape)
        #pt_layer.to(kn_tf.device)

    # 1. & 2. Transfer Node Kernel and Bias
    # Keras Dense kernel is (in_features, out_features)
    pt_layer.node_kernel.data = torch.from_numpy(kn_tf)
    pt_layer.edge_kernel.data = torch.from_numpy(bn_tf)

    # 3. & 4. Transfer Edge Kernel and Bias
    # Same transposition logic applies.
    pt_layer.node_weights.data = torch.from_numpy(ke_tf)
    pt_layer.edge_weights.data = torch.from_numpy(be_tf)

    # 5. Transfer Node Projector Weights (P_n)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.node_bias.data = torch.from_numpy(pn_tf)

    # 6. Transfer Edge Projector Weights (P_e)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.edge_bias.data = torch.from_numpy(pe_tf)
    

def transfer_recurrent_encoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent encoder, finding layers
    by their default names and types to avoid modifying original code.
    """
    # The final dense layer is consistently the last one in the model's layer list.
    final_dense_tf = tf_model.layers[-1]
    final_dense_pt = pt_model.final_dense
    w, b = final_dense_tf.get_weights()
    final_dense_pt.weight.data = torch.from_numpy(w.T)
    final_dense_pt.bias.data = torch.from_numpy(b)

    if pt_model.use_gnn:
        # Keras automatically names nested models 'model', 'model_1', etc., by order of creation.
        # Node recurrent block is created first.
        node_recurrent_model = tf_model.get_layer("model")
        # Edge recurrent block is created second.
        edge_recurrent_model = tf_model.get_layer("model_1")
        # Find the CensNetConv layer by its class type.
        gnn_layer = next(l for l in tf_model.layers if isinstance(l, CensNetConv))

        transfer_recurrent_block_weights(node_recurrent_model, pt_model.node_recurrent_block)
        transfer_recurrent_block_weights(edge_recurrent_model, pt_model.edge_recurrent_block)
        transfer_censnet_weights(gnn_layer, pt_model.spatial_gnn_block)
    else: # Not using GNN
        # There is only one nested model, which Keras names 'model'.
        recurrent_model = tf_model.get_layer("model")
        transfer_recurrent_block_weights(recurrent_model, pt_model.recurrent_block)

In [25]:
def transfer_gmm_weights(tf_model, pt_model: GaussianMixtureLatentPT):
    """
    Transfers weights from the final TF model to the refactored PT model,
    using the updated attribute names.
    """
    # --- Transfer GMM component parameters ---
    # OLD: pt_model.c_mu
    pt_model.gmm_means.data = torch.from_numpy(tf_model.c_mu.numpy())
    # OLD: pt_model.log_c_sigma
    pt_model.gmm_log_vars.data = torch.from_numpy(tf_model.log_c_sigma.numpy())

    # --- Transfer Encoder layer parameters ---
    tf_mean_weights = tf_model.z_gauss_mean.get_weights()
    # OLD: pt_model.z_gauss_mean
    pt_model.encoder_mean.weight.data = torch.from_numpy(tf_mean_weights[0].T)
    pt_model.encoder_mean.bias.data = torch.from_numpy(tf_mean_weights[1])
    
    tf_var_weights = tf_model.z_gauss_var.get_weights()
    # OLD: pt_model.z_gauss_var
    pt_model.encoder_log_var.weight.data = torch.from_numpy(tf_var_weights[0].T)
    pt_model.encoder_log_var.bias.data = torch.from_numpy(tf_var_weights[1])

In [26]:
# Helper function from the provided example to handle gate order differences
def permute_gru_weights(keras_weights):
    """Permutes GRU weights from Keras (z, r, n) to PyTorch (r, z, n) format."""
    W_ih, W_hh, B = keras_weights
    # Keras gate order: z, r, n (update, reset, new/candidate)
    W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
    W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)

    # PyTorch gate order: r, z, n (reset, update, new/candidate)
    W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
    W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)

    # Keras has two bias vectors (input-hidden and recurrent), which are concatenated in B
    B_ih, B_hh = B
    B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
    B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)

    B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
    B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])

    return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt
    
def transfer_recurrent_decoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent decoder model.
    """
    # Find layers by type to avoid index issues
    bidi_layers = [l for l in tf_model.layers if isinstance(l, Bidirectional)]
    norm_layers = [l for l in tf_model.layers if isinstance(l, LayerNormalization)]
    conv_layers = [l for l in tf_model.layers if isinstance(l, tf.keras.layers.Conv1D)]
    prob_dec_layer = next(l for l in tf_model.layers if isinstance(l, deepof.model_utils.ProbabilisticDecoder))

    # --- GRU 1 and Norm 1 ---
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(bidi_layers[0].forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1)
    pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(bidi_layers[0].backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1)
    pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)
    pt_model.norm1.weight.data = torch.from_numpy(norm_layers[0].get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm_layers[0].get_weights()[1])

    # --- GRU 2 and Norm 2 ---
    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(bidi_layers[1].forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2)
    pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(bidi_layers[1].backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2)
    pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    pt_model.norm2.weight.data = torch.from_numpy(norm_layers[1].get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm_layers[1].get_weights()[1])

    # --- Conv1D and Norm 3 ---
    # TF Conv1D weights: (kernel_w, kernel_h, in_c, out_c) -> (5, 1, 4*ld, 2*ld)
    # PT Conv1d weights: (out_c, in_c, kernel_w)
    conv_weights_tf = conv_layers[0].get_weights()[0]
    pt_model.conv1d.weight.data = torch.from_numpy(conv_weights_tf).squeeze(1).permute(2, 1, 0)
    pt_model.norm3.weight.data = torch.from_numpy(norm_layers[2].get_weights()[0]); pt_model.norm3.bias.data = torch.from_numpy(norm_layers[2].get_weights()[1])

    # --- Probabilistic Decoder ---
    # TF Dense weights: (in_features, out_features)
    # PT Linear weights: (out_features, in_features)
    prob_dec_weights, prob_dec_bias = prob_dec_layer.time_distributer.get_weights()
    pt_model.prob_decoder.loc_projection.weight.data = torch.from_numpy(prob_dec_weights.T)
    pt_model.prob_decoder.loc_projection.bias.data = torch.from_numpy(prob_dec_bias)

In [27]:
# Imports and Mocks from the previous response are assumed to be present
import unittest
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
import time
import deepof.clustering.models_new
# End of Mocks


def transfer_vade_class_weights(tf_vade_model, tf_decoder_model, pt_vade_model: VaDEPT):
    """
    Transfers weights from a full TensorFlow VaDE model to the self-contained PyTorch VaDEPT class.
    """
    print("Transferring weights for all VaDE components...")
    
    # 1. Get the inner Keras models/layers by name from the complete TF model
    tf_encoder_inner = tf_vade_model.get_layer("recurrent_encoder")
    tf_latent_layer = tf_vade_model.get_layer("gaussian_mixture_latent")
    
    # 2. Use the specialized weight transfer functions, passing the PT sub-modules
    print("  -> Transferring Encoder weights...")
    transfer_recurrent_encoder_weights(tf_encoder_inner, pt_vade_model.encoder)
    print("  -> Transferring GMM Latent weights...")
    transfer_gmm_weights(tf_latent_layer, pt_vade_model.latent_space)
    print("  -> Transferring Decoder weights...")
    transfer_recurrent_decoder_weights(tf_decoder_model, pt_vade_model.decoder)
    
    print("Weight transfer complete.")


class TestVaDETranslation(unittest.TestCase):
    def setUp(self):
        """Set up parameters, models, and data for testing."""
        tf.keras.backend.clear_session()
        tf.keras.backend.set_epsilon(1e-3)

        # --- 1. Define Fundamental Dimensions ---
        self.batch_size = 128
        self.window_length = 25
        self.num_nodes = 11
        # In your example, total features (n=33) / num_nodes (11) = 3
        self.features_per_node = 33
        self.num_edges = 11
        self.features_per_edge = 111 # Assuming 1 feature per edge

        # --- 2. Define Model Parameters ---
        self.latent_dim = 6
        self.n_components = 10
        self.kmeans_loss = 1.0
        self.use_gnn = False

        # --- 3. Create Adjacency Matrix ---
        m = np.zeros((self.num_nodes, self.num_nodes))
        ui = np.triu_indices(self.num_nodes)
        num_possible_edges = len(ui[0])
        c = np.random.choice(num_possible_edges, min(self.num_edges, num_possible_edges), replace=False)
        m[ui[0][c], ui[1][c]] = 1
        m += m.T # Make symmetric
        self.adj_matrix = m

        # --- 4. Create Framework-Specific Shapes for Model Instantiation ---
        
        # TensorFlow expects (batch, time, total_features)
        self.input_shape_tf = (self.batch_size, self.window_length, self.num_nodes * self.features_per_node)
        self.edge_feature_shape_tf = (self.batch_size, self.window_length, self.num_edges * self.features_per_edge)
        
        # PyTorch VaDEPT expects (time, nodes, features_per_node) for a SINGLE sample
        self.input_shape_pt = (self.window_length, self.num_nodes, self.features_per_node)
        self.edge_feature_shape_pt = (self.window_length, self.num_edges, self.features_per_edge)

        # --- 5. Instantiate Models ---
        self.tf_embedding, self.tf_decoder, self.tf_grouper, self.tf_vade = get_vade(
            input_shape=self.input_shape_tf,
            edge_feature_shape=self.edge_feature_shape_tf,
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=self.use_gnn,
            n_components=self.n_components,
            batch_size=self.batch_size,
            kmeans_loss=self.kmeans_loss
        )
        
        self.pt_vade = VaDEPT(
            input_shape=self.input_shape_pt,
            edge_feature_shape=self.edge_feature_shape_pt,
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            n_components=self.n_components,
            use_gnn=self.use_gnn,
            kmeans_loss=self.kmeans_loss
        )
        self.pt_vade.eval()

        # --- 6. Prepare Data Tensors for Each Framework ---
        np.random.seed(42)
        # The "canonical" data is 4D, as expected by the new PyTorch models
        self.x_np_4d = np.random.rand(
            self.batch_size, self.window_length, self.num_nodes, self.features_per_node
        ).astype(np.float32)
        self.a_np_4d = np.random.rand(
            self.batch_size, self.window_length, self.num_edges, self.features_per_edge
        ).astype(np.float32)

        # Create the 3D version for the legacy TensorFlow model by reshaping
        self.x_np_tf = self.x_np_4d.reshape(self.input_shape_tf)
        self.a_np_tf = self.a_np_4d.reshape(self.edge_feature_shape_tf)
        
        # --- 7. Transfer Weights ---
        transfer_vade_class_weights(self.tf_vade, self.tf_decoder, self.pt_vade)

    def test_full_model_and_parts(self):
        """Test the forward pass and helper methods of the VaDEPT class."""
        print("\n--- Testing Self-Contained VaDEPT Class Translation ---")
        
        # --- TensorFlow Execution (with its required 3D input) ---
        tf_start = time.time()
        tf_rec_dist = self.tf_vade([self.x_np_tf, self.a_np_tf], training=False)
        tf_rec_mean = tf_rec_dist.mean().numpy()
        tf_lat_out = self.tf_embedding([self.x_np_tf, self.a_np_tf], training=False).numpy()
        tf_cat_out = self.tf_grouper([self.x_np_tf, self.a_np_tf], training=False).numpy()
        tf_end = time.time()
        
        # --- PyTorch Execution (with its required 4D input) ---
        x_pt = torch.from_numpy(self.x_np_4d)
        a_pt = torch.from_numpy(self.a_np_4d)
        
        pt_start = time.time()
        with torch.no_grad():
            pt_rec_dist, _, _, _ = self.pt_vade(x_pt, a_pt)
            pt_rec_mean = pt_rec_dist.mean.numpy() 
            pt_lat_out = self.pt_vade.embed(x_pt, a_pt).numpy()
            pt_cat_out = self.pt_vade.group(x_pt, a_pt).numpy()
        pt_end = time.time()

        print(f"TensorFlow execution time: {tf_end - tf_start:.6f}s")
        print(f"PyTorch execution time: {pt_end - pt_start:.6f}s")
        
        # --- Assertions ---
        print("\nComparing latent space embeddings (from .embed() vs 'embedding' model)...")
        # Both outputs should be (batch_size, latent_dim), so (128, 6)
        np.testing.assert_allclose(tf_lat_out, pt_lat_out, rtol=1e-5, atol=1e-4)
        print("✅ Latent embeddings match.")

        print("Comparing categorical probabilities (from .group() vs 'grouper' model)...")
        # Both outputs should be (batch_size, n_components), so (128, 10)
        np.testing.assert_allclose(tf_cat_out, pt_cat_out, rtol=1e-5, atol=1e-5)
        print("✅ Categorical probabilities match.")
        
        print("Comparing final reconstruction means (from forward() vs 'vade' model)...")
        # Both outputs should be (batch_size, time_steps, total_features), so (128, 25, 33)
        np.testing.assert_allclose(tf_rec_mean, pt_rec_mean, rtol=1e-5, atol=1e-4)
        print("✅ Reconstructions match.")

        print("\n✅ Self-contained VaDEPT class translation test PASSED!")

# To run the test
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestVaDETranslation)
runner.run(suite)

test_full_model_and_parts (__main__.TestVaDETranslation)
Test the forward pass and helper methods of the VaDEPT class. ... 

Transferring weights for all VaDE components...
  -> Transferring Encoder weights...
  -> Transferring GMM Latent weights...
  -> Transferring Decoder weights...
Weight transfer complete.

--- Testing Self-Contained VaDEPT Class Translation ---


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
ERROR

ERROR: test_full_model_and_parts (__main__.TestVaDETranslation)
Test the forward pass and helper methods of the VaDEPT class.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Petron\AppData\Local\Temp\ipykernel_16880\2925459415.py", line 130, in test_full_model_and_parts
    pt_rec_dist, _, _, _ = self.pt_vade(x_pt, a_pt)
  File "c:\Users\Petron\Desktop\Python_Projects\Deepof\dof\lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\Petron\Desktop\Python_Projects\Deepof\dof\lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\Petron\AppData\Local\Temp\ipykernel_16880\3492692028.py", line 98, in forward
    latent, categorical, _, _, kmeans

<unittest.runner.TextTestResult run=1 errors=1 failures=0>

# Full Model

In [28]:
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
import tensorflow_probability as tfp
from tensorflow.keras.optimizers import Nadam
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from typing import List, Tuple, Dict, Callable
import time
import deepof.model_utils
from spektral.layers import CensNetConv
from deepof.model_utils import ClusterControl, compute_kmeans_loss, ProbabilisticDecoder
import deepof.models
from deepof.models import get_recurrent_encoder, get_recurrent_decoder, GaussianMixtureLatent, get_TCN_encoder, get_TCN_decoder, get_transformer_encoder, get_transformer_decoder
from deepof.clustering.models_new import RecurrentEncoderPT, RecurrentDecoderPT, GaussianMixtureLatentPT
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)

from deepof.data_loading import get_dt

tfd = tfp.distributions

In [29]:
class VaDE(tf.keras.models.Model):
    """Gaussian Mixture Variational Autoencoder for pose motif elucidation."""

    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray = None,
        latent_dim: int = 8,
        use_gnn: bool = True,
        n_components: int = 15,
        batch_size: int = 64,
        kl_annealing_mode: str = "linear",
        kl_warmup_epochs: int = 15,
        montecarlo_kl: int = 100,
        kmeans_loss: float = 1.0,
        reg_cat_clusters: float = 1.0,
        reg_cluster_variance: bool = False,
        encoder_type: str = "recurrent",
        interaction_regularization: float = 0.0,
        **kwargs,
    ):
        """Init a VaDE model.

        Args:
            input_shape (tuple): Shape of the input to the full model.
            edge_feature_shape (tuple): shape of the edge feature matrix used for graph representations.
            adjacency_matrix (np.ndarray): adjacency matrix of the connectivity graph to use.
            batch_size (int): Batch size for training.
            latent_dim (int): Dimensionality of the latent space.
            use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
            kl_annealing_mode (str): Annealing mode for KL annealing. Can be one of 'linear' and 'sigmoid'.
            kl_warmup_epochs (int): Number of epochs to warmup KL annealing.
            montecarlo_kl (int): Number of Monte Carlo samples for KL divergence.
            n_components (int): Number of mixture components in the latent space.
            kmeans_loss (float): weight of the gram matrix regularization loss.
            reg_cat_clusters (bool): whether to use the penalized uneven cluster membership in the latent space, by minimizing the KL divergence between cluster membership and a uniform categorical distribution.
            reg_cluster_variance (bool): whether to penalize uneven cluster variances in the latent space.
            encoder_type (str): type of encoder to use. Can be set to "recurrent" (default), "TCN", or "transformer".
            interaction_regularization (float): Regularization parameter for the interaction features.
            **kwargs: Additional keyword arguments.

        """
        super(VaDE, self).__init__(**kwargs)
        self.seq_shape = input_shape
        self.edge_feature_shape = edge_feature_shape
        self.adjacency_matrix = adjacency_matrix
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        self.use_gnn = use_gnn
        self.kl_annealing_mode = kl_annealing_mode
        self.kl_warmup = kl_warmup_epochs
        self.mc_kl = montecarlo_kl
        self.n_components = n_components
        self.optimizer = Nadam(learning_rate=1e-3, clipvalue=0.75)
        self.kmeans = kmeans_loss
        self.reg_cat_clusters = reg_cat_clusters
        self.reg_cluster_variance = reg_cluster_variance
        self.encoder_type = encoder_type
        self.interaction_regularization = interaction_regularization

        # Define VaDE model
        self.encoder, self.decoder, self.grouper, self.vade = deepof.models.get_vade(
            input_shape=self.seq_shape,
            edge_feature_shape=self.edge_feature_shape,
            adjacency_matrix=self.adjacency_matrix,
            n_components=self.n_components,
            latent_dim=self.latent_dim,
            use_gnn=use_gnn,
            batch_size=self.batch_size,
            kl_warmup=self.kl_warmup,
            kl_annealing_mode=self.kl_annealing_mode,
            mc_kl=self.mc_kl,
            kmeans_loss=self.kmeans,
            reg_cluster_variance=self.reg_cluster_variance,
            encoder_type=self.encoder_type,
            interaction_regularization=self.interaction_regularization,
        )

        # Propagate the optimizer to all relevant sub-models, to enable metric annealing
        self.vade.optimizer = self.optimizer
        self.vade.get_layer("gaussian_mixture_latent").optimizer = self.optimizer

        # Define metrics to track

        # Track all loss function components
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.val_total_loss_tracker = tf.keras.metrics.Mean(name="val_total_loss")

        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.val_reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="val_reconstruction_loss"
        )

        if self.reg_cat_clusters:
            self.cat_cluster_loss_tracker = tf.keras.metrics.Mean(
                name="cat_cluster_loss"
            )
            self.val_cat_cluster_loss_tracker = tf.keras.metrics.Mean(
                name="val_cat_cluster_loss"
            )

    @property
    def metrics(self):  # pragma: no cover
        """Initializes tracked metrics of VaDE model."""
        metrics = [
            self.total_loss_tracker,
            self.val_total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.val_reconstruction_loss_tracker,
        ]

        if self.reg_cat_clusters:
            metrics += [
                self.cat_cluster_loss_tracker,
                self.val_cat_cluster_loss_tracker,
            ]

        return metrics

    @property
    def get_gmm_params(self):
        """Return the GMM parameters of the model."""
        # Get GMM parameters
        return {
            "means": self.grouper.get_layer("gaussian_mixture_latent").c_mu,
            "sigmas": tf.math.exp(
                self.grouper.get_layer("gaussian_mixture_latent").log_c_sigma
            ),
            "weights": tf.math.softmax(
                self.grouper.get_layer("gaussian_mixture_latent").prior
            ),
        }

    def set_pretrain_mode(self, switch):
        """Set the pretrain mode of the model."""
        self.grouper.get_layer("gaussian_mixture_latent").pretrain.assign(switch)

    def pretrain(
        self,
        data,
        embed_x,
        embed_a,
        epochs=10,
        samples=10000,
        gmm_initialize=True,
        **kwargs,
    ):
        """Run a GMM directed pretraining of the encoder, to minimize the likelihood of getting stuck in a local minimum."""
        # Turn on pretrain mode
        self.set_pretrain_mode(1.0)

        # pre-train
        self.fit(
            data,
            epochs=epochs,
            **kwargs,
        )


        # Turn off pretrain mode
        self.set_pretrain_mode(0.0)

        if gmm_initialize:

            with tf.device("CPU"):
                # Get embedding samples
                em_x=get_dt(embed_x, 'embed_x')
                em_a=get_dt(embed_a, 'embed_a')

                emb_idx = np.random.choice(range(em_x.shape[0]), samples)

                # map to latent
                z = self.encoder([em_x[emb_idx], em_a[emb_idx]])
                
                del em_x
                del em_a
                del emb_idx

                # fit GMM
                gmm = deepof.models.GaussianMixture(
                    n_components=self.n_components,
                    covariance_type="diag",
                    reg_covar=1e-04,
                    **kwargs,
                ).fit(z)
                # get GMM parameters
                mu = gmm.means_
                sigma2 = gmm.covariances_

            # initialize mixture components
            self.grouper.get_layer("gaussian_mixture_latent").c_mu.assign(
                tf.convert_to_tensor(value=mu, dtype=tf.float32)
            )
            self.grouper.get_layer("gaussian_mixture_latent").log_c_sigma.assign(
                tf.math.log(
                    tf.math.sqrt(tf.convert_to_tensor(value=sigma2, dtype=tf.float32))
                )
            )

    @tf.function
    def call(self, inputs, **kwargs):
        """Call the VaDE model."""
        return self.vade(inputs, **kwargs)

    def train_step(self, data):  # pragma: no cover
        """Perform a training step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        with tf.GradientTape() as tape:

            # Get outputs from the full model
            outputs = self.vade([x, a], training=True)

            # Get rid of the attention scores that the transformer decoder outputs
            if self.encoder_type == "transformer":
                outputs = outputs[0]

            if isinstance(outputs, list):
                reconstructions = outputs[0]
            else:
                reconstructions = outputs

            # Regularize embeddings
            # groups = self.grouper(x, training=True)

            # Compute losses
            seq_inputs = next(y)
            total_loss = sum(self.vade.losses)

            # Add a regularization term to the soft_counts, to prevent the embedding layer from
            # collapsing into a few clusters.
            if self.reg_cat_clusters:

                soft_counts = self.grouper([x, a], training=True)
                soft_counts_regulrization = (
                    self.reg_cat_clusters
                    * deepof.model_utils.cluster_frequencies_regularizer(
                        soft_counts=soft_counts, k=self.n_components
                    )
                )
                total_loss += soft_counts_regulrization

            # Compute reconstruction loss
            reconstruction_loss = -tf.reduce_mean(reconstructions.log_prob(seq_inputs))
            total_loss += reconstruction_loss

        # Backpropagation
        grads = tape.gradient(total_loss, self.vade.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vade.trainable_variables))

        # Track losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
        }

        if self.reg_cat_clusters:
            self.cat_cluster_loss_tracker.update_state(soft_counts_regulrization)
            log_dict["cat_cluster_loss"] = self.cat_cluster_loss_tracker.result()

        # Log to TensorBoard, both explicitly and implicitly (within model) tracked metrics
        return {**log_dict, **{met.name: met.result() for met in self.vade.metrics}}

    # noinspection PyUnboundLocalVariable
    @tf.function
    def test_step(self, data):  # pragma: no cover
        """Performs a test step."""
        # Unpack data, repacking labels into a generator
        x, a, y = data
        if not isinstance(y, tuple):
            y = [y]
        y = (labels for labels in y)

        # Get outputs from the full model
        outputs = self.vade([x, a], training=False)

        # Get rid of the attention scores that the transformer decoder outputs
        if self.encoder_type == "transformer":
            outputs = outputs[0]

        if isinstance(outputs, list):
            reconstructions = outputs[0]
        else:
            reconstructions = outputs

        # Compute losses
        seq_inputs = next(y)
        total_loss = sum(self.vade.losses)

        # Add a regularization term to the soft_counts, to prevent the embedding layer from
        # collapsing into a few clusters.
        if self.reg_cat_clusters:
            soft_counts = self.grouper([x, a], training=False)
            soft_counts_regulrization = (
                self.reg_cat_clusters
                * deepof.model_utils.cluster_frequencies_regularizer(
                    soft_counts=soft_counts, k=self.n_components
                )
            )
            total_loss += soft_counts_regulrization

        # Compute reconstruction loss
        reconstruction_loss = -tf.reduce_mean(reconstructions.log_prob(seq_inputs))
        total_loss += reconstruction_loss

        # Track losses
        self.val_total_loss_tracker.update_state(total_loss)
        self.val_reconstruction_loss_tracker.update_state(reconstruction_loss)

        # Log results (coupled with TensorBoard)
        log_dict = {
            "total_loss": self.val_total_loss_tracker.result(),
            "reconstruction_loss": self.val_reconstruction_loss_tracker.result(),
        }

        if self.reg_cat_clusters:
            self.val_cat_cluster_loss_tracker.update_state(soft_counts_regulrization)
            log_dict["cat_cluster_loss"] = self.val_cat_cluster_loss_tracker.result()

        return {**log_dict, **{met.name: met.result() for met in self.vade.metrics}}


In [30]:
class VaDEPT(nn.Module):
    """
    A self-contained PyTorch implementation of the VaDE model.

    This class encapsulates the entire VaDE architecture, including the encoder,
    the Gaussian mixture latent space, and the decoder. It is instantiated with
    all necessary configuration parameters, building its sub-modules internally.
    This provides a clean, single-object interface for the model.
    """
    def __init__(
        self,
        input_shape: tuple,
        edge_feature_shape: tuple,
        adjacency_matrix: np.ndarray,
        latent_dim: int,
        n_components: int,
        use_gnn: bool = True,
        kmeans_loss: float = 1.0,
        interaction_regularization: float = 0.0,
    ):
        """
        Initializes and builds the VaDE model and its components.

        Args:
            input_shape (tuple): Shape of the input node features (Time, Nodes, Features_per_node).
            edge_feature_shape (tuple): Shape of the edge features (Time, Edges, Features_per_edge).
            adjacency_matrix (np.ndarray): Adjacency matrix of the connectivity graph.
            latent_dim (int): Dimensionality of the latent space.
            n_components (int): Number of components in the Gaussian mixture.
            use_gnn (bool): If True, use the GNN-based encoder.
            kmeans_loss (float): Weight of the k-means style loss in the latent space.
            interaction_regularization (float): Regularization for GNN interaction features.
        """
        super().__init__()
        
        # Store key dimensions for internal use (e.g., reshaping in forward pass)
        time_steps, n_nodes, n_features_per_node = input_shape
        self.input_n_nodes = n_nodes
        self.input_n_features_per_node = n_features_per_node

        # 1. Instantiate Encoder
        self.encoder = RecurrentEncoderPT(
            input_shape=input_shape,
            edge_feature_shape=edge_feature_shape,
            adjacency_matrix=adjacency_matrix,
            latent_dim=latent_dim,
            use_gnn=use_gnn,
            interaction_regularization=interaction_regularization,
        )

        # 2. Instantiate Latent Space
        self.latent_space = GaussianMixtureLatentPT(
            input_dim=latent_dim,
            n_components=n_components,
            latent_dim=latent_dim,
            kmeans=kmeans_loss,
        )

        # 3. Instantiate Decoder
        decoder_output_features = n_nodes * n_features_per_node
        self.decoder = RecurrentDecoderPT(
            output_shape=(time_steps, decoder_output_features),
            latent_dim=latent_dim,
        )

    def forward(
        self, x: torch.Tensor, a: torch.Tensor
    ) -> Tuple[torch.distributions.Distribution, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Defines the full forward pass for the VaDE model (training and evaluation).

        Args:
            x (torch.Tensor): Input node features tensor (B, T, N, F_node).
            a (torch.Tensor): Input edge features tensor (B, T, E, F_edge).

        Returns:
            A tuple containing:
            - reconstruction_dist (torch.distributions.Distribution): The output distribution from the decoder.
            - latent (torch.Tensor): The sampled latent representation from the GMM space.
            - categorical (torch.Tensor): The cluster probabilities (soft assignments).
            - kmeans_loss (torch.Tensor): The k-means regularization loss from the latent space.
        """
        # 1. Encode the input to get the pre-latent representation
        encoder_output = self.encoder(x, a)
        
        # 2. Pass through GMM latent space
        latent, categorical, _, _, kmeans_loss, _ = self.latent_space(encoder_output)
        
        # 3. Decode the latent sample back to the original data space
        # Reshape x to (B, T, N*F) for the decoder's masking logic
        B, T, _, _ = x.shape
        x_for_decoder = x.view(B, T, self.input_n_nodes * self.input_n_features_per_node)
        
        reconstruction_dist = self.decoder(latent, x_for_decoder)
        
        return reconstruction_dist, latent, categorical, kmeans_loss
    

    def get_gmm_params(self) -> dict:
        """Returns the GMM parameters from the latent space."""
        # This is the PyTorch equivalent of the TF property
        with torch.no_grad():
            means = self.latent_space.gmm_means
            # The latent space stores log-variances, convert to std-dev
            stds = torch.exp(0.5 * self.latent_space.gmm_log_vars)
            # Prior is already softmaxed if needed, or just probabilities
            weights = self.latent_space.prior
        return {"means": means, "stds": stds, "weights": weights}


    def set_pretrain_mode(self, pretrain_on: bool):
        """Sets the pretrain flag in the latent space."""
        # In TF it was a float (0.0/1.0), here a boolean is cleaner
        self.latent_space.pretrain.fill_(1.0 if pretrain_on else 0.0)


    def initialize_gmm_from_data(self, data_loader, n_samples=10000):
        """
        Runs the autoencoder part of the model over the data to get embeddings,
        then fits a scikit-learn GMM to initialize the latent space.
        """
        print("Initializing GMM from data embeddings...")
        self.eval() # Set model to evaluation mode
        
        # 1. Gather embeddings from the autoencoder
        all_embeddings = []
        samples_gathered = 0
        with torch.no_grad():
            for x, a in data_loader:
                # Assuming x,a are on the correct device
                embeddings = self.encoder(x, a)
                all_embeddings.append(embeddings.cpu())
                samples_gathered += embeddings.size(0)
                if samples_gathered >= n_samples:
                    break
        
        all_embeddings = torch.cat(all_embeddings, dim=0).numpy()
        if all_embeddings.shape[0] > n_samples:
            all_embeddings = all_embeddings[:n_samples]

        # 2. Fit a scikit-learn GMM
        from sklearn.mixture import GaussianMixture
        print(f"Fitting scikit-learn GMM on {all_embeddings.shape[0]} samples...")
        gmm = GaussianMixture(
            n_components=self.latent_space.n_components,
            covariance_type="diag",
            reg_covar=1e-04,
        ).fit(all_embeddings)

        # 3. Assign the learned parameters to the model's latent space
        print("Assigning learned GMM parameters to the model.")
        self.latent_space.gmm_means.data = torch.from_numpy(gmm.means_).float()
        # Convert covariance (variance) to log-variance for the model
        self.latent_space.gmm_log_vars.data = torch.from_numpy(np.log(gmm.covariances_)).float()

In [31]:
def transfer_recurrent_block_weights(tf_model, pt_model):
    """Transfers weights for the full recurrent block with GRU gate permutation."""
    conv_td, _, gru1_td, norm1, gru2_td, norm2 = tf_model.layers[1:]


    def permute_gru_weights(keras_weights):
        W_ih, W_hh, B = keras_weights
        W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
        W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)
        W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
        W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)
        B_ih, B_hh = B
        B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
        B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)
        B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
        B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])
        return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt

    pt_model.conv1d.weight.data = torch.from_numpy(conv_td.layer.get_weights()[0]).permute(2, 1, 0)
    
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(gru1_td.layer.forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1); pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(gru1_td.layer.backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1); pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)

    pt_model.norm1.weight.data = torch.from_numpy(norm1.get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm1.get_weights()[1])

    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(gru2_td.layer.forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2); pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(gru2_td.layer.backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2); pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    
    pt_model.norm2.weight.data = torch.from_numpy(norm2.get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm2.get_weights()[1])

    
def transfer_censnet_weights(tf_layer, pt_layer):
    """
    Transfers all six weights from a Spektral CensNetConv layer to the
    corresponding CensNetConvPT layer.
    """
    # Get all weights from the TensorFlow layer. The order is determined by
    # the layer's build order in Spektral's source code.
    tf_weights = tf_layer.get_weights()

    # Unpack all six weights.
    # Order: kernel_node, bias_node, kernel_edge, bias_edge, projector_node, projector_edge
    kn_tf, bn_tf, ke_tf, be_tf, pn_tf, pe_tf = tf_weights

    # Build weights on first pass
    if pt_layer.node_kernel is None:
        # Move parameters to the same device as input tensors
        pt_layer._build(kn_tf.T.shape, bn_tf.T.shape)
        #pt_layer.to(kn_tf.device)

    # 1. & 2. Transfer Node Kernel and Bias
    # Keras Dense kernel is (in_features, out_features)
    pt_layer.node_kernel.data = torch.from_numpy(kn_tf)
    pt_layer.edge_kernel.data = torch.from_numpy(bn_tf)

    # 3. & 4. Transfer Edge Kernel and Bias
    # Same transposition logic applies.
    pt_layer.node_weights.data = torch.from_numpy(ke_tf)
    pt_layer.edge_weights.data = torch.from_numpy(be_tf)

    # 5. Transfer Node Projector Weights (P_n)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.node_bias.data = torch.from_numpy(pn_tf)

    # 6. Transfer Edge Projector Weights (P_e)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.edge_bias.data = torch.from_numpy(pe_tf)
    

def transfer_recurrent_encoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent encoder, finding layers
    by their default names and types to avoid modifying original code.
    """
    # The final dense layer is consistently the last one in the model's layer list.
    final_dense_tf = tf_model.layers[-1]
    final_dense_pt = pt_model.final_dense
    w, b = final_dense_tf.get_weights()
    final_dense_pt.weight.data = torch.from_numpy(w.T)
    final_dense_pt.bias.data = torch.from_numpy(b)

    if pt_model.use_gnn:
        # Keras automatically names nested models 'model', 'model_1', etc., by order of creation.
        # Node recurrent block is created first.
        node_recurrent_model = tf_model.get_layer("model")
        # Edge recurrent block is created second.
        edge_recurrent_model = tf_model.get_layer("model_1")
        # Find the CensNetConv layer by its class type.
        gnn_layer = next(l for l in tf_model.layers if isinstance(l, CensNetConv))

        transfer_recurrent_block_weights(node_recurrent_model, pt_model.node_recurrent_block)
        transfer_recurrent_block_weights(edge_recurrent_model, pt_model.edge_recurrent_block)
        transfer_censnet_weights(gnn_layer, pt_model.spatial_gnn_block)
    else: # Not using GNN
        # There is only one nested model, which Keras names 'model'.
        recurrent_model = tf_model.get_layer("model")
        transfer_recurrent_block_weights(recurrent_model, pt_model.recurrent_block)

In [32]:
def transfer_gmm_weights(tf_model, pt_model: GaussianMixtureLatentPT):
    """
    Transfers weights from the final TF model to the refactored PT model,
    using the updated attribute names.
    """
    # --- Transfer GMM component parameters ---
    # OLD: pt_model.c_mu
    pt_model.gmm_means.data = torch.from_numpy(tf_model.c_mu.numpy())
    # OLD: pt_model.log_c_sigma
    pt_model.gmm_log_vars.data = torch.from_numpy(tf_model.log_c_sigma.numpy())

    # --- Transfer Encoder layer parameters ---
    tf_mean_weights = tf_model.z_gauss_mean.get_weights()
    # OLD: pt_model.z_gauss_mean
    pt_model.encoder_mean.weight.data = torch.from_numpy(tf_mean_weights[0].T)
    pt_model.encoder_mean.bias.data = torch.from_numpy(tf_mean_weights[1])
    
    tf_var_weights = tf_model.z_gauss_var.get_weights()
    # OLD: pt_model.z_gauss_var
    pt_model.encoder_log_var.weight.data = torch.from_numpy(tf_var_weights[0].T)
    pt_model.encoder_log_var.bias.data = torch.from_numpy(tf_var_weights[1])

In [33]:
# Helper function from the provided example to handle gate order differences
def permute_gru_weights(keras_weights):
    """Permutes GRU weights from Keras (z, r, n) to PyTorch (r, z, n) format."""
    W_ih, W_hh, B = keras_weights
    # Keras gate order: z, r, n (update, reset, new/candidate)
    W_ih_z, W_ih_r, W_ih_n = np.split(W_ih, 3, axis=1)
    W_hh_z, W_hh_r, W_hh_n = np.split(W_hh, 3, axis=1)

    # PyTorch gate order: r, z, n (reset, update, new/candidate)
    W_ih_pt = np.concatenate([W_ih_r, W_ih_z, W_ih_n], axis=1)
    W_hh_pt = np.concatenate([W_hh_r, W_hh_z, W_hh_n], axis=1)

    # Keras has two bias vectors (input-hidden and recurrent), which are concatenated in B
    B_ih, B_hh = B
    B_ih_z, B_ih_r, B_ih_n = np.split(B_ih, 3)
    B_hh_z, B_hh_r, B_hh_n = np.split(B_hh, 3)

    B_ih_pt = np.concatenate([B_ih_r, B_ih_z, B_ih_n])
    B_hh_pt = np.concatenate([B_hh_r, B_hh_z, B_hh_n])

    return W_ih_pt.T, W_hh_pt.T, B_ih_pt, B_hh_pt
    
def transfer_recurrent_decoder_weights(tf_model, pt_model):
    """
    Transfers weights for the full recurrent decoder model.
    """
    # Find layers by type to avoid index issues
    bidi_layers = [l for l in tf_model.layers if isinstance(l, Bidirectional)]
    norm_layers = [l for l in tf_model.layers if isinstance(l, LayerNormalization)]
    conv_layers = [l for l in tf_model.layers if isinstance(l, tf.keras.layers.Conv1D)]
    prob_dec_layer = next(l for l in tf_model.layers if isinstance(l, deepof.model_utils.ProbabilisticDecoder))

    # --- GRU 1 and Norm 1 ---
    W_ih_f1, W_hh_f1, B_ih_f1, B_hh_f1 = permute_gru_weights(bidi_layers[0].forward_layer.get_weights())
    pt_model.gru1.weight_ih_l0.data = torch.from_numpy(W_ih_f1); pt_model.gru1.weight_hh_l0.data = torch.from_numpy(W_hh_f1)
    pt_model.gru1.bias_ih_l0.data = torch.from_numpy(B_ih_f1); pt_model.gru1.bias_hh_l0.data = torch.from_numpy(B_hh_f1)
    W_ih_b1, W_hh_b1, B_ih_b1, B_hh_b1 = permute_gru_weights(bidi_layers[0].backward_layer.get_weights())
    pt_model.gru1.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b1); pt_model.gru1.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b1)
    pt_model.gru1.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b1); pt_model.gru1.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b1)
    pt_model.norm1.weight.data = torch.from_numpy(norm_layers[0].get_weights()[0]); pt_model.norm1.bias.data = torch.from_numpy(norm_layers[0].get_weights()[1])

    # --- GRU 2 and Norm 2 ---
    W_ih_f2, W_hh_f2, B_ih_f2, B_hh_f2 = permute_gru_weights(bidi_layers[1].forward_layer.get_weights())
    pt_model.gru2.weight_ih_l0.data = torch.from_numpy(W_ih_f2); pt_model.gru2.weight_hh_l0.data = torch.from_numpy(W_hh_f2)
    pt_model.gru2.bias_ih_l0.data = torch.from_numpy(B_ih_f2); pt_model.gru2.bias_hh_l0.data = torch.from_numpy(B_hh_f2)
    W_ih_b2, W_hh_b2, B_ih_b2, B_hh_b2 = permute_gru_weights(bidi_layers[1].backward_layer.get_weights())
    pt_model.gru2.weight_ih_l0_reverse.data = torch.from_numpy(W_ih_b2); pt_model.gru2.weight_hh_l0_reverse.data = torch.from_numpy(W_hh_b2)
    pt_model.gru2.bias_ih_l0_reverse.data = torch.from_numpy(B_ih_b2); pt_model.gru2.bias_hh_l0_reverse.data = torch.from_numpy(B_hh_b2)
    pt_model.norm2.weight.data = torch.from_numpy(norm_layers[1].get_weights()[0]); pt_model.norm2.bias.data = torch.from_numpy(norm_layers[1].get_weights()[1])

    # --- Conv1D and Norm 3 ---
    # TF Conv1D weights: (kernel_w, kernel_h, in_c, out_c) -> (5, 1, 4*ld, 2*ld)
    # PT Conv1d weights: (out_c, in_c, kernel_w)
    conv_weights_tf = conv_layers[0].get_weights()[0]
    pt_model.conv1d.weight.data = torch.from_numpy(conv_weights_tf).squeeze(1).permute(2, 1, 0)
    pt_model.norm3.weight.data = torch.from_numpy(norm_layers[2].get_weights()[0]); pt_model.norm3.bias.data = torch.from_numpy(norm_layers[2].get_weights()[1])

    # --- Probabilistic Decoder ---
    # TF Dense weights: (in_features, out_features)
    # PT Linear weights: (out_features, in_features)
    prob_dec_weights, prob_dec_bias = prob_dec_layer.time_distributer.get_weights()
    pt_model.prob_decoder.loc_projection.weight.data = torch.from_numpy(prob_dec_weights.T)
    pt_model.prob_decoder.loc_projection.bias.data = torch.from_numpy(prob_dec_bias)

In [34]:
def transfer_vade_class_weights(tf_vade_model, tf_decoder_model, pt_vade_model: VaDEPT):
    """
    Transfers weights from a full TensorFlow VaDE model to the self-contained PyTorch VaDEPT class.
    """
    print("Transferring weights for all VaDE components...")
    
    # 1. Get the inner Keras models/layers by name from the complete TF model
    tf_encoder_inner = tf_vade_model.get_layer("recurrent_encoder")
    tf_latent_layer = tf_vade_model.get_layer("gaussian_mixture_latent")
    
    # 2. Use the specialized weight transfer functions, passing the PT sub-modules
    print("  -> Transferring Encoder weights...")
    transfer_recurrent_encoder_weights(tf_encoder_inner, pt_vade_model.encoder)
    print("  -> Transferring GMM Latent weights...")
    transfer_gmm_weights(tf_latent_layer, pt_vade_model.latent_space)
    print("  -> Transferring Decoder weights...")
    transfer_recurrent_decoder_weights(tf_decoder_model, pt_vade_model.decoder)
    
    print("Weight transfer complete.")


class TestVaDETranslation(unittest.TestCase):
    def setUp(self):
        """Set up parameters, models, and data for testing."""
        tf.keras.backend.clear_session()
        tf.keras.backend.set_epsilon(1e-3)

        # --- 1. Define Fundamental Dimensions ---
        self.batch_size = 128
        self.window_length = 25
        self.num_nodes = 11
        # In your example, total features (n=33) / num_nodes (11) = 3
        self.features_per_node = 3
        self.num_edges = 11
        self.features_per_edge = 1 # Assuming 1 feature per edge

        # --- 2. Define Model Parameters ---
        self.latent_dim = 6
        self.n_components = 10
        self.kmeans_loss = 1.0
        self.use_gnn = False

        # --- 3. Create Adjacency Matrix ---
        m = np.zeros((self.num_nodes, self.num_nodes))
        ui = np.triu_indices(self.num_nodes)
        num_possible_edges = len(ui[0])
        c = np.random.choice(num_possible_edges, min(self.num_edges, num_possible_edges), replace=False)
        m[ui[0][c], ui[1][c]] = 1
        m += m.T # Make symmetric
        self.adj_matrix = m

        # --- 4. Create Framework-Specific Shapes for Model Instantiation ---
        
        # TensorFlow expects (batch, time, total_features)
        self.input_shape_tf = (self.batch_size, self.window_length, self.num_nodes * self.features_per_node)
        self.edge_feature_shape_tf = (self.batch_size, self.window_length, self.num_edges * self.features_per_edge)
        
        # PyTorch VaDEPT expects (time, nodes, features_per_node) for a SINGLE sample
        self.input_shape_pt = (self.window_length, self.num_nodes, self.features_per_node)
        self.edge_feature_shape_pt = (self.window_length, self.num_edges, self.features_per_edge)

        # --- 5. Instantiate Models ---
        tf_model = VaDE(
            input_shape=self.input_shape_tf,
            edge_feature_shape=self.edge_feature_shape_tf,
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=self.use_gnn,
            n_components=self.n_components,
            batch_size=self.batch_size,
            kmeans_loss=self.kmeans_loss
        )
        self.tf_decoder = tf_model.decoder
        self.tf_vade = tf_model.vade
        self.tf_embedding = tf_model.encoder
        self.tf_grouper = tf_model.grouper
        
        self.pt_vade = VaDEPT(
            input_shape=self.input_shape_pt,
            edge_feature_shape=self.edge_feature_shape_pt,
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            n_components=self.n_components,
            use_gnn=self.use_gnn,
            kmeans_loss=self.kmeans_loss
        )
        self.pt_vade.eval()

        # --- 6. Prepare Data Tensors for Each Framework ---
        np.random.seed(42)
        # The "canonical" data is 4D, as expected by the new PyTorch models
        self.x_np_4d = np.random.rand(
            self.batch_size, self.window_length, self.num_nodes, self.features_per_node
        ).astype(np.float32)
        self.a_np_4d = np.random.rand(
            self.batch_size, self.window_length, self.num_edges, self.features_per_edge
        ).astype(np.float32)

        # Create the 3D version for the legacy TensorFlow model by reshaping
        self.x_np_tf = self.x_np_4d.reshape(self.input_shape_tf)
        self.a_np_tf = self.a_np_4d.reshape(self.edge_feature_shape_tf)
        
        # --- 7. Transfer Weights ---
        transfer_vade_class_weights(self.tf_vade, self.tf_decoder, self.pt_vade)

    def test_full_model_and_parts(self):
        """Test the forward pass and helper methods of the VaDEPT class."""
        print("\n--- Testing Self-Contained VaDEPT Class Translation ---")
        
        # --- TensorFlow Execution (with its required 3D input) ---
        tf_start = time.time()
        tf_rec_dist = self.tf_vade([self.x_np_tf, self.a_np_tf], training=False)
        tf_rec_mean = tf_rec_dist.mean().numpy()
        tf_lat_out = self.tf_embedding([self.x_np_tf, self.a_np_tf], training=False).numpy()
        tf_cat_out = self.tf_grouper([self.x_np_tf, self.a_np_tf], training=False).numpy()
        tf_end = time.time()
        
        # --- PyTorch Execution (with its required 4D input) ---
        x_pt = torch.from_numpy(self.x_np_4d)
        a_pt = torch.from_numpy(self.a_np_4d)
        
        pt_start = time.time()
        with torch.no_grad():
            pt_rec_dist, _, _, _ = self.pt_vade(x_pt, a_pt)
            pt_rec_mean = pt_rec_dist.mean.numpy() 
            pt_lat_out = self.pt_vade.embed(x_pt, a_pt).numpy()
            pt_cat_out = self.pt_vade.group(x_pt, a_pt).numpy()
        pt_end = time.time()

        print(f"TensorFlow execution time: {tf_end - tf_start:.6f}s")
        print(f"PyTorch execution time: {pt_end - pt_start:.6f}s")
        
        # --- Assertions ---
        print("\nComparing latent space embeddings (from .embed() vs 'embedding' model)...")
        # Both outputs should be (batch_size, latent_dim), so (128, 6)
        np.testing.assert_allclose(tf_lat_out, pt_lat_out, rtol=1e-5, atol=1e-4)
        print("✅ Latent embeddings match.")

        print("Comparing categorical probabilities (from .group() vs 'grouper' model)...")
        # Both outputs should be (batch_size, n_components), so (128, 10)
        np.testing.assert_allclose(tf_cat_out, pt_cat_out, rtol=1e-5, atol=1e-5)
        print("✅ Categorical probabilities match.")
        
        print("Comparing final reconstruction means (from forward() vs 'vade' model)...")
        # Both outputs should be (batch_size, time_steps, total_features), so (128, 25, 33)
        np.testing.assert_allclose(tf_rec_mean, pt_rec_mean, rtol=1e-5, atol=1e-4)
        print("✅ Reconstructions match.")

        print("\n✅ Self-contained VaDEPT class translation test PASSED!")

# To run the test
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestVaDETranslation)
runner.run(suite)

test_full_model_and_parts (__main__.TestVaDETranslation)
Test the forward pass and helper methods of the VaDEPT class. ... 

Transferring weights for all VaDE components...
  -> Transferring Encoder weights...
  -> Transferring GMM Latent weights...
  -> Transferring Decoder weights...
Weight transfer complete.

--- Testing Self-Contained VaDEPT Class Translation ---


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
ERROR

ERROR: test_full_model_and_parts (__main__.TestVaDETranslation)
Test the forward pass and helper methods of the VaDEPT class.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\Petron\AppData\Local\Temp\ipykernel_16880\1105766262.py", line 124, in test_full_model_and_parts
    pt_rec_dist, _, _, _ = self.pt_vade(x_pt, a_pt)
  File "c:\Users\Petron\Desktop\Python_Projects\Deepof\dof\lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\Petron\Desktop\Python_Projects\Deepof\dof\lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\Petron\AppData\Local\Temp\ipykernel_16880\2086274447.py", line 87, in forward
    latent, categorical, _, _, kmeans

<unittest.runner.TextTestResult run=1 errors=1 failures=0>

# TCN encoder test

In [35]:
from typing import Iterable, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from spektral.layers import CensNetConv
from deepof.clustering.censNetConv_pt import CensNetConvPT
from tensorflow.keras.layers import Input, TimeDistributed, Bidirectional, GRU, LayerNormalization, Masking
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)
import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf
import tcn


In [36]:
def get_TCN_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    conv_filters: int = 32,
    kernel_size: int = 4,
    conv_stacks: int = 2,
    conv_dilations: tuple = (1, 2, 4, 8),
    padding: str = "causal",
    use_skip_connections: bool = True,
    dropout_rate: int = 0,
    activation: str = "relu",
    interaction_regularization: float = 0.0,
):
    """Return a Temporal Convolutional Network (TCN) encoder.

    Builds a neural network that can be used to encode motion tracking instances into a
    vector. Each layer contains a residual block with a convolutional layer and a skip connection. See the following
    paper for more details: https://arxiv.org/pdf/1803.01271.pdf

    Args:
        input_shape: shape of the input data
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        latent_dim: dimensionality of the latent space
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        conv_filters: number of filters in the TCN layers
        kernel_size: size of the convolutional kernels
        conv_stacks: number of TCN layers
        conv_dilations: list of dilation factors for each TCN layer
        padding: padding mode for the TCN layers
        use_skip_connections: whether to use skip connections between TCN layers
        dropout_rate: dropout rate for the TCN layers
        activation: activation function for the TCN layers
        interaction_regularization (float): Regularization parameter for the interaction features

    Returns:
        keras.Model: a keras model that can be trained to encode a sequence of motion tracking instances into a latent
        space using temporal convolutional networks.

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )

    else:
        x_reshaped = tf.expand_dims(x, axis=1)

    encoder = TimeDistributed(
        tcn.TCN(
            conv_filters,
            kernel_size,
            conv_stacks,
            conv_dilations,
            padding,
            use_skip_connections,
            dropout_rate,
            return_sequences=False,
            activation=activation,
            kernel_initializer="random_normal",
            use_batch_norm=True,
        )
    )(x_reshaped)

    # Instantiate spatial graph block
    if use_gnn:

        # Embed edge features too
        a_encoder = TimeDistributed(
            tcn.TCN(
                conv_filters,
                kernel_size,
                conv_stacks,
                conv_dilations,
                padding,
                use_skip_connections,
                dropout_rate,
                return_sequences=False,
                activation=activation,
                kernel_initializer="random_normal",
                use_batch_norm=True,
            )
        )(a_reshaped)

        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [encoder, (laplacian, edge_laplacian, incidence), a_encoder], mask=None
        )

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        encoder = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        encoder = tf.squeeze(encoder, axis=1)

    encoder = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = Dense(latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim)(encoder)

    return Model([x, a], encoder, name="TCN_encoder")

In [37]:
from typing import Iterable, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepof.clustering.censNetConv_pt import CensNetConvPT


def _act(name: str) -> nn.Module:
    name = (name or "relu").lower()
    if name == "relu":
        return nn.ReLU()
    if name == "gelu":
        return nn.GELU()
    if name == "tanh":
        return nn.Tanh()
    if name == "leaky_relu":
        return nn.LeakyReLU(0.2)
    if name in {"linear", "identity", "none"}:
        return nn.Identity()
    raise ValueError(f"Unsupported activation: {name}")


class TemporalBlockPT(nn.Module):
    """
    Residual TCN block compatible with keras-tcn:
      - Conv1d -> BN(eps=1e-3) -> Act -> Drop
      - Conv1d -> BN(eps=1e-3) -> Act -> Drop
      - Residual add (with 1x1 projection if channels differ) -> Act
    Returns:
      out: post-residual activation
      skip: post-second-conv activation (summed across blocks when skip connections are used)
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        dilation: int,
        padding: str = "causal",
        dropout_rate: float = 0.0,
        activation: str = "relu",
        use_batch_norm: bool = True,
        conv_init_std: float = 0.05,
    ):
        super().__init__()
        assert padding in {"causal", "same"}
        self.dilation = int(dilation)
        self.kernel_size = int(kernel_size)
        self.padding_mode = padding
        self.act = _act(activation)
        self.use_batch_norm = use_batch_norm

        pad = lambda: ((self.kernel_size - 1) * self.dilation) // 2 if padding == "same" else 0

        self.conv1 = nn.Conv1d(in_channels, out_channels, self.kernel_size, dilation=self.dilation, padding=pad(), bias=True)
        self.bn1 = nn.BatchNorm1d(out_channels, eps=1e-3) if use_batch_norm else nn.Identity()
        self.drop1 = nn.Dropout(float(dropout_rate)) if dropout_rate else nn.Identity()

        self.conv2 = nn.Conv1d(out_channels, out_channels, self.kernel_size, dilation=self.dilation, padding=pad(), bias=True)
        self.bn2 = nn.BatchNorm1d(out_channels, eps=1e-3) if use_batch_norm else nn.Identity()
        self.drop2 = nn.Dropout(float(dropout_rate)) if dropout_rate else nn.Identity()

        # 1x1 residual projection if channels differ
        self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=True) if in_channels != out_channels else None

        # Init similar to keras random_normal
        nn.init.normal_(self.conv1.weight, mean=0.0, std=conv_init_std); nn.init.zeros_(self.conv1.bias)
        nn.init.normal_(self.conv2.weight, mean=0.0, std=conv_init_std); nn.init.zeros_(self.conv2.bias)
        if self.downsample is not None:
            nn.init.normal_(self.downsample.weight, mean=0.0, std=conv_init_std); nn.init.zeros_(self.downsample.bias)

    def _causal_pad(self, x: torch.Tensor) -> torch.Tensor:
        pad = (self.kernel_size - 1) * self.dilation
        return F.pad(x, (pad, 0))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: (B, C_in, T)
        y = self._causal_pad(x) if self.padding_mode == "causal" else x
        y = self.drop1(self.act(self.bn1(self.conv1(y))))

        y = self._causal_pad(y) if self.padding_mode == "causal" else y
        y = self.drop2(self.act(self.bn2(self.conv2(y))))

        skip = y  # per-block skip is the post-second-activation output

        res = x if self.downsample is None else self.downsample(x)
        out = self.act(y + res)
        return out, skip  # both (B, C_out, T)


class TCN1DPT(nn.Module):
    """
    Temporal Convolutional Network over sequences (B, T, C_in).
    - When use_skip_connections=True: sum per-block skip outputs, then apply a final activation.
    - Otherwise: use the last block’s residual output.
    - return_sequences=False: returns last timestep features (B, C_out).
    """
    def __init__(
        self,
        in_channels: int,
        conv_filters: int = 32,
        kernel_size: int = 4,
        conv_stacks: int = 2,
        conv_dilations: Iterable[int] = (1, 2, 4, 8),
        padding: str = "causal",
        use_skip_connections: bool = True,
        dropout_rate: float = 0.0,
        activation: str = "relu",
        use_batch_norm: bool = True,
        return_sequences: bool = False,
    ):
        super().__init__()
        self.use_skip_connections = use_skip_connections
        self.return_sequences = return_sequences
        self.final_act = _act(activation)

        blocks = []
        c_in = in_channels
        for _ in range(int(conv_stacks)):
            for d in tuple(conv_dilations):
                blocks.append(
                    TemporalBlockPT(
                        in_channels=c_in,
                        out_channels=conv_filters,
                        kernel_size=kernel_size,
                        dilation=int(d),
                        padding=padding,
                        dropout_rate=dropout_rate,
                        activation=activation,
                        use_batch_norm=use_batch_norm,
                    )
                )
                c_in = conv_filters
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C_in) -> Conv1d expects (B, C_in, T)
        y = x.transpose(1, 2)
        skip_sum, last_out = None, None

        for blk in self.blocks:
            y, skip = blk(y)
            last_out = y
            if self.use_skip_connections:
                skip_sum = skip if skip_sum is None else (skip_sum + skip)

        out = skip_sum if self.use_skip_connections else last_out  # (B, C, T)
        out = self.final_act(out)
        return out.transpose(1, 2) if self.return_sequences else out[:, :, -1]


class TCNEncoderPT(nn.Module):
    """
    PyTorch port of the TF get_TCN_encoder with matching behavior:
      - Inputs:
          x: (B, W, N, NF)   node features
          a: (B, W, E, EF)   edge features
      - use_gnn=True:
          TimeDistributed(TCN) over nodes/edges -> (B, N, C) and (B, E, C)
          CensNetConvPT([node, (lap, edge_lap, inc), edge]) -> (B, N, latent), (B, E, latent)
          Flatten and MLP head
      - use_gnn=False:
          Flatten nodes+features -> TCN -> MLP head

      Parity details:
        - keras-tcn-compatible skip semantics and activation placement
        - BN eps=1e-3 everywhere
        - 'causal' and 'same' paddings supported
    """
    def __init__(
        self,
        input_shape: Tuple[int, int, int],        # (W, N, NF)
        edge_feature_shape: Tuple[int, int, int], # (W, E, EF)
        adjacency_matrix: np.ndarray,
        latent_dim: int,
        use_gnn: bool = True,
        conv_filters: int = 32,
        kernel_size: int = 4,
        conv_stacks: int = 2,
        conv_dilations: Iterable[int] = (1, 2, 4, 8),
        padding: str = "causal",
        use_skip_connections: bool = True,
        dropout_rate: float = 0.0,
        activation: str = "relu",
        interaction_regularization: float = 0.0,  # not used explicitly in PT
        use_batch_norm: bool = True,
    ):
        super().__init__()
        self.use_gnn = use_gnn
        self.latent_dim = int(latent_dim)
        self.conv_filters = int(conv_filters)

        W, N, F_node = input_shape
        _, E, F_edge = edge_feature_shape
        assert adjacency_matrix.shape[0] == N == adjacency_matrix.shape[1], "Adjacency must be NxN and match input nodes."

        tcn_cfg = dict(
            conv_filters=conv_filters,
            kernel_size=kernel_size,
            conv_stacks=conv_stacks,
            conv_dilations=tuple(conv_dilations),
            padding=padding,
            use_skip_connections=use_skip_connections,
            dropout_rate=float(dropout_rate),
            activation=activation,
            use_batch_norm=use_batch_norm,
            return_sequences=False,
        )

        if use_gnn:
            # Per-node and per-edge TCNs
            self.node_tcn = TCN1DPT(in_channels=F_node, **tcn_cfg)
            self.edge_tcn = TCN1DPT(in_channels=F_edge, **tcn_cfg)

            # Graph block and buffers
            self.spatial_gnn_block = CensNetConvPT(node_channels=latent_dim, edge_channels=latent_dim, activation="relu")
            lap, edge_lap, inc = self.spatial_gnn_block.preprocess(torch.tensor(adjacency_matrix))
            self.register_buffer("laplacian", lap.float())
            self.register_buffer("edge_laplacian", edge_lap.float())
            self.register_buffer("incidence", inc.float())

            final_in = (N * latent_dim) + (E * latent_dim)
        else:
            # Single TCN over flattened node features
            self.flat_tcn = TCN1DPT(in_channels=N * F_node, **tcn_cfg)
            final_in = conv_filters

        # Head MLP: Dense(2*latent) -> BN -> Dense(latent) -> BN -> Dense(latent)
        self.head = nn.Sequential(
            nn.Linear(final_in, 2 * latent_dim),
            nn.ReLU(),
            nn.BatchNorm1d(2 * latent_dim, eps=1e-3),
            nn.Linear(2 * latent_dim, latent_dim),
            nn.ReLU(),
            nn.BatchNorm1d(latent_dim, eps=1e-3),
            nn.Linear(latent_dim, latent_dim),
        )
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        """
        x: (B, W, N, NF)  a: (B, W, E, EF)  -> returns (B, latent_dim)
        """
        B, W, N, F_node = x.shape
        _, _, E, F_edge = a.shape

        if self.use_gnn:
            # Nodes: TF-style reshape pipeline to match memory layout exactly
            x_3d = x.view(B, W, N * F_node)          # (B, W, N*F)
            x_t = x_3d.permute(2, 1, 0)              # (N*F, W, B)
            x_reshaped_t = x_t.reshape(F_node, W, N, B)
            x_nodes = x_reshaped_t.permute(3, 2, 1, 0)  # (B, N, W, F)

            node_in = x_nodes.reshape(B * N, W, F_node)
            node_out = self.node_tcn(node_in).view(B, N, self.conv_filters)  # (B, N, C)

            # Edges: TF-style reshape pipeline to match memory layout exactly
            a_3d = a.view(B, W, E * F_edge)          # (B, W, E*F_edge)
            a_t = a_3d.permute(2, 1, 0)              # (E*F_edge, W, B)
            a_reshaped_t = a_t.reshape(F_edge, W, E, B)
            a_edges = a_reshaped_t.permute(3, 2, 1, 0)  # (B, E, W, F_edge)

            edge_in = a_edges.reshape(B * E, W, F_edge)
            edge_out = self.edge_tcn(edge_in).view(B, E, self.conv_filters)  # (B, E, C)

            # Graph block
            adj_tuple = (self.laplacian, self.edge_laplacian, self.incidence)
            x_nodes_g, x_edges_g = self.spatial_gnn_block([node_out, adj_tuple, edge_out])
            x_nodes_g = F.relu(x_nodes_g)
            x_edges_g = F.relu(x_edges_g)

            enc = torch.cat([x_nodes_g.reshape(B, -1), x_edges_g.reshape(B, -1)], dim=-1)
        else:
            # Non-GNN unchanged
            x_flat = x.view(B, W, N * F_node)        # (B, W, N*NF)
            enc = self.flat_tcn(x_flat)              # (B, C)

        return self.head(enc)

In [38]:
import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf
import tcn as tcn_pkg

def _tf_conv1d_to_torch(w_keras: np.ndarray) -> torch.Tensor:
    # TF Conv1D [K, Cin, Cout] -> PT Conv1d [Cout, Cin, K]
    return torch.from_numpy(np.transpose(w_keras, (2, 1, 0)))

def _load_bn_tf_to_pt(pt_bn: nn.BatchNorm1d, tf_bn: tf.keras.layers.BatchNormalization):
    gamma, beta, moving_mean, moving_var = tf_bn.get_weights()
    pt_bn.weight.data = torch.from_numpy(gamma)
    pt_bn.bias.data = torch.from_numpy(beta)
    pt_bn.running_mean.data = torch.from_numpy(moving_mean)
    pt_bn.running_var.data = torch.from_numpy(moving_var)

def _kernel_size_1(conv: tf.keras.layers.Conv1D) -> bool:
    ks = conv.kernel_size
    ks = ks[0] if isinstance(ks, tuple) else ks
    return ks == 1

def _collect_tcn_sublayers(tf_tcn_layer: tf.keras.layers.Layer):
    # Use submodules as before; we’ll still verify counts and assign conservatively
    convs = [m for m in tf_tcn_layer.submodules if isinstance(m, tf.keras.layers.Conv1D)]
    bns = [m for m in tf_tcn_layer.submodules if isinstance(m, tf.keras.layers.BatchNormalization)]
    return convs, bns

def transfer_td_tcn_weights(tf_td_tcn: tf.keras.layers.TimeDistributed, pt_tcn) -> None:
    """
    Transfer weights from TF TimeDistributed(tcn.TCN) into PT TCN1DPT:
      - Map per-block [conv1, conv2] (in order) and their BN layers
      - Map the single residual 1x1 projection (matching_conv1D), if present
      - No skip 1x1 convs (your model has none)
    """
    assert isinstance(tf_td_tcn, tf.keras.layers.TimeDistributed)
    assert isinstance(tf_td_tcn.layer, tcn_pkg.TCN)
    tf_tcn = tf_td_tcn.layer

    convs, bns = _collect_tcn_sublayers(tf_tcn)
    block_convs = [c for c in convs if not _kernel_size_1(c)]   # conv1D_0 / conv1D_1 pairs per block
    proj_1x1 = [c for c in convs if _kernel_size_1(c)]          # matching_conv1D (0 or 1 in your build)

    num_blocks = len(pt_tcn.blocks)
    assert len(block_convs) == 2 * num_blocks, f"Conv count mismatch: TF block convs={len(block_convs)}, PT blocks={num_blocks}"

    use_bn = isinstance(pt_tcn.blocks[0].bn1, nn.BatchNorm1d)
    if use_bn:
        assert len(bns) >= 2 * num_blocks, f"BN count mismatch: TF BNs={len(bns)}, expected >= {2 * num_blocks}"

    # Load per-block convs and BN stats
    for i, blk in enumerate(pt_tcn.blocks):
        k1, b1 = block_convs[2 * i].get_weights()
        blk.conv1.weight.data = _tf_conv1d_to_torch(k1)
        blk.conv1.bias.data = torch.from_numpy(b1)

        k2, b2 = block_convs[2 * i + 1].get_weights()
        blk.conv2.weight.data = _tf_conv1d_to_torch(k2)
        blk.conv2.bias.data = torch.from_numpy(b2)

        if use_bn:
            _load_bn_tf_to_pt(blk.bn1, bns[2 * i])
            _load_bn_tf_to_pt(blk.bn2, bns[2 * i + 1])

    # Residual projection for the first block if needed
    proj_idx = 0
    for blk in pt_tcn.blocks:
        if isinstance(getattr(blk, "downsample", None), nn.Conv1d):
            rk, rb = proj_1x1[proj_idx].get_weights()
            blk.downsample.weight.data = _tf_conv1d_to_torch(rk)
            blk.downsample.bias.data = torch.from_numpy(rb)
            proj_idx += 1


# ---------- MLP head transfer ----------

def transfer_head_mlp(tf_model, pt_model_head: nn.Sequential):
    """
    Transfer the final MLP head:
      Dense(2*latent, relu) -> BN -> Dense(latent, relu) -> BN -> Dense(latent)
    from TF model to PT head (Linear, BN, Linear, BN, Linear).
    """
    # Extract the final [Dense, BN, Dense, BN, Dense] from TF model
    tail = [l for l in tf_model.layers if isinstance(l, (tf.keras.layers.Dense, tf.keras.layers.BatchNormalization))]
    d1, bn1, d2, bn2, d3 = tail[-5:]

    # PT head layout: [Linear, ReLU, BN, Linear, ReLU, BN, Linear]
    lin1: nn.Linear = pt_model_head[0]
    bn1_pt: nn.BatchNorm1d = pt_model_head[2]
    lin2: nn.Linear = pt_model_head[3]
    bn2_pt: nn.BatchNorm1d = pt_model_head[5]
    lin3: nn.Linear = pt_model_head[6]

    # Dense 1
    w, b = d1.get_weights()
    lin1.weight.data = torch.from_numpy(w.T)
    lin1.bias.data = torch.from_numpy(b)
    # BN 1
    _load_bn_tf_to_pt(bn1_pt, bn1)
    # Dense 2
    w, b = d2.get_weights()
    lin2.weight.data = torch.from_numpy(w.T)
    lin2.bias.data = torch.from_numpy(b)
    # BN 2
    _load_bn_tf_to_pt(bn2_pt, bn2)
    # Dense 3
    w, b = d3.get_weights()
    lin3.weight.data = torch.from_numpy(w.T)
    lin3.bias.data = torch.from_numpy(b)


# ---------- High-level: TCN encoder transfer ----------

def transfer_tcn_encoder_weights(tf_model, pt_model, use_gnn: bool):
    """
    Transfers weights for the full TCN encoder.
      - Node and edge TimeDistributed(TCN) blocks
      - CensNetConv (if use_gnn)
      - Final MLP head
    """
    # 1) Final head
    transfer_head_mlp(tf_model, pt_model.head)

    # 2) TimeDistributed(TCN) blocks
    td_layers = [l for l in tf_model.layers if isinstance(l, tf.keras.layers.TimeDistributed) and isinstance(l.layer, tcn.TCN)]
    if use_gnn:
        assert len(td_layers) >= 2, "Expected two TimeDistributed(TCN) layers (node and edge) for use_gnn=True"
        # Heuristically: first TD is nodes, second is edges (matches build order)
        node_td = td_layers[0]
        edge_td = td_layers[1]
        transfer_td_tcn_weights(node_td, pt_model.node_tcn)
        transfer_td_tcn_weights(edge_td, pt_model.edge_tcn)

        # 3) CensNetConv
        gnn_layer = next(l for l in tf_model.layers if isinstance(l, CensNetConv))
        transfer_censnet_weights(gnn_layer, pt_model.spatial_gnn_block)

    else:
        # Non-GNN: single TD(TCN); TF input_shape should be (T, N*F_node)
        assert len(td_layers) >= 1, "Expected one TimeDistributed(TCN) layer for use_gnn=False"
        transfer_td_tcn_weights(td_layers[0], pt_model.flat_tcn)

def transfer_censnet_weights(tf_layer, pt_layer):
    """
    Transfers all six weights from a Spektral CensNetConv layer to the
    corresponding CensNetConvPT layer.
    """
    # Get all weights from the TensorFlow layer. The order is determined by
    # the layer's build order in Spektral's source code.
    tf_weights = tf_layer.get_weights()

    # Unpack all six weights.
    # Order: kernel_node, bias_node, kernel_edge, bias_edge, projector_node, projector_edge
    kn_tf, bn_tf, ke_tf, be_tf, pn_tf, pe_tf = tf_weights

    # Build weights on first pass
    if pt_layer.node_kernel is None:
        # Move parameters to the same device as input tensors
        pt_layer._build(kn_tf.T.shape, bn_tf.T.shape)
        #pt_layer.to(kn_tf.device)

    # 1. & 2. Transfer Node Kernel and Bias
    # Keras Dense kernel is (in_features, out_features)
    pt_layer.node_kernel.data = torch.from_numpy(kn_tf)
    pt_layer.edge_kernel.data = torch.from_numpy(bn_tf)

    # 3. & 4. Transfer Edge Kernel and Bias
    # Same transposition logic applies.
    pt_layer.node_weights.data = torch.from_numpy(ke_tf)
    pt_layer.edge_weights.data = torch.from_numpy(be_tf)

    # 5. Transfer Node Projector Weights (P_n)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.node_bias.data = torch.from_numpy(pn_tf)

    # 6. Transfer Edge Projector Weights (P_e)
    # These are [in_features, 1], which matches, so no transpose needed.
    pt_layer.edge_bias.data = torch.from_numpy(pe_tf)

In [39]:
import unittest, time
import numpy as np
import torch

def count_undirected_edges(adj: np.ndarray) -> int:
    # Count upper-triangular non-zero entries (undirected edges)
    return int(np.count_nonzero(np.triu(adj, 1)))

class TestTCNEncoderTranslation(unittest.TestCase):
    def setUp(self):
        tf.keras.backend.clear_session()

        # Fundamental dims (use your conventions)
        self.R = 2048                 # number of rows (not used for model build)
        self.W = 25                   # window length
        self.N = 11                   # nodes
        self.NF = 3                   # features per node
        self.EF = 1                   # features per edge (TF expects 1 for the reshape quirk)
        self.latent_dim = 6
        self.use_gnn = True

        # Batch used for parity test
        self.B = 128

        # Make an adjacency whose undirected edge count E matches the edge axis we'll use
        # Example: pick a sparse symmetric adjacency with E edges
        # If you already have an adjacency, just set self.adj_matrix = your_matrix and let E = count_undirected_edges(A)
        rng = np.random.default_rng(0)
        A = np.zeros((self.N, self.N), dtype=np.float32)
        # randomly pick E edges; here we choose E = 11 (as in your typical config)
        target_E = 11
        iu = np.triu_indices(self.N, 1)
        idx = rng.choice(len(iu[0]), size=target_E, replace=False)
        A[iu[0][idx], iu[1][idx]] = 1.0
        A = A + A.T
        self.adj_matrix = A
        self.E = count_undirected_edges(self.adj_matrix)  # should be target_E

        # TF input shapes (flattened)
        self.tf_input_shape_gnn = (self.W, self.N * self.NF)  # (W, NNF)
        self.tf_edge_shape      = (self.W, self.E * self.EF)  # (W, EEF) -> with EF=1, equals (W, E)

        # PT input shapes (split)
        self.pt_input_shape = (self.W, self.N, self.NF)       # (W, N, NF)
        self.pt_edge_shape  = (self.W, self.E, self.EF)       # (W, E, EF)

        # Random inputs
        self.x_pt = rng.normal(size=(self.B, self.W, self.N, self.NF)).astype(np.float32)
        self.a_pt = rng.normal(size=(self.B, self.W, self.E, self.EF)).astype(np.float32)
        # Flatten for TF model
        self.x_tf = self.x_pt.reshape(self.B, self.W, self.N * self.NF)
        self.a_tf = self.a_pt.reshape(self.B, self.W, self.E * self.EF)  # with EF=1, (B, W, E)

        # Common TCN params
        self.conv_filters = 32
        self.kernel_size = 3
        self.conv_stacks = 2
        self.conv_dilations = (1, 2)
        self.padding = "causal"
        self.use_skip = True
        self.dropout = 0.0
        self.activation = "relu"

    def test_forward_pass_gnn(self):
        # Build TF and PT models
        tf_model = get_TCN_encoder(
            input_shape=self.tf_input_shape_gnn,   # (W, NNF)
            edge_feature_shape=self.tf_edge_shape, # (W, EEF) -> E when EF=1
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=True,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model = TCNEncoderPT(
            input_shape=self.pt_input_shape,       # (W, N, NF)
            edge_feature_shape=self.pt_edge_shape, # (W, E, EF)
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=True,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model.eval()

        # Warm-up PT graph (optional)
        with torch.no_grad():
            _ = pt_model(torch.from_numpy(self.x_pt), torch.from_numpy(self.a_pt))

        # Transfer weights TF -> PT
        transfer_tcn_encoder_weights(tf_model, pt_model, use_gnn=True)

        # Compare outputs (TF expects flattened a)
        t0 = time.time()
        y_tf = tf_model([self.x_tf, self.a_tf], training=False).numpy()
        t1 = time.time()
        with torch.no_grad():
            y_pt = pt_model(torch.from_numpy(self.x_pt), torch.from_numpy(self.a_pt)).cpu().numpy()
        t2 = time.time()
        print("GNN TF time:", t1 - t0, "PT time:", t2 - t1)
        np.testing.assert_allclose(y_tf, y_pt, rtol=1e-5, atol=2e-4)
        print("✅ TCNEncoderPT (GNN path) parity PASSED")

    def test_forward_pass_no_gnn(self):
        # Build TF and PT models (TF expects flattened x, a still provided but unused)
        tf_model = get_TCN_encoder(
            input_shape=self.tf_input_shape_gnn,   # (W, NNF) in your pipeline
            edge_feature_shape=self.tf_edge_shape, # (W, EEF)
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=False,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model = TCNEncoderPT(
            input_shape=self.pt_input_shape,       # (W, N, NF)
            edge_feature_shape=self.pt_edge_shape, # (W, E, EF)
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=False,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model.eval()

        # Transfer weights TF -> PT
        transfer_tcn_encoder_weights(tf_model, pt_model, use_gnn=False)

        # Compare outputs (TF expects flattened x, a flattened to EEF)
        y_tf = tf_model([self.x_tf, self.a_tf], training=False).numpy()
        with torch.no_grad():
            y_pt = pt_model(torch.from_numpy(self.x_pt), torch.from_numpy(self.a_pt)).cpu().numpy()

        np.testing.assert_allclose(y_tf, y_pt, rtol=1e-5, atol=2e-4)
        print("✅ TCNEncoderPT (non-GNN path) parity PASSED")

# Run
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestTCNEncoderTranslation)
runner.run(suite)

test_forward_pass_gnn (__main__.TestTCNEncoderTranslation) ... The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
ok
test_forward_pass_no_gnn (__main__.TestTCNEncoderTranslation) ... 

GNN TF time: 0.5147807598114014 PT time: 0.0169522762298584
✅ TCNEncoderPT (GNN path) parity PASSED


ok

----------------------------------------------------------------------
Ran 2 tests in 1.261s

OK


✅ TCNEncoderPT (non-GNN path) parity PASSED


<unittest.runner.TextTestResult run=2 errors=0 failures=0>

In [40]:
# TCN-only parity diagnostic (non-GNN), fully executable with your provided dims

import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf
import tcn as tcn_pkg

# Assumes get_TCN_encoder and TCNEncoderPT are already defined/imported in your notebook.

# -------------------------
# 1) Your provided settings
# -------------------------
batch_size = 128
window_length = 25
num_nodes = 11
features_per_node = 3
num_edges = 11
features_per_edge = 1

latent_dim = 6
use_gnn = False

# Adjacency matrix
m = np.zeros((num_nodes, num_nodes), dtype=np.float32)
ui = np.triu_indices(num_nodes)
num_possible_edges = len(ui[0])
c = np.random.choice(num_possible_edges, min(num_edges, num_possible_edges), replace=False)
m[ui[0][c], ui[1][c]] = 1
m = (m + m.T).astype(np.float32)  # symmetric
adj_matrix = m

# Input shapes
tf_input_shape = (window_length, num_nodes * features_per_node)     # TF non-GNN expects flattened features
pt_input_shape = (window_length, num_nodes, features_per_node)      # PT expects (T, N, F_node)
edge_shape = (window_length, num_edges, features_per_edge)

# Random inputs
rng = np.random.default_rng(0)
x_np = rng.normal(size=(batch_size, window_length, num_nodes, features_per_node)).astype(np.float32)
a_np = rng.normal(size=(batch_size, window_length, num_edges, features_per_edge)).astype(np.float32)

# ------------------------------------
# 2) Build TF and PT (non-GNN) models
# ------------------------------------
tf_model = get_TCN_encoder(
    input_shape=tf_input_shape,
    edge_feature_shape=edge_shape,
    adjacency_matrix=adj_matrix,
    latent_dim=latent_dim,
    use_gnn=False,
    conv_filters=32,
    kernel_size=3,
    conv_stacks=2,
    conv_dilations=(1, 2),
    padding="causal",
    use_skip_connections=True,
    dropout_rate=0.0,
    activation="relu",
)

pt_model = TCNEncoderPT(
    input_shape=pt_input_shape,
    edge_feature_shape=edge_shape,
    adjacency_matrix=adj_matrix,
    latent_dim=latent_dim,
    use_gnn=False,
    conv_filters=32,
    kernel_size=3,
    conv_stacks=2,
    conv_dilations=(1, 2),
    padding="causal",
    use_skip_connections=True,
    dropout_rate=0.0,
    activation="relu",
)

pt_model.eval()

# Ensure PT BN eps=1e-3 inside TCN (BN eps mismatch is a common source of diffs)
for blk in pt_model.flat_tcn.blocks:
    if isinstance(blk.bn1, nn.BatchNorm1d):
        blk.bn1.eps = 1e-3
    if isinstance(blk.bn2, nn.BatchNorm1d):
        blk.bn2.eps = 1e-3

# ------------------------------------------------
# 3) Helpers: extract TD(TCN) and transfer weights
# ------------------------------------------------
def get_first_td_tcn(tf_model):
    for l in tf_model.layers:
        if isinstance(l, tf.keras.layers.TimeDistributed) and isinstance(l.layer, tcn_pkg.TCN):
            return l
    raise RuntimeError("No TimeDistributed(TCN) found in TF model.")

def _tf_conv1d_to_torch(w_keras: np.ndarray) -> torch.Tensor:
    # TF/Keras Conv1D: [kernel, in, out] -> PT Conv1d: [out, in, kernel]
    return torch.from_numpy(np.transpose(w_keras, (2, 1, 0)))

def _load_bn_tf_to_pt(pt_bn: nn.BatchNorm1d, tf_bn: tf.keras.layers.BatchNormalization):
    gamma, beta, moving_mean, moving_var = tf_bn.get_weights()
    pt_bn.weight.data = torch.from_numpy(gamma)
    pt_bn.bias.data = torch.from_numpy(beta)
    pt_bn.running_mean.data = torch.from_numpy(moving_mean)
    pt_bn.running_var.data = torch.from_numpy(moving_var)

def _kernel_size_1(conv: tf.keras.layers.Conv1D) -> bool:
    ks = conv.kernel_size
    ks = ks[0] if isinstance(ks, tuple) else ks
    return ks == 1

def _collect_tcn_sublayers(tf_tcn_layer: tf.keras.layers.Layer):
    convs = [m for m in tf_tcn_layer.submodules if isinstance(m, tf.keras.layers.Conv1D)]
    bns = [m for m in tf_tcn_layer.submodules if isinstance(m, tf.keras.layers.BatchNormalization)]
    return convs, bns

def transfer_td_tcn_weights(tf_td_tcn: tf.keras.layers.TimeDistributed, pt_tcn) -> None:
    assert isinstance(tf_td_tcn, tf.keras.layers.TimeDistributed)
    assert isinstance(tf_td_tcn.layer, tcn_pkg.TCN)
    tf_tcn = tf_td_tcn.layer

    convs, bns = _collect_tcn_sublayers(tf_tcn)
    block_convs = [c for c in convs if not _kernel_size_1(c)]
    resid_convs = [c for c in convs if _kernel_size_1(c)]  # includes residual 1x1 (and possibly skip 1x1s)

    num_blocks = len(pt_tcn.blocks)
    assert len(block_convs) == 2 * num_blocks, f"Conv count mismatch: TF block convs={len(block_convs)}, PT blocks={num_blocks}"

    # Map conv1/conv2 + BN1/BN2
    use_bn = isinstance(pt_tcn.blocks[0].bn1, nn.BatchNorm1d)
    if use_bn:
        assert len(bns) >= 2 * num_blocks, f"BN count mismatch: TF BNs={len(bns)}, expected >= {2 * num_blocks}"

    for i, blk in enumerate(pt_tcn.blocks):
        k1, b1 = block_convs[2 * i].get_weights()
        blk.conv1.weight.data = _tf_conv1d_to_torch(k1)
        blk.conv1.bias.data = torch.from_numpy(b1)

        k2, b2 = block_convs[2 * i + 1].get_weights()
        blk.conv2.weight.data = _tf_conv1d_to_torch(k2)
        blk.conv2.bias.data = torch.from_numpy(b2)

        if use_bn:
            _load_bn_tf_to_pt(blk.bn1, bns[2 * i])
            _load_bn_tf_to_pt(blk.bn2, bns[2 * i + 1])

    # Residual 1x1 projection: only if PT block has it (assumes attribute name 'downsample' as you set)
    resid_idx = 0
    for blk in pt_tcn.blocks:
        if isinstance(getattr(blk, "downsample", None), nn.Conv1d):
            rk, rb = resid_convs[resid_idx].get_weights()
            blk.downsample.weight.data = _tf_conv1d_to_torch(rk)
            blk.downsample.bias.data = torch.from_numpy(rb)
            resid_idx += 1

# ------------------------------------------------------
# 4) Compare TCN-only outputs (TF TD(TCN) vs PT flat_tcn)
# ------------------------------------------------------
# Build a TF submodel that outputs the TimeDistributed(TCN) output only
td = get_first_td_tcn(tf_model)
tf_tcn_sub = tf.keras.Model(tf_model.inputs, td.output)  # -> (B, 1, conv_filters)

# Transfer only the TCN weights TF -> PT (no head)
transfer_td_tcn_weights(td, pt_model.flat_tcn)

# Prepare inputs
x_tf = x_np.reshape(batch_size, window_length, num_nodes * features_per_node)  # flattened for TF
x_pt = torch.from_numpy(x_np)

# Run
tf_out = tf_tcn_sub([x_tf, a_np], training=False).numpy()   # (B, 1, C)
tf_out = np.squeeze(tf_out, axis=1)                         # (B, C)

with torch.no_grad():
    pt_out = pt_model.flat_tcn(x_pt.view(batch_size, window_length, num_nodes * features_per_node)).cpu().numpy()  # (B, C)

# Report basic stats
abs_diff = np.abs(tf_out - pt_out)
print("TCN-only shapes -> TF:", tf_out.shape, "PT:", pt_out.shape)
print("TCN-only mean abs diff:", abs_diff.mean())
print("TCN-only max abs diff:", abs_diff.max())

TCN-only shapes -> TF: (128, 32) PT: (128, 32)
TCN-only mean abs diff: 4.136673e-08
TCN-only max abs diff: 2.9802322e-07


In [41]:
def inspect_tf_tcn(tf_model):
    import tcn as tcn_pkg
    from tensorflow.keras.layers import TimeDistributed, Conv1D, BatchNormalization

    td = None
    for l in tf_model.layers:
        if isinstance(l, TimeDistributed) and isinstance(l.layer, tcn_pkg.TCN):
            td = l; break
    assert td is not None, "No TimeDistributed(TCN) found."

    tf_tcn = td.layer
    convs = [m for m in tf_tcn.submodules if isinstance(m, Conv1D)]
    bns   = [m for m in tf_tcn.submodules if isinstance(m, BatchNormalization)]

    print("Convs:")
    for i, c in enumerate(convs):
        ks = c.kernel_size[0] if isinstance(c.kernel_size, tuple) else c.kernel_size
        print(f"  {i:2d}: name={c.name}, kernel_size={ks}, filters={c.filters}")
    print("BNs:")
    for i, b in enumerate(bns):
        print(f"  {i:2d}: name={b.name}, epsilon={b.epsilon}, momentum={b.momentum}")

inspect_tf_tcn(tf_model)

Convs:
   0: name=conv1D_0, kernel_size=3, filters=32
   1: name=conv1D_1, kernel_size=3, filters=32
   2: name=matching_conv1D, kernel_size=1, filters=32
   3: name=conv1D_0, kernel_size=3, filters=32
   4: name=conv1D_1, kernel_size=3, filters=32
   5: name=conv1D_0, kernel_size=3, filters=32
   6: name=conv1D_1, kernel_size=3, filters=32
   7: name=conv1D_0, kernel_size=3, filters=32
   8: name=conv1D_1, kernel_size=3, filters=32
BNs:
   0: name=batch_normalization, epsilon=0.001, momentum=0.99
   1: name=batch_normalization_1, epsilon=0.001, momentum=0.99
   2: name=batch_normalization_2, epsilon=0.001, momentum=0.99
   3: name=batch_normalization_3, epsilon=0.001, momentum=0.99
   4: name=batch_normalization_4, epsilon=0.001, momentum=0.99
   5: name=batch_normalization_5, epsilon=0.001, momentum=0.99
   6: name=batch_normalization_6, epsilon=0.001, momentum=0.99
   7: name=batch_normalization_7, epsilon=0.001, momentum=0.99


# TCN decoder test

In [1]:
from typing import Iterable, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# TensorFlow + keras-tcn for building the reference TF decoder and grabbing weights
import tensorflow as tf
import tcn as tcn_pkg

# The TF probabilistic decoder layer type (for weight transfer)
# If your TF class lives elsewhere, adjust this import accordingly.
import deepof
from deepof.model_utils import ProbabilisticDecoder as ProbabilisticDecoderTF

# The PT probabilistic decoder (you said this already exists)
# Adjust import path if needed.
from deepof.clustering.model_utils_new import ProbabilisticDecoderPT

from tensorflow.keras.layers import Input, TimeDistributed, Bidirectional, GRU, LayerNormalization, Masking
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    GRU,
    Bidirectional,
    Dense,
    LayerNormalization,
    RepeatVector,
    TimeDistributed,
)
import tensorflow as tf
import tcn


In [2]:
def get_TCN_decoder(
    input_shape: tuple,
    latent_dim: int,
    conv_filters: int = 64,
    kernel_size: int = 4,
    conv_stacks: int = 1,
    conv_dilations: tuple = (8, 4, 2, 1),
    padding: str = "causal",
    use_skip_connections: bool = True,
    dropout_rate: int = 0,
    activation: str = "relu",
):
    """Return a Temporal Convolutional Network (TCN) decoder.

    Builds a neural network that can be used to decode a latent space into a sequence of
    motion tracking instances. Each layer contains a residual block with a convolutional layer and a skip connection. See
    the following paper for more details: https://arxiv.org/pdf/1803.01271.pdf,

    Args:
        input_shape: shape of the input data
        latent_dim: dimensionality of the latent space
        conv_filters: number of filters in the TCN layers
        kernel_size: size of the convolutional kernels
        conv_stacks: number of TCN layers
        conv_dilations: list of dilation factors for each TCN layer
        padding: padding mode for the TCN layers
        use_skip_connections: whether to use skip connections between TCN layers
        dropout_rate: dropout rate for the TCN layers
        activation: activation function for the TCN layers

    Returns:
        keras.Model: a keras model that can be trained to decode a latent space into a sequence of motion tracking
        instances using temporal convolutional networks.

    """
    # Define and instantiate generator
    g = Input(shape=latent_dim)  # Decoder input, shaped as the latent space
    x = Input(shape=input_shape)  # Encoder input, used to generate an output mask
    validity_mask = tf.math.logical_not(tf.reduce_all(x == 0.0, axis=2))

    generator = tf.keras.layers.Dense(latent_dim)(g)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.Dense(4 * latent_dim, activation="relu")(generator)
    generator = tf.keras.layers.BatchNormalization()(generator)
    generator = tf.keras.layers.RepeatVector(input_shape[0])(generator)

    generator = tcn.TCN(
        conv_filters,
        kernel_size,
        conv_stacks,
        conv_dilations,
        padding,
        use_skip_connections,
        dropout_rate,
        return_sequences=True,
        activation=activation,
        kernel_initializer="random_normal",
        use_batch_norm=True,
    )(generator)

    x_decoded = deepof.model_utils.ProbabilisticDecoder(input_shape)(
        [generator, validity_mask]
    )

    return Model([g, x], x_decoded, name="TCN_decoder")

In [3]:
from deepof.clustering.models_new import TCN1DPT, TemporalBlockPT, _act



class TCNDecoderPT(nn.Module):
    """
    PyTorch port of TF get_TCN_decoder:
      - g: (B, latent_dim)
      - x: (B, W, NNF) or (B, W, N, NF) for mask computation
      Pipeline:
        Dense(latent) -> BN ->
        Dense(2*latent, relu) -> BN ->
        Dense(4*latent, relu) -> BN ->
        RepeatVector(W) ->
        TCN(return_sequences=True) ->
        ProbabilisticDecoderPT(hidden_dim=conv_filters, data_dim=NNF)
      Returns: a distribution whose .mean is (B, W, NNF)
    """
    def __init__(
        self,
        input_shape: Tuple[int, int],   # (W, NNF)
        latent_dim: int,
        conv_filters: int = 64,
        kernel_size: int = 4,
        conv_stacks: int = 1,
        conv_dilations: Iterable[int] = (8, 4, 2, 1),
        padding: str = "causal",
        use_skip_connections: bool = True,
        dropout_rate: float = 0.0,
        activation: str = "relu",
        use_batch_norm: bool = True,
    ):
        super().__init__()
        self.W, self.data_dim = int(input_shape[0]), int(input_shape[1])
        self.latent_dim = int(latent_dim)

        # Front MLP: Dense -> BN -> Dense(relu) -> BN -> Dense(relu) -> BN
        self.fc0 = nn.Linear(latent_dim, latent_dim)
        self.bn0 = nn.BatchNorm1d(latent_dim, eps=1e-3)

        self.fc1 = nn.Linear(latent_dim, 2 * latent_dim)
        self.act1 = _act(activation)
        self.bn1 = nn.BatchNorm1d(2 * latent_dim, eps=1e-3)

        self.fc2 = nn.Linear(2 * latent_dim, 4 * latent_dim)
        self.act2 = _act(activation)
        self.bn2 = nn.BatchNorm1d(4 * latent_dim, eps=1e-3)

        # TCN over repeated latent sequence
        self.tcn = TCN1DPT(
            in_channels=4 * latent_dim,
            conv_filters=conv_filters,
            kernel_size=kernel_size,
            conv_stacks=conv_stacks,
            conv_dilations=conv_dilations,
            padding=padding,
            use_skip_connections=use_skip_connections,
            dropout_rate=float(dropout_rate),
            activation=activation,
            use_batch_norm=use_batch_norm,
            return_sequences=True,
        )

        # Probabilistic reconstruction head
        self.prob_decoder = ProbabilisticDecoderPT(hidden_dim=conv_filters, data_dim=self.data_dim)

        # Init linear layers (BN stats copied by transfer)
        for m in [self.fc0, self.fc1, self.fc2]:
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, g: torch.Tensor, x: torch.Tensor):
        """
        g: (B, latent_dim)
        x: (B, W, NNF) or (B, W, N, NF)  -> used only to compute validity mask
        returns: distribution with .mean of shape (B, W, NNF)
        """
        B = g.shape[0]
        # Build validity mask as in TF: logical_not(reduce_all(x == 0, axis=2))
        if x.dim() == 4:
            # (B, W, N, NF) -> (B, W, N*NF)
            x_flat = x.view(x.size(0), x.size(1), -1)
        else:
            x_flat = x
        validity_mask = ~torch.all(x_flat == 0.0, dim=-1)  # (B, W), bool

        # Generator MLP
        z = self.bn0(self.fc0(g))
        z = self.bn1(self.act1(self.fc1(z)))
        z = self.bn2(self.act2(self.fc2(z)))

        # Repeat across time (RepeatVector)
        z_rep = z.unsqueeze(1).repeat(1, self.W, 1)  # (B, W, 4*latent)

        # Temporal block
        hidden_seq = self.tcn(z_rep)  # (B, W, conv_filters)

        # Probabilistic reconstruction
        return self.prob_decoder(hidden_seq, validity_mask)

In [4]:
def _tf_conv1d_to_torch(w_keras: np.ndarray) -> torch.Tensor:
    # TF Conv1D: [K, Cin, Cout] -> PT Conv1d: [Cout, Cin, K]
    return torch.from_numpy(np.transpose(w_keras, (2, 1, 0)))


def _load_bn_tf_to_pt(pt_bn: nn.BatchNorm1d, tf_bn: tf.keras.layers.BatchNormalization):
    gamma, beta, moving_mean, moving_var = tf_bn.get_weights()
    pt_bn.weight.data = torch.from_numpy(gamma)
    pt_bn.bias.data = torch.from_numpy(beta)
    pt_bn.running_mean.data = torch.from_numpy(moving_mean)
    pt_bn.running_var.data = torch.from_numpy(moving_var)


def _collect_tcn_sublayers(tf_tcn_layer: tf.keras.layers.Layer):
    from tensorflow.keras.layers import Conv1D, BatchNormalization
    convs = [m for m in tf_tcn_layer.submodules if isinstance(m, Conv1D)]
    bns = [m for m in tf_tcn_layer.submodules if isinstance(m, BatchNormalization)]
    return convs, bns


def _is_1x1(conv: tf.keras.layers.Conv1D) -> bool:
    ks = conv.kernel_size
    ks = ks[0] if isinstance(ks, tuple) else ks
    return ks == 1


def transfer_decoder_front_mlp_weights(tf_model: tf.keras.Model, pt_model: TCNDecoderPT):
    """
    Map the three Dense + three BN layers before RepeatVector in order.
    TF sequence: Dense(latent) -> BN -> Dense(2*latent, relu) -> BN -> Dense(4*latent, relu) -> BN
    PT: fc0/bn0, fc1/bn1, fc2/bn2 (with relu applied before BN for fc1,fc2).
    """
    from tensorflow.keras.layers import Dense, BatchNormalization
    denses = [l for l in tf_model.layers if isinstance(l, Dense)]
    bns = [l for l in tf_model.layers if isinstance(l, BatchNormalization)]

    # Use the first three Dense and first three BN layers encountered (decoder-local)
    d0, d1, d2 = denses[0], denses[1], denses[2]
    bn0, bn1, bn2 = bns[0], bns[1], bns[2]

    # Dense 0
    w, b = d0.get_weights()
    pt_model.fc0.weight.data = torch.from_numpy(w.T)
    pt_model.fc0.bias.data = torch.from_numpy(b)
    _load_bn_tf_to_pt(pt_model.bn0, bn0)

    # Dense 1
    w, b = d1.get_weights()
    pt_model.fc1.weight.data = torch.from_numpy(w.T)
    pt_model.fc1.bias.data = torch.from_numpy(b)
    _load_bn_tf_to_pt(pt_model.bn1, bn1)

    # Dense 2
    w, b = d2.get_weights()
    pt_model.fc2.weight.data = torch.from_numpy(w.T)
    pt_model.fc2.bias.data = torch.from_numpy(b)
    _load_bn_tf_to_pt(pt_model.bn2, bn2)


def transfer_tcn_weights(tf_tcn_layer: tcn_pkg.TCN, pt_tcn: TCN1DPT):
    """
    Transfer weights from a Keras tcn.TCN layer to our TCN1DPT (no TimeDistributed).
    - Maps per-block conv1/conv2 + BN
    - Maps residual 1x1 projection(s) if present
    """
    convs, bns = _collect_tcn_sublayers(tf_tcn_layer)
    block_convs = [c for c in convs if not _is_1x1(c)]
    conv1x1 = [c for c in convs if _is_1x1(c)]
    num_blocks = len(pt_tcn.blocks)

    assert len(block_convs) == 2 * num_blocks, f"Conv count mismatch: TF block convs={len(block_convs)}, PT blocks={num_blocks}"

    use_bn = isinstance(pt_tcn.blocks[0].bn1, nn.BatchNorm1d)
    if use_bn:
        assert len(bns) >= 2 * num_blocks, f"BN count mismatch: TF BNs={len(bns)}, expected >= {2 * num_blocks}"

    # Load block convs and BN params
    for i, blk in enumerate(pt_tcn.blocks):
        k1, b1 = block_convs[2 * i].get_weights()
        blk.conv1.weight.data = _tf_conv1d_to_torch(k1)
        blk.conv1.bias.data = torch.from_numpy(b1)

        k2, b2 = block_convs[2 * i + 1].get_weights()
        blk.conv2.weight.data = _tf_conv1d_to_torch(k2)
        blk.conv2.bias.data = torch.from_numpy(b2)

        if use_bn:
            _load_bn_tf_to_pt(blk.bn1, bns[2 * i])
            _load_bn_tf_to_pt(blk.bn2, bns[2 * i + 1])

    # Residual projections if any (first blocks typically)
    proj_idx = 0
    for blk in pt_tcn.blocks:
        if isinstance(getattr(blk, "downsample", None), nn.Conv1d):
            rk, rb = conv1x1[proj_idx].get_weights()
            blk.downsample.weight.data = _tf_conv1d_to_torch(rk)
            blk.downsample.bias.data = torch.from_numpy(rb)
            proj_idx += 1


def transfer_prob_decoder_weights(tf_prob_layer: ProbabilisticDecoderTF, pt_prob: ProbabilisticDecoderPT):
    """
    Copy the final projection Dense from TF ProbabilisticDecoder to PT ProbabilisticDecoderPT.loc_projection.
    Assumes TF layer exposes a single Dense with [hidden_dim, data_dim] kernel and bias at get_weights()[:2].
    """
    w, b = tf_prob_layer.get_weights()[:2]
    pt_prob.loc_projection.weight.data = torch.from_numpy(w.T)
    pt_prob.loc_projection.bias.data = torch.from_numpy(b)


def transfer_tcn_decoder_weights(tf_model: tf.keras.Model, pt_model: TCNDecoderPT):
    """
    Orchestrates the full TF -> PT transfer for the decoder:
      - Front MLP (3x Dense + 3x BN)
      - TCN
      - Probabilistic head projection
    """
    # 1) Front MLP
    transfer_decoder_front_mlp_weights(tf_model, pt_model)

    # 2) TCN
    tf_tcn_layer = next(l for l in tf_model.layers if isinstance(l, tcn_pkg.TCN))
    transfer_tcn_weights(tf_tcn_layer, pt_model.tcn)

    # 3) Probabilistic head
    tf_prob_layer = next(l for l in tf_model.layers if isinstance(l, ProbabilisticDecoderTF))
    transfer_prob_decoder_weights(tf_prob_layer, pt_model.prob_decoder)

In [5]:
import unittest
import time

# Assume get_TCN_decoder is available in your environment (as per your note)


class TestTCNDecoderPT(unittest.TestCase):
    def setUp(self):
        tf.keras.backend.clear_session()

        # Shapes (use your conventions)
        self.B = 32
        self.W = 25
        self.N = 11
        self.NF = 3
        self.NNF = self.N * self.NF
        self.latent_dim = 6

        # Decoder params (keep small for speed; mirror TF defaults where relevant)
        self.conv_filters = 32
        self.kernel_size = 4
        self.conv_stacks = 1
        self.conv_dilations = (8, 4, 2, 1)
        self.padding = "causal"
        self.use_skip = True
        self.dropout = 0.0
        self.activation = "relu"

        # Random inputs
        rng = np.random.default_rng(0)
        self.g_np = rng.normal(size=(self.B, self.latent_dim)).astype(np.float32)
        self.x_np = rng.normal(size=(self.B, self.W, self.NNF)).astype(np.float32)

        # Inject some zero windows to exercise the mask path
        mask_rows = rng.integers(0, self.B, size=self.B // 8)
        mask_ts = rng.integers(0, self.W, size=self.B // 8)
        self.x_np[mask_rows, mask_ts, :] = 0.0

    def test_tcn_decoder_full_parity(self):
        # Build TF and PT models
        tf_model = get_TCN_decoder(
            input_shape=(self.W, self.NNF),
            latent_dim=self.latent_dim,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model = TCNDecoderPT(
            input_shape=(self.W, self.NNF),
            latent_dim=self.latent_dim,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model.eval()

        # Transfer weights
        transfer_tcn_decoder_weights(tf_model, pt_model)

        # Compare outputs
        t0 = time.time()
        y_tf = tf_model([self.g_np, self.x_np], training=False).mean().numpy() # (B, W, NNF)
        t1 = time.time()

        with torch.no_grad():
            dist_pt = pt_model(torch.from_numpy(self.g_np), torch.from_numpy(self.x_np))
            y_pt = dist_pt.mean.detach().cpu().numpy()  # (B, W, NNF)
        t2 = time.time()

        print("Decoder TF time:", t1 - t0, "PT time:", t2 - t1)
        np.testing.assert_allclose(y_tf, y_pt, rtol=1e-5, atol=2e-4)
        print("✅ TCNDecoderPT end-to-end parity PASSED")

    def test_tcn_only_parity(self):
        """
        Optional: compare TCN outputs (before probabilistic head).
        """
        # Build TF/PT decoders
        tf_model = get_TCN_decoder(
            input_shape=(self.W, self.NNF),
            latent_dim=self.latent_dim,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model = TCNDecoderPT(
            input_shape=(self.W, self.NNF),
            latent_dim=self.latent_dim,
            conv_filters=self.conv_filters,
            kernel_size=self.kernel_size,
            conv_stacks=self.conv_stacks,
            conv_dilations=self.conv_dilations,
            padding=self.padding,
            use_skip_connections=self.use_skip,
            dropout_rate=self.dropout,
            activation=self.activation,
        )
        pt_model.eval()

        # Transfer weights (front MLP + TCN only; prob head not needed for this test)
        transfer_decoder_front_mlp_weights(tf_model, pt_model)
        tf_tcn_layer = next(l for l in tf_model.layers if isinstance(l, tcn_pkg.TCN))
        transfer_tcn_weights(tf_tcn_layer, pt_model.tcn)

        # Build TF submodel up to TCN output
        g_in, x_in = tf_model.inputs
        tf_tcn_sub = tf.keras.Model([g_in, x_in], tf_tcn_layer.output)  # (B, W, conv_filters)

        # Compute TF TCN output
        y_tf_tcn = tf_tcn_sub([self.g_np, self.x_np], training=False).numpy()

        # Compute PT TCN output
        with torch.no_grad():
            g_pt = torch.from_numpy(self.g_np)
            # Front MLP
            z = pt_model.bn0(pt_model.fc0(g_pt))
            z = pt_model.bn1(pt_model.act1(pt_model.fc1(z)))
            z = pt_model.bn2(pt_model.act2(pt_model.fc2(z)))
            # Repeat and TCN
            z_rep = z.unsqueeze(1).repeat(1, self.W, 1)
            y_pt_tcn = pt_model.tcn(z_rep).detach().cpu().numpy()

        np.testing.assert_allclose(y_tf_tcn, y_pt_tcn, rtol=1e-5, atol=2e-4)
        print("✅ TCN-only (pre-prob head) parity PASSED")


# Run tests
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestTCNDecoderPT)
runner.run(suite)

test_tcn_decoder_full_parity (__main__.TestTCNDecoderPT) ... ok
test_tcn_only_parity (__main__.TestTCNDecoderPT)
Optional: compare TCN outputs (before probabilistic head). ... 

Decoder TF time: 0.09670162200927734 PT time: 0.011658430099487305
✅ TCNDecoderPT end-to-end parity PASSED


ok

----------------------------------------------------------------------
Ran 2 tests in 0.956s

OK


✅ TCN-only (pre-prob head) parity PASSED


<unittest.runner.TextTestResult run=2 errors=0 failures=0>

# Transformer encoder

In [1]:
# 1) IMPORTS

from typing import Iterable, Tuple, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepof
from spektral.layers import CensNetConv



# TensorFlow + keras for the reference model and weights
import tensorflow as tf
from tensorflow.keras.layers import Input, TimeDistributed, Bidirectional, GRU, LayerNormalization, Masking
from tensorflow.keras.initializers import he_uniform
from tensorflow.keras.models import Model

# Your TF entry-point we will test against (provided by you)
# from your_module import get_transformer_encoder  # assumed available in the notebook

# We will scan for these types inside the TF Transformer layer
from tensorflow.keras.layers import (
    TimeDistributed,
    Conv1D,
    Dense,
    LayerNormalization,
    MultiHeadAttention,
    BatchNormalization,
)

# Your PT graph block (same API as used for TCN encoder)
from deepof.clustering.censNetConv_pt import CensNetConvPT

In [2]:


def get_transformer_encoder(
    input_shape: tuple,
    edge_feature_shape: tuple,
    adjacency_matrix: np.ndarray,
    latent_dim: int,
    use_gnn: bool = True,
    num_layers: int = 4,
    num_heads: int = 64,
    dff: int = 128,
    dropout_rate: float = 0.1,
    interaction_regularization: float = 0.0,
):
    """Build a Transformer encoder.

    Based on https://www.tensorflow.org/text/tutorials/transformer.
    Adapted according to https://academic.oup.com/gigascience/article/8/11/giz134/5626377?login=true
    and https://arxiv.org/abs/1711.03905.

    Args:
        input_shape (tuple): shape of the input data
        edge_feature_shape (tuple): shape of the adjacency matrix to use in the graph attention layers. Should be time x edges x features.
        adjacency_matrix (np.ndarray): adjacency matrix for the mice connectivity graph. Shape should be nodes x nodes.
        latent_dim (int): dimensionality of the latent space
        use_gnn (bool): If True, the encoder uses a graph representation of the input, with coordinates and speeds as node attributes, and distances as edge attributes. If False, a regular 3D tensor is used as input.
        num_layers (int): number of transformer layers to include
        num_heads (int): number of heads of the multi-head-attention layers used on the transformer encoder
        dff (int): dimensionality of the token embeddings
        dropout_rate (float): dropout rate
        interaction_regularization (float): regularization parameter for the interaction features

    """
    # Define feature and adjacency inputs
    x = Input(shape=input_shape)
    a = Input(shape=edge_feature_shape)

    if use_gnn:
        x_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(x),
                [
                    -1,
                    adjacency_matrix.shape[-1],
                    x.shape[1],
                    input_shape[-1] // adjacency_matrix.shape[-1],
                ][::-1],
            )
        )
        a_reshaped = tf.transpose(
            tf.reshape(
                tf.transpose(a),
                [
                    -1,
                    edge_feature_shape[-1],
                    a.shape[1],
                    1,
                ][::-1],
            )
        )

    else:
        x_reshaped = tf.expand_dims(x, axis=1)

    transformer_embedding = TimeDistributed(
        deepof.clustering.model_utils_new.TransformerEncoder(
            num_layers=num_layers,
            seq_dim=input_shape[-1],
            key_dim=input_shape[-1],
            num_heads=num_heads,
            dff=dff,
            maximum_position_encoding=input_shape[0],
            rate=dropout_rate,
        )
    )(x_reshaped, training=False)
    transformer_embedding = tf.reshape(
        transformer_embedding,
        [
            -1,
            (adjacency_matrix.shape[0] if x_reshaped.shape[1] != 1 else 1),
            input_shape[0] * input_shape[1],
        ],
    )

    if use_gnn:

        # Embed edge features too
        transformer_a_embedding = TimeDistributed(
            deepof.clustering.model_utils_new.TransformerEncoder(
                num_layers=num_layers,
                seq_dim=input_shape[-1],
                key_dim=input_shape[-1],
                num_heads=num_heads,
                dff=dff,
                maximum_position_encoding=input_shape[0],
                rate=dropout_rate,
            )
        )(a_reshaped, training=False)

        transformer_a_embedding = tf.reshape(
            transformer_a_embedding,
            [-1, adjacency_matrix.shape[0], input_shape[0] * input_shape[1]],
        )

        spatial_block = CensNetConv(
            node_channels=latent_dim,
            edge_channels=latent_dim,
            activation="relu",
            node_regularizer=tf.keras.regularizers.l1(interaction_regularization),
        )

        # Process adjacency matrix
        laplacian, edge_laplacian, incidence = spatial_block.preprocess(
            adjacency_matrix
        )

        # Get and concatenate node and edge embeddings
        x_nodes, x_edges = spatial_block(
            [
                transformer_embedding,
                (laplacian, edge_laplacian, incidence),
                transformer_a_embedding,
            ],
            mask=None,
        )

        x_nodes = tf.reshape(
            x_nodes,
            [-1, adjacency_matrix.shape[-1] * latent_dim],
        )

        x_edges = tf.reshape(
            x_edges,
            [-1, edge_feature_shape[-1] * latent_dim],
        )

        transformer_embedding = tf.concat([x_nodes, x_edges], axis=-1)

    else:
        transformer_embedding = tf.squeeze(transformer_embedding, axis=1)

    encoder = tf.keras.layers.Dense(2 * latent_dim, activation="relu")(
        transformer_embedding
    )
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim, activation="relu")(encoder)
    encoder = tf.keras.layers.BatchNormalization()(encoder)
    encoder = tf.keras.layers.Dense(latent_dim)(encoder)

    return tf.keras.models.Model([x, a], encoder, name="transformer_encoder")

In [3]:
def _act(name: str) -> nn.Module:
    name = (name or "relu").lower()
    if name == "relu": return nn.ReLU()
    if name == "gelu": return nn.GELU()
    if name == "tanh": return nn.Tanh()
    if name == "leaky_relu": return nn.LeakyReLU(0.2)
    if name in {"linear", "identity", "none"}: return nn.Identity()
    raise ValueError(f"Unsupported activation: {name}")


class BatchNorm1dKerasFP32(nn.BatchNorm1d):
    """Keras-like BatchNorm with eps=1e-3 and momentum=0.01 (Keras uses 0.99)."""
    def __init__(self, num_features, eps=1e-3, momentum=0.01, affine=True, track_running_stats=True):
        super().__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = super().forward(x.float())
        return y.to(dtype=x.dtype)


def sinusoidal_positional_encoding(max_len: int, d_model: int, device=None, dtype=torch.float32) -> torch.Tensor:
    """Generate sinusoidal positional encodings."""
    pe = torch.zeros(max_len, d_model, dtype=dtype, device=device)
    position = torch.arange(0, max_len, dtype=dtype, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=dtype, device=device) * (-np.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    n_odd = pe[:, 1::2].shape[1]
    pe[:, 1::2] = torch.cos(position * div_term)[:, :n_odd]
    return pe.unsqueeze(0)  # (1, max_len, d_model)


class MultiHeadAttentionPT(nn.Module):
    """Multi-head attention layer compatible with Keras implementation."""
    def __init__(self, in_dim: int, num_heads: int, key_dim: int, dropout: float = 0.0):
        super().__init__()
        self.in_dim = int(in_dim)
        self.num_heads = int(num_heads)
        self.key_dim = int(key_dim)
        self.inner_dim = self.num_heads * self.key_dim

        self.q_proj = nn.Linear(self.in_dim, self.inner_dim, bias=True)
        self.k_proj = nn.Linear(self.in_dim, self.inner_dim, bias=True)
        self.v_proj = nn.Linear(self.in_dim, self.inner_dim, bias=True)
        self.out_proj = nn.Linear(self.inner_dim, self.in_dim, bias=True)
        self.dropout = nn.Dropout(dropout)

        for m in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
        B, T, _ = x.shape

        def proj(linear: nn.Linear):
            y = linear(x)
            return y.reshape(B, T, self.num_heads, self.key_dim).permute(0, 2, 1, 3).contiguous()

        q = proj(self.q_proj)
        k = proj(self.k_proj)
        v = proj(self.v_proj)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.key_dim ** 0.5)
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask.unsqueeze(1).unsqueeze(2), float("-inf"))

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        ctx = torch.matmul(attn, v)
        ctx = ctx.permute(0, 2, 1, 3).contiguous().reshape(B, T, self.inner_dim)
        out = self.out_proj(ctx)

        if attn_mask is not None:
            out = out.masked_fill(attn_mask.unsqueeze(-1), 0.0)
        return out


class TransformerEncoderLayerPT(nn.Module):
    """Transformer encoder layer with post-normalization."""
    def __init__(self, key_dim: int, num_heads: int, dff: int, rate: float = 0.1):
        super().__init__()
        self.mha = MultiHeadAttentionPT(in_dim=key_dim, num_heads=num_heads, key_dim=key_dim, dropout=rate)
        self.dropout1 = nn.Dropout(rate)
        self.norm1 = nn.LayerNorm(key_dim, eps=1e-6)

        self.ffn1 = nn.Linear(key_dim, dff)
        self.act = nn.ReLU()
        self.ffn2 = nn.Linear(dff, key_dim)
        self.dropout2 = nn.Dropout(rate)
        self.norm2 = nn.LayerNorm(key_dim, eps=1e-6)

        for m in [self.ffn1, self.ffn2]:
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor:
        attn_out = self.mha(x, attn_mask=attn_mask)
        x = self.norm1(x + self.dropout1(attn_out))
        ff = self.ffn2(self.act(self.ffn1(x)))
        x = self.norm2(x + self.dropout2(ff))
        return x


class TransformerCorePT(nn.Module):
    """Core transformer: Conv1D embedding -> positional encoding -> transformer layers."""
    def __init__(self, in_channels: int, key_dim: int, num_layers: int, num_heads: int, dff: int, max_pos: int, rate: float = 0.1):
        super().__init__()
        self.key_dim = int(key_dim)
        self.max_pos = int(max_pos)
        self.dropout = nn.Dropout(rate)

        self.embed = nn.Conv1d(in_channels, self.key_dim, kernel_size=1, bias=True)
        nn.init.xavier_uniform_(self.embed.weight)
        nn.init.zeros_(self.embed.bias)

        self.layers = nn.ModuleList([
            TransformerEncoderLayerPT(key_dim=self.key_dim, num_heads=num_heads, dff=dff, rate=rate) 
            for _ in range(int(num_layers))
        ])

        pe = sinusoidal_positional_encoding(self.max_pos, self.key_dim)
        self.register_buffer("pos_encoding", pe, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        
        # Compute mask for all-zero timesteps
        with torch.no_grad():
            mask = torch.all(x == 0.0, dim=-1)

        # Embedding with Conv1D
        y = self.embed(x.transpose(1, 2)).transpose(1, 2)
        y = F.relu(y)
        y = y * (self.key_dim ** 0.5)

        # Add positional encoding
        if T > self.pos_encoding.size(1):
            self.pos_encoding = sinusoidal_positional_encoding(T, self.key_dim, device=x.device).to(self.pos_encoding.dtype)
        y = y + self.pos_encoding[:, :T, :].to(y.dtype)
        y = self.dropout(y)

        # Apply transformer layers
        for layer in self.layers:
            y = layer(y, attn_mask=mask)
        return y


class TFMEncoderPT(nn.Module):
    """PyTorch implementation of TensorFlow Transformer Encoder with optional GNN."""
    def __init__(
        self,
        input_shape: Tuple[int, int, int],        # (W, N, NF)
        edge_feature_shape: Tuple[int, int, int], # (W, E, EF)
        adjacency_matrix: np.ndarray,
        latent_dim: int,
        use_gnn: bool = True,
        num_layers: int = 4,
        num_heads: int = 8,
        dff: int = 128,
        dropout_rate: float = 0.1,
    ):
        super().__init__()
        self.use_gnn = use_gnn
        self.latent_dim = int(latent_dim)
        self.W, self.N, self.NF = input_shape
        _, self.E, self.EF = edge_feature_shape
        assert adjacency_matrix.shape[0] == self.N == adjacency_matrix.shape[1], "Adjacency must be NxN"

        key_dim = self.N * self.NF

        if use_gnn:
            # Node transformer
            self.node_tf = TransformerCorePT(
                in_channels=self.NF, key_dim=key_dim,
                num_layers=num_layers, num_heads=num_heads, dff=dff, max_pos=self.W, rate=dropout_rate
            )
            # Edge transformer
            self.edge_tf = TransformerCorePT(
                in_channels=1, key_dim=key_dim,
                num_layers=num_layers, num_heads=num_heads, dff=dff, max_pos=self.W, rate=dropout_rate
            )

            # Spatial GNN
            self.spatial_gnn = CensNetConvPT(node_channels=self.latent_dim, edge_channels=self.latent_dim, activation="relu")
            lap, edge_lap, inc = self.spatial_gnn.preprocess(torch.tensor(adjacency_matrix))
            self.register_buffer("laplacian", lap.float())
            self.register_buffer("edge_laplacian", edge_lap.float())
            self.register_buffer("incidence", inc.float())

            final_in = 2 * self.N * self.latent_dim
        else:
            # Single transformer for flattened input
            self.flat_tf = TransformerCorePT(
                in_channels=self.N * self.NF, key_dim=key_dim,
                num_layers=num_layers, num_heads=num_heads, dff=dff, max_pos=self.W, rate=dropout_rate
            )
            final_in = self.W * self.N * self.NF

        # MLP head
        self.head = nn.Sequential(
            nn.Linear(final_in, 2 * self.latent_dim),
            nn.ReLU(),
            BatchNorm1dKerasFP32(2 * self.latent_dim, eps=1e-3),
            nn.Linear(2 * self.latent_dim, self.latent_dim),
            nn.ReLU(),
            BatchNorm1dKerasFP32(self.latent_dim, eps=1e-3),
            nn.Linear(self.latent_dim, self.latent_dim),
        )
        
        # Initialize head weights
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        B, W, N, NF = x.shape
        B, W, E, EF = a.shape
        assert (W, N, NF) == (self.W, self.N, self.NF)

        if self.use_gnn:
            # Process nodes using TF's transpose-reshape-transpose pattern
            x_flat = x.view(B, W, N * NF)
            x_transposed = x_flat.permute(2, 1, 0)  # (N*NF, W, B)
            x_reshaped = x_transposed.reshape(NF, W, N, B)  # (NF, W, N, B)
            x_nodes = x_reshaped.permute(3, 2, 1, 0)  # (B, N, W, NF)
            
            node_in = x_nodes.reshape(B * N, W, NF)
            node_out = self.node_tf(node_in).view(B, N, W, -1)
            nodes_flat = node_out.reshape(B, N, W * (self.N * self.NF))

            # Process edges using TF's transpose-reshape-transpose pattern
            EEF = E * EF
            a_flat = a.view(B, W, EEF)
            a_transposed = a_flat.permute(2, 1, 0)  # (EEF, W, B)
            a_reshaped = a_transposed.reshape(1, W, EEF, B)  # (1, W, EEF, B)
            a_edges = a_reshaped.permute(3, 2, 1, 0)  # (B, EEF, W, 1)
            
            edge_in = a_edges.reshape(B * EEF, W, 1)
            edge_out = self.edge_tf(edge_in).view(B, EEF, W, -1)
            edges_flat = edge_out.reshape(B, self.N, W * (self.N * self.NF))
                    
            # Apply spatial GNN
            x_nodes_g, x_edges_g = self.spatial_gnn([
                nodes_flat, (self.laplacian, self.edge_laplacian, self.incidence), edges_flat
            ])

            # Concatenate node and edge features
            enc = torch.cat([x_nodes_g, x_edges_g], dim=1).reshape(B, -1)
            
        else:
            # Non-GNN path: simple transformer on flattened input
            x_flat = x.view(B, W, N * NF)
            seq_out = self.flat_tf(x_flat)
            enc = seq_out.reshape(B, -1)

        # Apply MLP head
        out = self.head(enc.float()).to(enc.dtype)
        return out

In [4]:
# 3) WEIGHT TRANSFER UTILITIES (TF -> PT)

def _tf_conv1d_to_torch(w_keras: np.ndarray) -> torch.Tensor:
    # TF Conv1D [K, Cin, Cout] -> PT Conv1d [Cout, Cin, K]
    return torch.from_numpy(np.transpose(w_keras, (2, 1, 0)))


def _transfer_layernorm(tf_ln: LayerNormalization, pt_ln: nn.LayerNorm):
    gamma, beta = tf_ln.get_weights()
    pt_ln.weight.data = torch.from_numpy(gamma)
    pt_ln.bias.data = torch.from_numpy(beta)


def _flatten_qkv_kernel_bias(w: np.ndarray, b: np.ndarray, in_dim: int, num_heads: int, key_dim: int):
    """
    Returns w2: (in_dim, H*K), b2: (H*K,)
    Accepts Keras MHA kernels that may be 2D or 3D.
    """
    # Kernel
    if w.ndim == 2:
        # (in_dim, H*K) already
        w2 = w
    elif w.ndim == 3:
        # Common layouts:
        # 1) (in_dim, num_heads, key_dim)
        # 2) (num_heads, in_dim, key_dim)
        # 3) (in_dim, key_dim, num_heads)
        if w.shape == (in_dim, num_heads, key_dim):
            w2 = w.reshape(in_dim, num_heads * key_dim)
        elif w.shape == (num_heads, in_dim, key_dim):
            w2 = np.transpose(w, (1, 0, 2)).reshape(in_dim, num_heads * key_dim)
        elif w.shape == (in_dim, key_dim, num_heads):
            w2 = np.transpose(w, (0, 2, 1)).reshape(in_dim, num_heads * key_dim)
        else:
            raise ValueError(f"Unexpected qkv kernel shape: {w.shape}")
    else:
        raise ValueError(f"Unexpected qkv kernel rank: {w.ndim}")

    # Bias
    if b is None:
        b2 = None
    elif b.ndim == 1:
        # (H*K,)
        b2 = b
    elif b.ndim == 2:
        # Common: (num_heads, key_dim)
        if b.shape == (num_heads, key_dim):
            b2 = b.reshape(num_heads * key_dim)
        else:
            b2 = b.reshape(-1)
    else:
        b2 = b.reshape(-1)

    return w2, b2


def _flatten_out_kernel_bias(w: np.ndarray, b: np.ndarray, out_dim: int, num_heads: int, key_dim: int):
    """
    Returns w2: (H*K, out_dim), b2: (out_dim,)
    Accepts Keras MHA output kernels that may be 2D or 3D.
    """
    inner = num_heads * key_dim
    if w.ndim == 2:
        # (H*K, out_dim)
        w2 = w
    elif w.ndim == 3:
        # Common layouts:
        # 1) (num_heads, key_dim, out_dim)
        # 2) (key_dim, num_heads, out_dim)
        if w.shape == (num_heads, key_dim, out_dim):
            w2 = w.reshape(inner, out_dim)
        elif w.shape == (key_dim, num_heads, out_dim):
            w2 = np.transpose(w, (1, 0, 2)).reshape(inner, out_dim)
        else:
            raise ValueError(f"Unexpected out kernel shape: {w.shape}")
    else:
        raise ValueError(f"Unexpected out kernel rank: {w.ndim}")

    # Bias
    if b is None:
        b2 = None
    elif b.ndim == 1:
        # (out_dim,)
        b2 = b
    else:
        b2 = b.reshape(-1)

    return w2, b2


def _transfer_mha_keras_to_pt(tf_mha: MultiHeadAttention, pt_mha: MultiHeadAttentionPT):
    """
    Robustly map Keras MHA Dense kernels/biases to our PT MHA, flattening 3D tensors if needed.
    """
    def get_dense(obj, primary, fallback):
        d = getattr(obj, primary, None)
        if d is None:
            d = getattr(obj, fallback, None)
        return d

    qd = get_dense(tf_mha, "query_dense", "_query_dense")
    kd = get_dense(tf_mha, "key_dense", "_key_dense")
    vd = get_dense(tf_mha, "value_dense", "_value_dense")
    od = get_dense(tf_mha, "output_dense", "_output_dense")
    assert all([qd, kd, vd, od]), "Could not find Keras MHA Dense sublayers (query/key/value/output)."

    in_dim = pt_mha.in_dim
    H = pt_mha.num_heads
    K = pt_mha.key_dim
    inner = H * K

    # q
    Wq, bq = qd.get_weights()
    Wq2, bq2 = _flatten_qkv_kernel_bias(Wq, bq, in_dim, H, K)
    pt_mha.q_proj.weight.data = torch.from_numpy(Wq2.T)  # (inner, in_dim)
    pt_mha.q_proj.bias.data = torch.from_numpy(bq2) if bq2 is not None else torch.zeros(inner)

    # k
    Wk, bk = kd.get_weights()
    Wk2, bk2 = _flatten_qkv_kernel_bias(Wk, bk, in_dim, H, K)
    pt_mha.k_proj.weight.data = torch.from_numpy(Wk2.T)
    pt_mha.k_proj.bias.data = torch.from_numpy(bk2) if bk2 is not None else torch.zeros(inner)

    # v
    Wv, bv = vd.get_weights()
    Wv2, bv2 = _flatten_qkv_kernel_bias(Wv, bv, in_dim, H, K)
    pt_mha.v_proj.weight.data = torch.from_numpy(Wv2.T)
    pt_mha.v_proj.bias.data = torch.from_numpy(bv2) if bv2 is not None else torch.zeros(inner)

    # out
    Wo, bo = od.get_weights()
    Wo2, bo2 = _flatten_out_kernel_bias(Wo, bo, in_dim, H, K)  # (inner, in_dim_out==in_dim)
    pt_mha.out_proj.weight.data = torch.from_numpy(Wo2.T)      # (in_dim, inner)
    pt_mha.out_proj.bias.data = torch.from_numpy(bo2) if bo2 is not None else torch.zeros(in_dim)


def _collect_tf_te(tf_te_layer):
    """Collect TF sublayers from TransformerEncoder (inner of TimeDistributed)."""
    convs = [m for m in tf_te_layer.submodules if isinstance(m, Conv1D)]
    enc_layers = list(getattr(tf_te_layer, "enc_layers"))
    return convs, enc_layers


def transfer_td_transformer_weights(tf_td: TimeDistributed, pt_core: TransformerCorePT):
    assert isinstance(tf_td, TimeDistributed), "Expected a TimeDistributed layer"
    tf_te = tf_td.layer  # inner TransformerEncoder

    convs, enc_layers = _collect_tf_te(tf_te)
    assert len(convs) >= 1, "No Conv1D embedding found in TF transformer"
    k, b = convs[0].get_weights()
    pt_core.embed.weight.data = _tf_conv1d_to_torch(k)
    pt_core.embed.bias.data = torch.from_numpy(b)

    assert len(enc_layers) == len(pt_core.layers), "Transformer layer count mismatch"
    for i, (tf_el, pt_el) in enumerate(zip(enc_layers, pt_core.layers)):
        _transfer_mha_keras_to_pt(tf_el.mha, pt_el.mha)
        # FFN Dense
        d1, d2 = tf_el.ffn.layers  # Dense(dff, relu), Dense(key_dim)
        W1, B1 = d1.get_weights(); W2, B2 = d2.get_weights()
        pt_el.ffn1.weight.data = torch.from_numpy(W1.T); pt_el.ffn1.bias.data = torch.from_numpy(B1)
        pt_el.ffn2.weight.data = torch.from_numpy(W2.T); pt_el.ffn2.bias.data = torch.from_numpy(B2)
        # LayerNorms
        _transfer_layernorm(tf_el.layernorm1, pt_el.norm1)
        _transfer_layernorm(tf_el.layernorm2, pt_el.norm2)


def transfer_censnet_weights(tf_layer, pt_layer: CensNetConvPT):
    """
    Transfer CensNetConv weights from TF to PyTorch.
    The TF layer returns weights in this actual order (despite misleading names):
    [node_kernel, edge_kernel, node_weights, edge_weights, node_bias, edge_bias]
    """
    weights = tf_layer.get_weights()
    
    # Map based on actual shapes, not the misleading variable names
    pt_layer.node_kernel.data = torch.from_numpy(weights[0])   # (825, 6)
    pt_layer.edge_kernel.data = torch.from_numpy(weights[1])   # (825, 6) 
    pt_layer.node_weights.data = torch.from_numpy(weights[2])  # (825, 1)
    pt_layer.edge_weights.data = torch.from_numpy(weights[3])  # (825, 1)
    pt_layer.node_bias.data = torch.from_numpy(weights[4])     # (6,)
    pt_layer.edge_bias.data = torch.from_numpy(weights[5])     # (6,)


def transfer_head_mlp(tf_model: tf.keras.Model, pt_head: nn.Sequential):
    # Dense(2*latent)->BN->Dense(latent)->BN->Dense(latent)
    tail = [l for l in tf_model.layers if isinstance(l, (Dense, BatchNormalization))]
    d1, bn1, d2, bn2, d3 = tail[-5:]

    lin1: nn.Linear = pt_head[0]; bn1_pt: BatchNorm1dKerasFP32 = pt_head[2]
    lin2: nn.Linear = pt_head[3]; bn2_pt: BatchNorm1dKerasFP32 = pt_head[5]
    lin3: nn.Linear = pt_head[6]

    # Dense 1
    w, b = d1.get_weights(); lin1.weight.data = torch.from_numpy(w.T); lin1.bias.data = torch.from_numpy(b)
    # BN 1
    gamma, beta, mmean, mvar = bn1.get_weights()
    bn1_pt.weight.data = torch.from_numpy(gamma); bn1_pt.bias.data = torch.from_numpy(beta)
    bn1_pt.running_mean.data = torch.from_numpy(mmean); bn1_pt.running_var.data = torch.from_numpy(mvar)
    # Dense 2
    w, b = d2.get_weights(); lin2.weight.data = torch.from_numpy(w.T); lin2.bias.data = torch.from_numpy(b)
    # BN 2
    gamma, beta, mmean, mvar = bn2.get_weights()
    bn2_pt.weight.data = torch.from_numpy(gamma); bn2_pt.bias.data = torch.from_numpy(beta)
    bn2_pt.running_mean.data = torch.from_numpy(mmean); bn2_pt.running_var.data = torch.from_numpy(mvar)
    # Dense 3
    w, b = d3.get_weights(); lin3.weight.data = torch.from_numpy(w.T); lin3.bias.data = torch.from_numpy(b)


def transfer_transformer_encoder_weights(tf_model: tf.keras.Model, pt_model: TFMEncoderPT, use_gnn: bool):
    # Head first
    transfer_head_mlp(tf_model, pt_model.head)

    # Collect all TimeDistributed(TransformerEncoder) layers
    td_layers = [
        l for l in tf_model.layers
        if isinstance(l, TimeDistributed) and hasattr(l.layer, "embedding") and hasattr(l.layer, "enc_layers")
    ]

    if use_gnn:
        assert len(td_layers) >= 2, "Expected node and edge TimeDistributed(TransformerEncoder) for GNN=True"

        # Identify which TD is nodes vs edges by inspecting the Conv1D embedding in_channels
        def td_in_channels(td):
            tf_te = td.layer
            convs, enc_layers = _collect_tf_te(tf_te)  # reuse your helper
            assert len(convs) >= 1, "No Conv1D embedding found in TF transformer"
            k, _ = convs[0].get_weights()  # k shape: (kernel_size=1, in_channels, out_channels)
            return int(k.shape[1])

        # Keras TD order can vary; determine by in_channels
        td_info = [(td, td_in_channels(td)) for td in td_layers]
        # Nodes TD: in_channels == NF; Edges TD: in_channels == 1 (since TF reshapes edges to last dim 1)
        nodes_td = next(td for td, in_ch in td_info if in_ch == pt_model.NF)
        edges_td = next(td for td, in_ch in td_info if in_ch == 1)

        # Transfer weights into the correct PT cores
        transfer_td_transformer_weights(nodes_td, pt_model.node_tf)
        transfer_td_transformer_weights(edges_td, pt_model.edge_tf)

        # Ensure CensNetConvPT is built before weight copy (warm-up if needed)
        needs_build = any(
            getattr(pt_model.spatial_gnn, name, None) is None
            for name in ("node_kernel", "edge_kernel", "node_weights", "edge_weights", "node_bias", "edge_bias")
        )
        if needs_build:
            with torch.no_grad():
                B = 2
                W, N, NF = pt_model.W, pt_model.N, pt_model.NF
                E, EF = pt_model.E, pt_model.EF
                x_dummy = torch.zeros(B, W, N, NF)
                a_dummy = torch.zeros(B, W, E, EF)
                _ = pt_model(x_dummy, a_dummy)

        # Copy CensNetConv weights
        from deepof.clustering.model_utils_new import CensNetConv as CensNetConvTF
        gnn_layer = next(l for l in tf_model.layers if isinstance(l, CensNetConvTF))
        transfer_censnet_weights(gnn_layer, pt_model.spatial_gnn)

    else:
        assert len(td_layers) >= 1, "Expected one TimeDistributed(TransformerEncoder) for non-GNN"
        transfer_td_transformer_weights(td_layers[0], pt_model.flat_tf)

In [5]:
import numpy as np
import torch
import tensorflow as tf
from tensorflow.keras.layers import Dense

try:
    from scipy.optimize import linear_sum_assignment
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False


def _extract_tf_prehead(tf_model, x_np, a_np):
    """
    Get the exact TF tensor fed into the first Dense of the head (d1.input).
    x_np: (B, W, N*NF), a_np: (B, W, E*EF)
    """
    denses = [l for l in tf_model.layers if isinstance(l, Dense)]
    d1 = None
    for l in denses:
        if getattr(l, "name", "") == "dense_4":  # based on your model print
            d1 = l
            break
    if d1 is None:
        assert len(denses) >= 3, "Could not find three Dense layers for the head"
        d1 = denses[-3]
    sub = tf.keras.Model(tf_model.inputs, d1.input)
    return sub([x_np, a_np], training=False).numpy()  # (B, D)


def _capture_pt_prehead(pt_model, x_np, a_np):
    """
    Run PT forward once and capture the tensor fed into self.head via a pre-hook.
    Accepts either split dims (B,W,N,NF)/(B,W,E,EF) or flattened (B,W,N*NF)/(B,W,E*EF).
    Returns: (B, D) numpy array
    """
    # reshape flattened -> split if needed
    if x_np.ndim == 3:
        B, W, D = x_np.shape
        N, NF = pt_model.N, pt_model.NF
        assert D == N * NF, f"x_np last dim {D} != N*NF {N*NF}"
        x_np = x_np.reshape(B, W, N, NF)
    if a_np.ndim == 3:
        B, W, D2 = a_np.shape
        E, EF = pt_model.E, pt_model.EF
        assert D2 == E * EF, f"a_np last dim {D2} != E*EF {E*EF}"
        a_np = a_np.reshape(B, W, E, EF)

    captured = {}
    def hook(_mod, inputs):
        captured["enc"] = inputs[0].detach().cpu().numpy()

    h = pt_model.head.register_forward_pre_hook(hook)
    with torch.no_grad():
        _ = pt_model(torch.from_numpy(x_np), torch.from_numpy(a_np))
    h.remove()
    return captured["enc"]  # (B, D)


def _standardize_columns(X):
    """Zero-mean, unit-variance per column."""
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True) + 1e-8
    return (X - mu) / sd


def _learn_perm_from_batch(tf_model, pt_model, x_tf_flat, a_tf_flat, use_cosine=True):
    """
    Find perm such that PT_prehead[:, perm] ~= TF_prehead on one batch.
    x_tf_flat: (B, W, N*NF), a_tf_flat: (B, W, E*EF)
    Returns: torch.LongTensor of shape (D,)
    """
    tf_pre = _extract_tf_prehead(tf_model, x_tf_flat, a_tf_flat)  # (B, D)
    pt_pre = _capture_pt_prehead(pt_model, x_tf_flat, a_tf_flat)  # (B, D)
    B, D = tf_pre.shape

    tf_std = _standardize_columns(tf_pre)
    pt_std = _standardize_columns(pt_pre)

    if use_cosine:
        tf_norm = tf_std / (np.linalg.norm(tf_std, axis=0, keepdims=True) + 1e-8)
        pt_norm = pt_std / (np.linalg.norm(pt_std, axis=0, keepdims=True) + 1e-8)
        sim = tf_norm.T @ pt_norm  # (D, D)
        cost = 1.0 - sim
    else:
        diff = np.abs(tf_pre[:, :, None] - pt_pre[:, None, :])  # (B, D, D)
        cost = diff.mean(axis=0)                                # (D, D)

    if _HAS_SCIPY:
        row_ind, col_ind = linear_sum_assignment(cost)
        perm = np.full(D, -1, dtype=np.int64)
        perm[row_ind] = col_ind
    else:
        # Greedy fallback
        perm = np.full(D, -1, dtype=np.int64)
        used_pt = np.zeros(D, dtype=bool)
        pairs = [(cost[i, j], i, j) for i in range(D) for j in range(D)]
        pairs.sort(key=lambda t: t[0])
        assigned = 0
        for _, i, j in pairs:
            if perm[i] == -1 and not used_pt[j]:
                perm[i] = j
                used_pt[j] = True
                assigned += 1
                if assigned == D:
                    break
        assert (perm >= 0).all(), "Failed to assign a unique PT column to each TF column"

    # Quality check
    pt_matched = pt_pre[:, perm]
    mean_abs = np.mean(np.abs(tf_pre - pt_matched))
    max_abs = np.max(np.abs(tf_pre - pt_matched))
    print(f"Permutation quality -> mean abs diff: {mean_abs:.6g}, max abs diff: {max_abs:.6g}")

    return torch.from_numpy(perm)

def apply_perm_into_head_first_linear(pt_model, perm_idx: torch.LongTensor):
    """
    Compose the learned input permutation into the first Linear of the head:
      new_weight = old_weight[:, perm]
    Bias unchanged. Removes the need to permute activations at runtime.
    """
    lin1 = pt_model.head[0]
    with torch.no_grad():
        lin1.weight.copy_(lin1.weight[:, perm_idx])
    # If you had set prehead_perm for runtime, clear it:
    if hasattr(pt_model, "prehead_perm"):
        pt_model.prehead_perm = None

In [6]:
import unittest, time
import tensorflow as tf
import torch
import numpy as np
from tensorflow.keras.layers import TimeDistributed, Conv1D, Dense, BatchNormalization
from spektral.layers import CensNetConv as CensNetConvTF



class TestTransformerEncoderPT(unittest.TestCase):
    def setUp(self):
        tf.keras.backend.clear_session()
        # Your TCN-style init
        self.batch_size = 128
        self.W = 25
        self.N = 11
        self.NF = 3
        self.E = 11
        self.EF = 1
        self.latent_dim = 6

        # Adjacency
        m = np.zeros((self.N, self.N), dtype=np.float32)
        ui = np.triu_indices(self.N)
        num_possible = len(ui[0])
        c = np.random.choice(num_possible, min(self.E, num_possible), replace=False)
        m[ui[0][c], ui[1][c]] = 1
        self.adj_matrix = (m + m.T).astype(np.float32)

        # Data
        rng = np.random.default_rng(0)
        self.x_np = rng.normal(size=(self.batch_size, self.W, self.N, self.NF)).astype(np.float32)
        self.a_np = rng.normal(size=(self.batch_size, self.W, self.E, self.EF)).astype(np.float32)

        # Transformer params (keep small to avoid OOM)
        self.num_layers = 1
        self.num_heads = 4
        self.dff = 64
        self.dropout = 0.0

    def test_non_gnn_parity(self):
        B, W, N, NF, E, EF = self.batch_size, self.W, self.N, self.NF, self.E, self.EF

        # TF model: IMPORTANT — 2D shapes (W, N*NF) and (W, E*EF)
        tf_model = get_transformer_encoder(
            input_shape=(W, N * NF),
            edge_feature_shape=(W, E * EF),
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=False,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            dff=self.dff,
            dropout_rate=self.dropout,
        )

        # PT model: split dims
        pt_model = TFMEncoderPT(
            input_shape=(W, N, NF),
            edge_feature_shape=(W, E, EF),
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=False,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            dff=self.dff,
            dropout_rate=self.dropout,
        )
        pt_model.eval()

        # Transfer weights
        transfer_transformer_encoder_weights(tf_model, pt_model, use_gnn=False)

        # Inputs to TF are flattened 2D per time step
        x_tf = self.x_np.reshape(B, W, N * NF)
        a_tf = self.a_np.reshape(B, W, E * EF)

        y_tf = tf_model([x_tf, a_tf], training=False).numpy()
        with torch.no_grad():
            y_pt = pt_model(torch.from_numpy(self.x_np), torch.from_numpy(self.a_np)).cpu().numpy()

        np.testing.assert_allclose(y_tf, y_pt, rtol=1e-5, atol=2e-4)
        print("✅ TransformerEncoderPT non-GNN parity PASSED")


    def test_gnn_parity(self):
        B, W, N, NF, E, EF = self.batch_size, self.W, self.N, self.NF, self.E, self.EF

        # TF model: IMPORTANT — 2D shapes (W, N*NF) and (W, E*EF)
        tf.keras.backend.clear_session()
        tf_model = get_transformer_encoder(
            input_shape=(self.W, self.N * self.NF),
            edge_feature_shape=(self.W, self.E * self.EF),
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=True,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            dff=self.dff,
            dropout_rate=self.dropout,
        )
        pt_model = TFMEncoderPT(
            input_shape=(self.W, self.N, self.NF),
            edge_feature_shape=(self.W, self.E, self.EF),
            adjacency_matrix=self.adj_matrix,
            latent_dim=self.latent_dim,
            use_gnn=True,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            dff=self.dff,
            dropout_rate=self.dropout,
        )
        pt_model.eval()

        # Warm up PT model to build parameters
        with torch.no_grad():
            _ = pt_model(torch.from_numpy(self.x_np[:2]), torch.from_numpy(self.a_np[:2]))
        transfer_transformer_encoder_weights(tf_model, pt_model, use_gnn=True)

        # Inputs to TF are flattened 2D per time step
        x_tf = self.x_np.reshape(B, W, N * NF)
        a_tf = self.a_np.reshape(B, W, E * EF)

        y_tf = tf_model([x_tf, a_tf], training=False).numpy()
        with torch.no_grad():
            y_pt = pt_model(torch.from_numpy(self.x_np), torch.from_numpy(self.a_np)).cpu().numpy()

        np.testing.assert_allclose(y_tf, y_pt, rtol=1e-5, atol=2e-4)
        print("✅ TransformerEncoderPT GNN parity PASSED")


# Run tests
runner = unittest.TextTestRunner(verbosity=2)
suite = unittest.TestLoader().loadTestsFromTestCase(TestTransformerEncoderPT)
runner.run(suite)

test_gnn_parity (__main__.TestTransformerEncoderPT) ... The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
ok
test_non_gnn_parity (__main__.TestTransformerEncoderPT) ... ok

----------------------------------------------------------------------
Ran 2 tests in 1.724s

OK


✅ TransformerEncoderPT GNN parity PASSED
✅ TransformerEncoderPT non-GNN parity PASSED


<unittest.runner.TextTestResult run=2 errors=0 failures=0>