In [5]:
"""
model.py - Embedding Model for Keyword Spotting

Provides:
- Embedding models that return vector embeddings for word images
- Similarity functions (cosine similarity, euclidean distance)
- Support for multiple backbones (SimpleCNN, ResNet18)

Project: Keyword Spotting
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


# ============================================================================
#                           BACKBONE ARCHITECTURES
# ============================================================================

class SimpleCNN(nn.Module):
    """
    Simple CNN backbone for word image embedding.
    """
    def __init__(self, embedding_dim=128, dropout_rate=0.4):
        super(SimpleCNN, self).__init__()

        # Block 1: 1 -> 32 channels
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Block 2: 32 -> 64 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Block 3: 64 -> 128 channels
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Block 4: 128 -> 256 channels
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool4 = nn.MaxPool2d(2, 2)

        # Global Average Pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Embedding layer
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(256, embedding_dim)

    def forward(self, x):
        """
        Forward pass through the network.
        """
        # Block 1
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))

        # Block 2
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))

        # Block 3
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))

        # Block 4
        x = self.pool4(F.relu(self.bn4(self.conv4(x))))

        # Global pooling and embedding
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)

        return x


class ResNet18Backbone(nn.Module):
    """
    ResNet18 backbone adapted for grayscale word images.

    Uses pretrained ResNet18 with modified first conv layer for single-channel input.
    Provides stronger feature extraction capabilities than SimpleCNN at the cost
    of more parameters and slower inference.

    """
    def __init__(self, embedding_dim=128, pretrained=True):
        super(ResNet18Backbone, self).__init__()

        # Load pretrained ResNet18
        weights = 'IMAGENET1K_V1' if pretrained else None
        resnet = models.resnet18(weights=weights)

        # Modify first conv layer for grayscale (1 channel instead of 3)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Copy weights from pretrained (average across RGB channels)
        if pretrained:
            pretrained_weight = resnet.conv1.weight.data
            self.conv1.weight.data = pretrained_weight.mean(dim=1, keepdim=True)

        # Copy other layers from ResNet
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.avgpool = resnet.avgpool

        # Replace final FC layer
        self.fc = nn.Linear(512, embedding_dim)

    def forward(self, x):
        """Forward pass through ResNet18 backbone."""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


# ============================================================================
#                        MAIN EMBEDDING MODEL
# ============================================================================

class EmbeddingModel(nn.Module):
    """
    Main embedding model for keyword spotting.

    This model takes word images as input and returns fixed-size embedding vectors.
    The embeddings are designed such that similar words (same text content) have
    similar embeddings in the vector space.

    """
    def __init__(self, backbone='simple_cnn', embedding_dim=128, pretrained=False):
        super(EmbeddingModel, self).__init__()

        self.backbone_name = backbone
        self.embedding_dim = embedding_dim

        # Select and initialize backbone
        if backbone == 'simple_cnn':
            self.backbone = SimpleCNN(embedding_dim=embedding_dim)
        elif backbone == 'resnet18':
            self.backbone = ResNet18Backbone(
                embedding_dim=embedding_dim,
                pretrained=pretrained
            )
        else:
            raise ValueError(
                f"Unknown backbone: '{backbone}'. "
                f"Available options: 'simple_cnn', 'resnet18'"
            )

    def forward(self, x):
        """
        Compute embedding for input word image(s).
        """
        return self.backbone(x)

    def get_embedding(self, x):
        """
        Alias for forward pass. Returns embedding vector.

        This method is provided for clarity and backward compatibility.
        It has identical behavior to calling the model directly.

        """
        return self.forward(x)


# ============================================================================
#                        SIMILARITY FUNCTIONS
# ============================================================================

def cosine_similarity(embedding1, embedding2):
    """
    Compute cosine similarity between two embeddings.

    Cosine similarity measures the cosine of the angle between two vectors.
    It is bounded in the range [-1, 1] where:
        - 1.0: Vectors point in exactly the same direction (most similar)
        - 0.0: Vectors are orthogonal (no similarity)
        - -1.0: Vectors point in opposite directions (most dissimilar)
    """
    return F.cosine_similarity(embedding1, embedding2, dim=-1)


def euclidean_distance(embedding1, embedding2):
    """
    Compute Euclidean (L2) distance between two embeddings.

    The Euclidean distance is the straight-line distance between two points
    in the embedding space. Lower values indicate more similar embeddings.

    """
    return F.pairwise_distance(embedding1, embedding2, p=2)


def similarity_matrix(embeddings):
    """
    Compute pairwise cosine similarity matrix for a batch of embeddings.

    This is useful for ranking and retrieval tasks where you need to compare
    one query embedding against many candidate embeddings.
    """
    # Normalize embeddings to unit vectors
    embeddings_norm = F.normalize(embeddings, p=2, dim=1)

    # Compute pairwise cosine similarity as matrix multiplication
    # (normalized vectors Â· normalized vectors^T)
    similarity = torch.mm(embeddings_norm, embeddings_norm.t())

    return similarity


# ============================================================================
#                        MODEL PERSISTENCE
# ============================================================================

def save_model(model, save_path, history=None, **kwargs):
    """
    Save model checkpoint to disk.

    Saves the model state along with metadata including architecture parameters,
    training history, and any custom information.
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'backbone': model.backbone_name,
        'embedding_dim': model.embedding_dim,
        'history': history,
        **kwargs
    }

    torch.save(checkpoint, save_path)
    print(f" Model saved to {save_path}")


