### Define the PyTorch model as it is in Tractolearn

In [1]:
import torch.nn as nn
from torch.nn import functional as F

class AETorch(nn.Module):
    """Strided convolution-upsampling-based AE using reflection-padding and
    increasing feature maps in decoder.
    """

    def __init__(self, latent_space_dims):
        super(AETorch, self).__init__()

        self.kernel_size = 3
        self.latent_space_dims = latent_space_dims

        self.pad = nn.ReflectionPad1d(1)

        def pre_pad(m):
            return nn.Sequential(self.pad, m)

        self.encod_conv1 = pre_pad(
            nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0)
        )
        self.encod_conv2 = pre_pad(
            nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0)
        )
        self.encod_conv3 = pre_pad(
            nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0)
        )
        self.encod_conv4 = pre_pad(
            nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0)
        )
        self.encod_conv5 = pre_pad(
            nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0)
        )
        self.encod_conv6 = pre_pad(
            nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0)
        )

        self.fc1 = nn.Linear(8192, self.latent_space_dims)  # 8192 = 1024*8
        self.fc2 = nn.Linear(self.latent_space_dims, 8192)

        self.decod_conv1 = pre_pad(
            nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0)
        )
        self.upsampl1 = nn.Upsample(
            scale_factor=2, mode="linear", align_corners=False
        )
        self.decod_conv2 = pre_pad(
            nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0)
        )
        self.upsampl2 = nn.Upsample(
            scale_factor=2, mode="linear", align_corners=False
        )
        self.decod_conv3 = pre_pad(
            nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0)
        )
        self.upsampl3 = nn.Upsample(
            scale_factor=2, mode="linear", align_corners=False
        )
        self.decod_conv4 = pre_pad(
            nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0)
        )
        self.upsampl4 = nn.Upsample(
            scale_factor=2, mode="linear", align_corners=False
        )
        self.decod_conv5 = pre_pad(
            nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0)
        )
        self.upsampl5 = nn.Upsample(
            scale_factor=2, mode="linear", align_corners=False
        )
        self.decod_conv6 = pre_pad(
            nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0)
        )
    def encode(self, x):
        h1 = F.relu(self.encod_conv1(x))
        h2 = F.relu(self.encod_conv2(h1))
        h3 = F.relu(self.encod_conv3(h2))
        h4 = F.relu(self.encod_conv4(h3))
        h5 = F.relu(self.encod_conv5(h4))
        h6 = self.encod_conv6(h5)

        self.encoder_out_size = (h6.shape[1], h6.shape[2])

        # Flatten
        h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1])

        fc1 = self.fc1(h7)

        return fc1

    def decode(self, z):
        fc = self.fc2(z)
        fc_reshape = fc.view(
            -1, self.encoder_out_size[0], self.encoder_out_size[1]
        )
        h1 = F.relu(self.decod_conv1(fc_reshape))
        h2 = self.upsampl1(h1)
        h3 = F.relu(self.decod_conv2(h2))
        h4 = self.upsampl2(h3)
        h5 = F.relu(self.decod_conv3(h4))
        h6 = self.upsampl3(h5)
        h7 = F.relu(self.decod_conv4(h6))
        h8 = self.upsampl4(h7)
        h9 = F.relu(self.decod_conv5(h8))
        h10 = self.upsampl5(h9)
        h11 = self.decod_conv6(h10)

        return h11

    def forward(self, x):
        encoded = self.encode(x)
        return self.decode(encoded)

### Instantiate models

In [2]:
from tractoencoder_gsoc.ae_model import IncrFeatStridedConvFCUpsampReflectPadAE as AEKeras
model_torch = AETorch(32)
model_keras = AEKeras(32)

2024-06-18 12:02:44.289059: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-18 12:02:44.370492: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Compare the shape of the weights and biases in some layers between Torch and Keras to see the difference pattern

In [3]:
print("ENCODER LAYER CONV6")
print(f"Weights TORCH: {model_torch.encod_conv6[1].weight.shape}")
print(f"Biases TORCH: {model_torch.encod_conv6[1].bias.shape}")
print("---------------------------------")
print(f"Weights KERAS: {model_keras.model.get_layer('encoder').encod_conv6.layers[1].get_weights()[0].shape}")
print(f"Biases KERAS: {model_keras.model.get_layer('encoder').encod_conv6.layers[1].get_weights()[1].shape}")

