### original model 

In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

model_urls = dict(
    acc_920='https://github.com/khrlimam/facenet/releases/download/acc-0.920/model920-6be7e3e9.pth',
    acc_921='https://github.com/khrlimam/facenet/releases/download/acc-0.92135/model921-af60fb4f.pth'
)



In [2]:
def load_state(arch, progress=True):
    state = load_state_dict_from_url(model_urls.get(arch), progress=progress)
    return state


def model_920(pretrained=True, progress=True):
    model = FaceNetModel()
    if pretrained:
        state = load_state('acc_920', progress)
        model.load_state_dict(state['state_dict'])
    return model


def model_921(pretrained=True, progress=True):
    model = FaceNetModel()
    if pretrained:
        state = load_state('acc_921', progress)
        model.load_state_dict(state['state_dict'])
    return model


class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.size(0), -1)


In [3]:

class FaceNetModel(nn.Module):
    def __init__(self, pretrained=False):
        super(FaceNetModel, self).__init__()

        self.model = resnet50(pretrained)
        embedding_size = 128
        num_classes = 500
        self.cnn = nn.Sequential(
            self.model.conv1,
            self.model.bn1,
            self.model.relu,
            self.model.maxpool,
            self.model.layer1,
            self.model.layer2,
            self.model.layer3,
            self.model.layer4)

        # modify fc layer based on https://arxiv.org/abs/1703.07737
        self.model.fc = nn.Sequential(
            Flatten(),
            # nn.Linear(100352, 1024),
            # nn.BatchNorm1d(1024),
            # nn.ReLU(),
            nn.Linear(100352, embedding_size))

        self.model.classifier = nn.Linear(embedding_size, num_classes)

    def l2_norm(self, input):
        input_size = input.size()
        buffer = torch.pow(input, 2)
        normp = torch.sum(buffer, 1).add_(1e-10)
        norm = torch.sqrt(normp)
        _output = torch.div(input, norm.view(-1, 1).expand_as(input))
        output = _output.view(input_size)
        return output

    def freeze_all(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def unfreeze_all(self):
        for param in self.model.parameters():
            param.requires_grad = True

    def freeze_fc(self):
        for param in self.model.fc.parameters():
            param.requires_grad = False

    def unfreeze_fc(self):
        for param in self.model.fc.parameters():
            param.requires_grad = True

    def freeze_only(self, freeze):
        for name, child in self.model.named_children():
            if name in freeze:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                for param in child.parameters():
                    param.requires_grad = True

    def unfreeze_only(self, unfreeze):
        for name, child in self.model.named_children():
            if name in unfreeze:
                for param in child.parameters():
                    param.requires_grad = True
            else:
                for param in child.parameters():
                    param.requires_grad = False

    # returns face embedding(embedding_size)
    def forward(self, x):
        x = self.cnn(x)
        x = self.model.fc(x)

        features = self.l2_norm(x)
        # Multiply by alpha = 10 as suggested in https://arxiv.org/pdf/1703.09507.pdf
        alpha = 10
        features = features * alpha
        return features

    def forward_classifier(self, x):
        features = self.forward(x)
        res = self.model.classifier(features)
        return res

In [4]:
model = FaceNetModel(pretrained=False)

# Print number of parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Optional: Detailed layer-wise parameter count
print("\nLayer-wise parameters:")
for name, param in model.named_parameters():
    print(f"{name:50} | Shape: {str(param.shape):20} | Params: {param.numel():,}")

# Dummy input and forward pass
dummy_input = torch.randn(4, 3, 224, 224)
output = model(dummy_input)
print("\nOutput shape:", output.shape)



Total parameters: 36,417,716
Trainable parameters: 36,417,716

Layer-wise parameters:
model.conv1.weight                                 | Shape: torch.Size([64, 3, 7, 7]) | Params: 9,408
model.bn1.weight                                   | Shape: torch.Size([64])     | Params: 64
model.bn1.bias                                     | Shape: torch.Size([64])     | Params: 64
model.layer1.0.conv1.weight                        | Shape: torch.Size([64, 64, 1, 1]) | Params: 4,096
model.layer1.0.bn1.weight                          | Shape: torch.Size([64])     | Params: 64
model.layer1.0.bn1.bias                            | Shape: torch.Size([64])     | Params: 64
model.layer1.0.conv2.weight                        | Shape: torch.Size([64, 64, 3, 3]) | Params: 36,864
model.layer1.0.bn2.weight                          | Shape: torch.Size([64])     | Params: 64
model.layer1.0.bn2.bias                            | Shape: torch.Size([64])     | Params: 64
model.layer1.0.conv3.weight              

### modified model 

In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
from core_qnn import *
from ReLu import Relu
from core_qnn.quaternion_layers import QuaternionConv
from BatchNormalization import QuaternionBatchNorm2d
from InstanceNormalization import QuaternionInstanceNorm2d
from core_qnn.quaternion_layers import QuaternionLinearAutograd
import numpy



try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

model_urls = dict(
    acc_920='https://github.com/khrlimam/facenet/releases/download/acc-0.920/model920-6be7e3e9.pth',
    acc_921='https://github.com/khrlimam/facenet/releases/download/acc-0.92135/model921-af60fb4f.pth'
)



In [2]:
def load_state(arch, progress=True):
    state = load_state_dict_from_url(model_urls.get(arch), progress=progress)
    return state


def model_920(pretrained=True, progress=True):
    model = FaceNetModel()
    if pretrained:
        state = load_state('acc_920', progress)
        model.load_state_dict(state['state_dict'])
    return model


def model_921(pretrained=True, progress=True):
    model = FaceNetModel()
    if pretrained:
        state = load_state('acc_921', progress)
        model.load_state_dict(state['state_dict'])
    return model


class Flatten(nn.Module):

    def forward(self, x):
        return x.view(x.size(0), -1)


In [17]:


class FaceNetModel(nn.Module):
    def __init__(self, pretrained=False):
        super(FaceNetModel, self).__init__()

        # Initialize base model without pretrained weights (quaternions aren't in ImageNet-pretrained models)
        self.model = resnet50(pretrained=False)

        # Replace first conv with quaternion conv (adjust out_channels for compatibility)
        self.model.conv1 = QuaternionConv(
            in_channels=4,        # 4 channels: e.g., RGB + extra channel
            out_channels=64,      # 64 real channels (i.e. 16 quaternion channels)
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )
        self.model.bn1 = QuaternionBatchNorm2d(64)
        self.model.relu = nn.ReLU()
        self.model.layer1 = self._replace_resnet_layer(self.model.layer1, 64, 64)      # Outputs 64*4 = 256 channels
        self.model.layer2 = self._replace_resnet_layer(self.model.layer2, 256, 128)     # Outputs 128*4 = 512 channels
        self.model.layer3 = self._replace_resnet_layer(self.model.layer3, 512, 256)     # Outputs 256*4 = 1024 channels
        self.model.layer4 = self._replace_resnet_layer(self.model.layer4, 1024, 512)    # Outputs 512*4 = 2048 channels
    

        # Update avgpool and fully connected layers
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.model.fc = QuaternionLinearAutograd(2048, 128)
  # 512 real -> 128 quaternion
        
        # Classifier
        self.classifier = QuaternionLinearAutograd(128, 500)
    
    def _replace_resnet_layer(self, layer, in_channels, out_channels):
        expansion = 4  # For ResNet50 Bottleneck blocks
        for i in range(len(layer)):
            block = layer[i]
            # Replace conv1: use a 1x1 convolution to reduce channels
            block.conv1 = QuaternionConv(in_channels, out_channels, kernel_size=1, stride=block.conv1.stride, bias=False)
            block.bn1 = QuaternionBatchNorm2d(out_channels)
            
            # Replace conv2: 3x3 convolution
            block.conv2 = QuaternionConv(out_channels, out_channels, kernel_size=3, stride=block.conv2.stride, padding=1, bias=False)
            block.bn2 = QuaternionBatchNorm2d(out_channels)
            
            # Replace conv3: 1x1 convolution to expand channels
            block.conv3 = QuaternionConv(out_channels, out_channels * expansion, kernel_size=1, stride=1, bias=False)
            block.bn3 = QuaternionBatchNorm2d(out_channels * expansion)
            
            # Replace downsample if it exists so that it outputs out_channels * expansion channels
            if block.downsample:
                block.downsample[0] = QuaternionConv(in_channels, out_channels * expansion, kernel_size=1, stride=block.downsample[0].stride, bias=False)
                block.downsample[1] = QuaternionBatchNorm2d(out_channels * expansion)
            in_channels = out_channels * expansion  # update for subsequent blocks in this layer
        return layer



    def quaternion_l2_norm(self, x):
        """L2 normalization adapted for quaternion tensors."""
        real, i, j, k = torch.chunk(x, 4, dim=1)
        norm = torch.sqrt(real**2 + i**2 + j**2 + k**2 + 1e-10)
        return torch.cat([real/norm, i/norm, j/norm, k/norm], dim=1)

    def forward(self, x):
        
        x = self.model.conv1(x)
        print("shape of x before batch normalization" ,x.shape)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        print("shape of x before layers ", x.shape)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.fc(x)
        return self.quaternion_l2_norm(x)


In [18]:
model = FaceNetModel(pretrained=False)

# Print number of parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Optional: Detailed layer-wise parameter count
print("\nLayer-wise parameters:")
for name, param in model.named_parameters():
    print(f"{name:50} | Shape: {str(param.shape):20} | Params: {param.numel():,}")

# Dummy input and forward pass
dummy_input = torch.randn(4, 3, 224, 224)
dummy_input = torch.cat([dummy_input, torch.zeros(4, 1, 224, 224)], dim=1)  # Shape: [4, 4, 224, 224]
output = model(dummy_input)
print("\nOutput shape:", output.shape)

Total parameters: 5,979,876
Trainable parameters: 5,979,876

Layer-wise parameters:
model.conv1.r_weight                               | Shape: torch.Size([16, 1, 7, 7]) | Params: 784
model.conv1.i_weight                               | Shape: torch.Size([16, 1, 7, 7]) | Params: 784
model.conv1.j_weight                               | Shape: torch.Size([16, 1, 7, 7]) | Params: 784
model.conv1.k_weight                               | Shape: torch.Size([16, 1, 7, 7]) | Params: 784
model.bn1.gamma                                    | Shape: torch.Size([1, 16, 1, 1]) | Params: 16
model.bn1.beta                                     | Shape: torch.Size([1, 64, 1, 1]) | Params: 64
model.layer1.0.conv1.r_weight                      | Shape: torch.Size([16, 16, 1, 1]) | Params: 256
model.layer1.0.conv1.i_weight                      | Shape: torch.Size([16, 16, 1, 1]) | Params: 256
model.layer1.0.conv1.j_weight                      | Shape: torch.Size([16, 16, 1, 1]) | Params: 256
model.layer1.0.

In [None]:
import numpy as np
print(np.__version__)