def load_model(checkpoint_path, device='cpu'):
    """
    Load a trained model from checkpoint.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Reconstruct model with saved architecture
    model = EmbeddingModel(
        backbone=checkpoint['backbone'],
        embedding_dim=checkpoint['embedding_dim']
    )

    # Load trained weights
    model.load_state_dict(checkpoint['model_state_dict'])

    # Move to device and set to evaluation mode
    model.to(device)
    model.eval()

    return model, checkpoint


# ============================================================================
#                        UTILITY FUNCTIONS
# ============================================================================

def count_parameters(model):
    """
    Count the number of trainable parameters in a model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_info(model):
    """
    Get detailed information about a model.
    """
    return {
        'backbone': model.backbone_name,
        'embedding_dim': model.embedding_dim,
        'num_parameters': count_parameters(model)
    }


# ============================================================================
#                        TEST / DEMO CODE
# ============================================================================

if __name__ == "__main__":
    """
    Test script to verify that all components work correctly.
    Run this file directly to execute tests.
    """
    print("=" * 70)
    print("MODEL.PY - Test Script")
    print("=" * 70)
    print()

    # Test 1: SimpleCNN
    print("Test 1: SimpleCNN Backbone")
    print("-" * 70)
    model_simple = EmbeddingModel(backbone='simple_cnn', embedding_dim=128)
    print(f" Model created")
    print(f"  Parameters: {count_parameters(model_simple):,}")

    batch_size = 4
    word_image = torch.randn(batch_size, 1, 64, 128)
    print(f" Input shape: {word_image.shape}")

    embedding = model_simple(word_image)
    print(f" Output shape: {embedding.shape}")
    print(f" Embedding dimension: {embedding.shape[1]}")
    print()

    # Test 2: ResNet18
    print("Test 2: ResNet18 Backbone")
    print("-" * 70)
    model_resnet = EmbeddingModel(
        backbone='resnet18',
        embedding_dim=256,
        pretrained=False
    )
    print(f" Model created")
    print(f"  Parameters: {count_parameters(model_resnet):,}")

    embedding = model_resnet(word_image)
    print(f" Output shape: {embedding.shape}")
    print()

    # Test 3: Similarity functions
    print("Test 3: Similarity Functions")
    print("-" * 70)

    emb1 = torch.randn(5, 128)
    emb2 = torch.randn(5, 128)

    cos_sim = cosine_similarity(emb1, emb2)
    print(f" Cosine similarity: {cos_sim}")
    print(f"  Shape: {cos_sim.shape}")
    print(f"  Range: [{cos_sim.min().item():.3f}, {cos_sim.max().item():.3f}]")

    euc_dist = euclidean_distance(emb1, emb2)
    print(f" Euclidean distance: {euc_dist}")
    print(f"  Mean: {euc_dist.mean().item():.3f}")

    embeddings = torch.randn(10, 128)
    sim_mat = similarity_matrix(embeddings)
    print(f" Similarity matrix shape: {sim_mat.shape}")
    print(f"  Diagonal (self-similarity): {torch.diag(sim_mat)[:3]}")
    print()

    # Test 4: Save and Load
    print("Test 4: Save and Load")
    print("-" * 70)
    save_path = '/tmp/test_model.pth'

    save_model(
        model_simple,
        save_path,
        history={'train_loss': [0.5, 0.3, 0.2]},
        test_metric=0.85
    )

    loaded_model, checkpoint = load_model(save_path)
    print(f" Model loaded")
    print(f"  Backbone: {checkpoint['backbone']}")
    print(f"  Embedding dim: {checkpoint['embedding_dim']}")
    print(f"  Test metric: {checkpoint['test_metric']}")
    print()

    test_output = loaded_model(word_image)
    print(f" Loaded model forward pass: {test_output.shape}")
    print()

    print("=" * 70)
    print(" ALL TESTS PASSED!")
    print("=" * 70)

MODEL.PY - Test Script

Test 1: SimpleCNN Backbone
----------------------------------------------------------------------
 Model created
  Parameters: 421,696
 Input shape: torch.Size([4, 1, 64, 128])
 Output shape: torch.Size([4, 128])
 Embedding dimension: 128

Test 2: ResNet18 Backbone
----------------------------------------------------------------------
 Model created
  Parameters: 11,301,568
 Output shape: torch.Size([4, 256])

Test 3: Similarity Functions
----------------------------------------------------------------------
 Cosine similarity: tensor([-0.0103,  0.0977,  0.0332,  0.1420,  0.0650])
  Shape: torch.Size([5])
  Range: [-0.010, 0.142]
 Euclidean distance: tensor([15.3223, 14.8142, 15.0832, 14.4881, 15.9881])
  Mean: 15.139
 Similarity matrix shape: torch.Size([10, 10])
  Diagonal (self-similarity): tensor([1.0000, 1.0000, 1.0000])

Test 4: Save and Load
----------------------------------------------------------------------
 Model saved to /tmp/test_model.pth
 Model l