print("\n\nENCODER LAYER CONV3")
print(f"Weights TORCH: {model_torch.encod_conv3[1].weight.shape}")
print(f"Biases TORCH: {model_torch.encod_conv3[1].bias.shape}")
print("---------------------------------")
print(f"Weights KERAS: {model_keras.model.get_layer('encoder').encod_conv3.layers[1].get_weights()[0].shape}")
print(f"Biases KERAS: {model_keras.model.get_layer('encoder').encod_conv3.layers[1].get_weights()[1].shape}")

print("\n\nDECODER LAYER FC1")
print(f"Weights TORCH: {model_torch.fc1.weight.shape}")
print(f"Biases TORCH: {model_torch.fc1.bias.shape}")
print("---------------------------------")
print(f"Weights KERAS: {model_keras.model.get_layer('encoder').fc1.get_weights()[0].shape}")
print(f"Biases KERAS: {model_keras.model.get_layer('encoder').fc1.get_weights()[1].shape}")


print("\n\nDECODER LAYER CONV1")
print(f"Weights TORCH: {model_torch.decod_conv1[1].weight.shape}")
print(f"Biases TORCH: {model_torch.decod_conv1[1].bias.shape}")
print("---------------------------------")
print(f"Weights KERAS: {model_keras.model.get_layer('decoder').decod_conv1.layers[1].get_weights()[0].shape}")
print(f"Biases KERAS: {model_keras.model.get_layer('decoder').decod_conv1.layers[1].get_weights()[1].shape}")



ENCODER LAYER CONV6
Weights TORCH: torch.Size([1024, 512, 3])
Biases TORCH: torch.Size([1024])
---------------------------------
Weights KERAS: (3, 512, 1024)
Biases KERAS: (1024,)


ENCODER LAYER CONV3
Weights TORCH: torch.Size([128, 64, 3])
Biases TORCH: torch.Size([128])
---------------------------------
Weights KERAS: (3, 64, 128)
Biases KERAS: (128,)


DECODER LAYER FC1
Weights TORCH: torch.Size([32, 8192])
Biases TORCH: torch.Size([32])
---------------------------------
Weights KERAS: (8192, 32)
Biases KERAS: (32,)


DECODER LAYER CONV1
Weights TORCH: torch.Size([512, 1024, 3])
Biases TORCH: torch.Size([512])
---------------------------------
Weights KERAS: (3, 1024, 512)
Biases KERAS: (512,)


# READ WEIGHTS FROM PTH FILE

In [4]:
import os
import torch
pth_path = os.path.abspath("/home/teitxe/data/tractolearn_data/best_model_contrastive_tractoinferno_hcp.pt")
torch_weights = torch.load(pth_path, map_location=torch.device('cpu'))

In [5]:
print(torch_weights.keys())

dict_keys(['epoch', 'state_dict', 'lowest_loss', 'optimizer'])


In [7]:
print(torch_weights['state_dict'].keys())
weight_dict = torch_weights['state_dict']
model_torch.load_state_dict(torch_weights['state_dict'])

odict_keys(['encod_conv1.1.weight', 'encod_conv1.1.bias', 'encod_conv2.1.weight', 'encod_conv2.1.bias', 'encod_conv3.1.weight', 'encod_conv3.1.bias', 'encod_conv4.1.weight', 'encod_conv4.1.bias', 'encod_conv5.1.weight', 'encod_conv5.1.bias', 'encod_conv6.1.weight', 'encod_conv6.1.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'decod_conv1.1.weight', 'decod_conv1.1.bias', 'decod_conv2.1.weight', 'decod_conv2.1.bias', 'decod_conv3.1.weight', 'decod_conv3.1.bias', 'decod_conv4.1.weight', 'decod_conv4.1.bias', 'decod_conv5.1.weight', 'decod_conv5.1.bias', 'decod_conv6.1.weight', 'decod_conv6.1.bias'])


<All keys matched successfully>

## Set the weights in the layers setting the read Pytorch weights to the Keras model

### Encoder

In [20]:
import numpy as np
data_type = np.float32

In [21]:
# encod_conv1
weight_bias_list = [weight_dict['encod_conv1.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv1.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv1.layers[1].set_weights(weight_bias_list)

# encod_conv2
weight_bias_list = [weight_dict['encod_conv2.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv2.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv2.layers[1].set_weights(weight_bias_list)

# encod_conv3
weight_bias_list = [weight_dict['encod_conv3.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv3.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv3.layers[1].set_weights(weight_bias_list)

# encod_conv4
weight_bias_list = [weight_dict['encod_conv4.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv4.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv4.layers[1].set_weights(weight_bias_list)

# encod_conv5
weight_bias_list = [weight_dict['encod_conv5.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv5.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv5.layers[1].set_weights(weight_bias_list)

# encod_conv6
weight_bias_list = [weight_dict['encod_conv6.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['encod_conv6.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').encod_conv6.layers[1].set_weights(weight_bias_list)

# fc1
weight_bias_list = [weight_dict['fc1.weight'].numpy().transpose(1, 0).astype(data_type),
               weight_dict['fc1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('encoder').fc1.set_weights(weight_bias_list)


### Decoder

In [22]:
# fc2
weight_bias_list = [weight_dict['fc2.weight'].numpy().transpose(1, 0).astype(data_type),
               weight_dict['fc2.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').fc2.set_weights(weight_bias_list)

# decod_conv1
weight_bias_list = [weight_dict['decod_conv1.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv1.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv1.layers[1].set_weights(weight_bias_list)

# decod_conv2
weight_bias_list = [weight_dict['decod_conv2.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv2.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv2.layers[1].set_weights(weight_bias_list)

# decod_conv3
weight_bias_list = [weight_dict['decod_conv3.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv3.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv3.layers[1].set_weights(weight_bias_list)

# decod_conv4
weight_bias_list = [weight_dict['decod_conv4.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv4.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv4.layers[1].set_weights(weight_bias_list)

# decod_conv5
weight_bias_list = [weight_dict['decod_conv5.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv5.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv5.layers[1].set_weights(weight_bias_list)

# decod_conv6
weight_bias_list = [weight_dict['decod_conv6.1.weight'].numpy().transpose(2, 1, 0).astype(data_type),
               weight_dict['decod_conv6.1.bias'].numpy().astype(data_type)]
model_keras.model.get_layer('decoder').decod_conv6.layers[1].set_weights(weight_bias_list)


## Check that the weights are equal in both models

### Encoder

In [41]:
# encod_conv1 weights
print(f"Encod_conv1 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv1.layers[1].get_weights()[0] == model_torch.encod_conv1[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv1 biases
print(f"Encod_conv1 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv1.layers[1].get_weights()[1] == model_torch.encod_conv1[1].bias.detach().numpy())}")

# encod_conv2 weights
print(f"Encod_conv2 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv2.layers[1].get_weights()[0] == model_torch.encod_conv2[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv2 biases
print(f"Encod_conv2 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv2.layers[1].get_weights()[1] == model_torch.encod_conv2[1].bias.detach().numpy())}")

# encod_conv3 weights
print(f"Encod_conv3 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv3.layers[1].get_weights()[0] == model_torch.encod_conv3[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv3 biases
print(f"Encod_conv3 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv3.layers[1].get_weights()[1] == model_torch.encod_conv3[1].bias.detach().numpy())}")

# encod_conv4 weights
print(f"Encod_conv4 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv4.layers[1].get_weights()[0] == model_torch.encod_conv4[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv4 biases
print(f"Encod_conv4 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv4.layers[1].get_weights()[1] == model_torch.encod_conv4[1].bias.detach().numpy())}")

# encod_conv5 weights
print(f"Encod_conv5 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv5.layers[1].get_weights()[0] == model_torch.encod_conv5[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv5 biases
print(f"Encod_conv5 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv5.layers[1].get_weights()[1] == model_torch.encod_conv5[1].bias.detach().numpy())}")

# encod_conv6 weights
print(f"Encod_conv6 weights: {np.all(model_keras.model.get_layer('encoder').encod_conv6.layers[1].get_weights()[0] == model_torch.encod_conv6[1].weight.detach().numpy().transpose(2, 1, 0))}")
# encod_conv6 biases
print(f"Encod_conv6 biases: {np.all(model_keras.model.get_layer('encoder').encod_conv6.layers[1].get_weights()[1] == model_torch.encod_conv6[1].bias.detach().numpy())}")

# fc1 weights
print(f"FC1 weights: {np.all(model_keras.model.get_layer('encoder').fc1.get_weights()[0] == model_torch.fc1.weight.detach().numpy().transpose(1, 0))}")
# fc1 biases
print(f"FC1 biases: {np.all(model_keras.model.get_layer('encoder').fc1.get_weights()[1] == model_torch.fc1.bias.detach().numpy())}")

Encod_conv1 weights: True
Encod_conv1 biases: True
Encod_conv2 weights: True
Encod_conv2 biases: True
Encod_conv3 weights: True
Encod_conv3 biases: True
Encod_conv4 weights: True
Encod_conv4 biases: True
Encod_conv5 weights: True
Encod_conv5 biases: True
Encod_conv6 weights: True
Encod_conv6 biases: True
FC1 weights: True
FC1 biases: True


### Decoder

In [42]:
# fc2 weights
print(f"FC2 weights: {np.all(model_keras.model.get_layer('decoder').fc2.get_weights()[0] == model_torch.fc2.weight.detach().numpy().transpose(1, 0))}")
# fc2 biases
print(f"FC2 biases: {np.all(model_keras.model.get_layer('decoder').fc2.get_weights()[1] == model_torch.fc2.bias.detach().numpy())}")

# decod_conv1 weights
print(f"Decod_conv1 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv1.layers[1].get_weights()[0] == model_torch.decod_conv1[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv1 biases
print(f"Decod_conv1 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv1.layers[1].get_weights()[1] == model_torch.decod_conv1[1].bias.detach().numpy())}")

# decod_conv2 weights
print(f"Decod_conv2 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv2.layers[1].get_weights()[0] == model_torch.decod_conv2[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv2 biases
print(f"Decod_conv2 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv2.layers[1].get_weights()[1] == model_torch.decod_conv2[1].bias.detach().numpy())}")

# decod_conv3 weights
print(f"Decod_conv3 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv3.layers[1].get_weights()[0] == model_torch.decod_conv3[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv3 biases
print(f"Decod_conv3 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv3.layers[1].get_weights()[1] == model_torch.decod_conv3[1].bias.detach().numpy())}")

# decod_conv4 weights
print(f"Decod_conv4 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv4.layers[1].get_weights()[0] == model_torch.decod_conv4[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv4 biases
print(f"Decod_conv4 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv4.layers[1].get_weights()[1] == model_torch.decod_conv4[1].bias.detach().numpy())}")

# decod_conv5 weights
print(f"Decod_conv5 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv5.layers[1].get_weights()[0] == model_torch.decod_conv5[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv5 biases
print(f"Decod_conv5 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv5.layers[1].get_weights()[1] == model_torch.decod_conv5[1].bias.detach().numpy())}")

# decod_conv6 weights
print(f"Decod_conv6 weights: {np.all(model_keras.model.get_layer('decoder').decod_conv6.layers[1].get_weights()[0] == model_torch.decod_conv6[1].weight.detach().numpy().transpose(2, 1, 0))}")
# decod_conv6 biases
print(f"Decod_conv6 biases: {np.all(model_keras.model.get_layer('decoder').decod_conv6.layers[1].get_weights()[1] == model_torch.decod_conv6[1].bias.detach().numpy())}")

FC2 weights: True
FC2 biases: True
Decod_conv1 weights: True
Decod_conv1 biases: True
Decod_conv2 weights: True
Decod_conv2 biases: True
Decod_conv3 weights: True
Decod_conv3 biases: True
Decod_conv4 weights: True
Decod_conv4 biases: True
Decod_conv5 weights: True
Decod_conv5 biases: True
Decod_conv6 weights: True
Decod_conv6 biases: True


### Check that both models give a very similar output for the same input

In [53]:
import numpy as np
import tensorflow as tf
random_streamline = np.random.rand(1, 256, 3)
dummy_input_keras = tf.convert_to_tensor(random_streamline, dtype=tf.float32)
dummy_input_torch = torch.Tensor(random_streamline.transpose(0, 2, 1))

output_keras = model_keras(dummy_input_keras)
output_torch = model_torch(dummy_input_torch).detach().numpy()
output_torch_reshaped = output_torch.transpose(0, 2, 1)

print("MSE: ", np.mean((output_keras - output_torch_reshaped) ** 2))
print("diff max (3D): ", np.max(output_keras - output_torch_reshaped))
print("diff norm (3D): ", np.linalg.norm(output_keras - output_torch_reshaped))
are_close = np.isclose(output_keras, output_torch_reshaped, atol=1e-6)

# Check if all elements are close
outputs_are_similar = np.all(are_close)
print("Outputs are similar: ", outputs_are_similar)

MSE:  60.697445
diff max (3D):  18.677742
diff norm (3D):  215.90654
Outputs are similar:  False


### Outputs are not similar. Let's check where this happens in the models.

Extract the layers

In [67]:
# encoder Keras
encoder = model_keras.model.get_layer('encoder')
encoder_keras_layers = [getattr(encoder, layer) for layer in dir(encoder) if layer.startswith(('encod_conv', 'fc'))]
# encoder Torch
encoder_torch_layers = [getattr(model_torch, layer) for layer in dir(model_torch) if layer.startswith(('encod_conv', 'fc1'))]

# decoder Keras
decoder = model_keras.model.get_layer('decoder')
decoder_keras_layers = [getattr(decoder, layer) for layer in dir(decoder) if layer.startswith(('decod_conv', 'fc'))]
# decoder Torch
decoder_torch_layers = [getattr(model_torch, layer) for layer in dir(model_torch) if layer.startswith(('decod_conv', 'fc2'))]

Iterate through the conv1d layers and check if the output is similar.

In [92]:
encoder_keras_layers[5]

<Sequential name=sequential_5, built=True>

In [107]:
random_streamline = np.random.rand(1, 256, 3)
dummy_input_keras = tf.convert_to_tensor(random_streamline, dtype=tf.float32)
dummy_input_torch = torch.Tensor(random_streamline.transpose(0, 2, 1))

for i, (layer_keras, layer_torch) in enumerate(zip(encoder_keras_layers[:-1], encoder_torch_layers[:-1])):
    print(f"ENCODER LAYER {i+1}")
    print(f"Keras: {layer_keras}")
    print(f"Torch: {layer_torch}")
    # Run the input through the layers
    layer_output_keras = dummy_input_keras
    layer_output_torch = dummy_input_torch
    for c in range(i+1):
        if c == 5:  # conv6, do not ReLU
            layer_output_keras = encoder_keras_layers[c](layer_output_keras)
            layer_output_torch = encoder_torch_layers[c](layer_output_torch)
        else:
            layer_output_keras = tf.nn.relu(encoder_keras_layers[c](layer_output_keras))
            layer_output_torch = F.relu(encoder_torch_layers[c](layer_output_torch))
    # Check if the outputs are similar
    layer_output_torch_reshaped = layer_output_torch.detach().numpy().transpose(0, 2, 1)
    are_close = np.all(np.isclose(layer_output_keras, layer_output_torch_reshaped, atol=1e-6))
    print(f"In layer {i} Outputs are similar? {are_close}")
    print(f"MSE is: {np.mean((layer_output_keras - layer_output_torch_reshaped) ** 2)}\n")


# reshape before running into fc1
encoder_out_size_keras = (layer_output_keras.shape[1], layer_output_keras.shape[2])
h7_keras = tf.reshape(layer_output_keras, (-1, encoder_out_size_keras[0] * encoder_out_size_keras[1]))

encoder_out_size_torch = (layer_output_torch.shape[1], layer_output_torch.shape[2])
h7_torch = layer_output_torch.view(-1, encoder_out_size_torch[0] * encoder_out_size_torch[1]).detach().numpy()

# run through fc1

are_close = np.all(np.isclose(h7_keras, h7_torch, atol=1e-6))
print(f"In layer {7} Outputs are similar? {are_close}")
print(f"MSE is: {np.mean((h7_keras - h7_torch) ** 2)}\n")

ENCODER LAYER 1
Keras: <Sequential name=sequential, built=True>
Torch: Sequential(
  (0): ReflectionPad1d((1, 1))
  (1): Conv1d(3, 32, kernel_size=(3,), stride=(2,))
)


In layer 0 Outputs are similar? True
MSE is: 8.115592074232902e-17

ENCODER LAYER 2
Keras: <Sequential name=sequential_1, built=True>
Torch: Sequential(
  (0): ReflectionPad1d((1, 1))
  (1): Conv1d(32, 64, kernel_size=(3,), stride=(2,))
)
In layer 1 Outputs are similar? True
MSE is: 4.5786538017398616e-17

ENCODER LAYER 3
Keras: <Sequential name=sequential_2, built=True>
Torch: Sequential(
  (0): ReflectionPad1d((1, 1))
  (1): Conv1d(64, 128, kernel_size=(3,), stride=(2,))
)
In layer 2 Outputs are similar? True
MSE is: 7.266666703173203e-18

ENCODER LAYER 4
Keras: <Sequential name=sequential_3, built=True>
Torch: Sequential(
  (0): ReflectionPad1d((1, 1))
  (1): Conv1d(128, 256, kernel_size=(3,), stride=(2,))
)
In layer 3 Outputs are similar? True
MSE is: 3.997768863552458e-18

ENCODER LAYER 5
Keras: <Sequential name=sequential_4, built=True>
Torch: Sequential(
  (0): ReflectionPad1d((1, 1))
  (1): Conv1d(256, 512, kernel_size=(3,), stride=(2,))
)
In layer 4 Outputs are similar? True
M

Now check if the fc1 layer output is similar too

In [None]:


# reshape the output before passing it to fc1


In [2]:
import os
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn

data_type = np.float64

# Pytorch 3D Convolution with 1 input channel and 32 output channels
# Initialize weights for PyTorch model
# Shape for PyTorch Conv3d: (out_channels, in_channels, D, H, W)
weight_pytorch = 10.12 * np.random.randint(1, 1000, (32, 1, 3, 3, 3)).astype(data_type)  
weights = torch.from_numpy(weight_pytorch)
biases = torch.zeros(32, dtype=torch.float64)  # Biases for 32 output channels

inputs_torch = torch.from_numpy(1.5 * np.ones((1, 1, 10, 10, 10), dtype=data_type))  
torch_model_3d = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)  
torch_model_3d.weight = nn.Parameter(weights)
torch_model_3d.bias = nn.Parameter(biases)
torch_output_3d = torch_model_3d(inputs_torch)

# Convert PyTorch weights for TensorFlow
weight_tf = weight_pytorch.transpose((2, 3, 4, 1, 0))  # Reorder dimensions for TensorFlow: (D, H, W, in_channels, out_channels)

inputs = tf.Variable(1.5 * np.ones((1, 10, 10, 10, 1), dtype=data_type)) 

# TensorFlow 3D Convolution with 1 input channel and 32 output channels
conv3d_layer = tf.keras.layers.Conv3D(32, [3, 3, 3], strides=(1, 1, 1), padding='same',
                                      kernel_initializer=tf.constant_initializer(weight_tf),
                                      bias_initializer=tf.constant_initializer(0),
                                      activation=None, dtype=data_type)

# Apply the 3D convolution operation
tf_output_3d = conv3d_layer(inputs)

# Compare results for 3D Convolution with 1 input channel and 32 output channels
# Ensure TensorFlow output is converted to NumPy for comparison
tf_output_3d_numpy = tf_output_3d.numpy()
torch_output_3d_numpy = torch_output_3d.permute((0, 2, 3, 4, 1)).detach().numpy()  

print("diff max (3D): ", np.max(tf_output_3d_numpy - torch_output_3d_numpy))
print("diff norm (3D): ", np.linalg.norm(tf_output_3d_numpy - torch_output_3d_numpy))
are_close = np.isclose(tf_output_3d_numpy, torch_output_3d_numpy, atol=1e-6)

# Check if all elements are close
outputs_are_similar = np.all(are_close)
print("Outputs are similar: ", outputs_are_similar)

diff max (3D):  8.731149137020111e-11
diff norm (3D):  3.8385555479116074e-09
Outputs are similar:  True


# Try to load the weights from the HDF5 file