# Classification and Visual Search

This notebook builds a dual-purpose architecture for an image-based product catalog system with two capabilities: **classifying** clothing items and powering a **visual search** engine for recommendations.

The approach emphasizes **efficiency** and **reusability** — a well-designed feature extractor can be leveraged for multiple downstream tasks.

**What we build:**
* `InvertedResidualBlock` — the efficient core component inspired by MobileNetV2
* `MobileNetBackbone` — stacked custom blocks for feature extraction
* `MobileNetLikeClassifier` — multi-class fashion item categorization with weighted loss
* `TripleDataset` — generates anchor, positive, and negative examples for similarity learning
* `SiameseEncoder` — reuses the trained backbone from the classifier
* `SiameseNetwork` — trained with `TripletMarginLoss` for visual similarity
* Visual search retrieval using image embeddings

## Table of Contents
- [Imports](#0)
- [1 - Building a Fashion Item Classifier](#1)
    - [1.1 - The Fashion Dataset](#1-1)
    - [1.2 - Preparing the Data Pipeline](#1-2)
    - [1.3 - Architecting the Classifier: Inverted Residuals](#1-3)
    - [1.4 - Assembling the Full Classifier](#1-4)
    - [1.5 - Training the Classifier](#1-5)  
    - [1.6 - Evaluating the Classifier](#1-6)
- [2 - Building a Visual Search Engine](#2)
    - [2.1 - Similarity Learning for Recommendations](#2-1)
    - [2.2 - The Triplet Dataset](#2-2)
    - [2.3 - Architecting the Visual Search Model](#2-3)
        - [2.3.1 - The Siamese Encoder](#2-3-1)
        - [2.3.2 - The Siamese Network Wrapper](#2-3-2)
    - [2.4 - Training the Siamese Network](#2-4)
    - [2.5 - Performing Visual Search](#2-5)
- [3 - Conclusion](#3)

<a name='0'></a>
## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset
import torchvision.utils as vutils
from torchvision import transforms
import torchinfo
import copy

import helper_utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

<a name='1'></a>
## 1 - Building a Fashion Item Classifier

<a name='1-1'></a>
### 1.1 - The Fashion Dataset

The dataset is derived from the **clothing-dataset-small** collection, curated into training and validation sets with seven categories: `dress`, `hat`, `longsleeve`, `pants`, `shoes`, `shorts`, and `t-shirt`.

<a name='1-2'></a>
### 1.2 - Preparing the Data Pipeline

The data pipeline defines image transformations for training (with augmentation) and validation (preprocessing only), then loads the datasets and creates DataLoaders for batching.

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
])

val_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
])

dataset_path = "./clothing-dataset-small"

train_dataset, validation_dataset = helper_utils.load_datasets(
    dataset_path=dataset_path,
    train_transform=train_transform,
    val_transform=val_transform,
)

classes = train_dataset.classes
num_classes = len(classes)
print(f"Classes: {classes}")
print(f"Number of classes: {num_classes}")

In [None]:
helper_utils.show_sample_images(train_dataset)

In [None]:
train_loader, val_loader = helper_utils.create_dataloaders(
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
    batch_size=32
)

<a name='1-3'></a>
### 1.3 - Architecting the Classifier: Inverted Residuals

The classifier is inspired by **MobileNetV2**, which uses an **Inverted Residual Block** with a narrow → wide → narrow structure. It expands input channels with a 1x1 convolution, applies a lightweight depthwise separable convolution for spatial features, then projects back down with another 1x1 convolution. A skip connection aids gradient flow.

The architecture is built modularly:
1. `InvertedResidualBlock` — the core efficient block
2. `MobileNetBackbone` — stacks blocks for feature extraction
3. `MobileNetLikeClassifier` — adds a classification head

In [None]:
class InvertedResidualBlock(nn.Module):
    """
    Implements an inverted residual block, often used in architectures like MobileNetV2.
    
    This block features an expansion phase (1x1 convolution), a depthwise
    convolution (3x3 convolution), and a projection phase (1x1 convolution).
    It utilizes a residual connection between the input and the output of the projection.
    """
    
    def __init__(self, in_channels, out_channels, stride, expansion_factor, shortcut=None):
        """
        Initializes the InvertedResidualBlock module.

        Args:
            in_channels: The number of channels in the input tensor.
            out_channels: The number of channels in the output tensor.
            stride (int): The stride to be used in the depthwise convolution.
            expansion_factor (int): The factor by which to expand the input channels.
            shortcut: An optional module for the shortcut connection.
        """
        super().__init__()
        hidden_dim = in_channels * expansion_factor

        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
        )

        self.depthwise = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
        )

        self.project = nn.Sequential(
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        self.shortcut = shortcut

    def forward(self, x):
        """
        Defines the forward pass of the InvertedResidualBlock.

        Args:
            x: The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying the block operations and residual connection.
        """
        skip = x
        out = self.expand(x)
        out = self.depthwise(out)
        out = self.project(out)

        if self.shortcut is not None:
            skip = self.shortcut(x)

        out = out + skip
        return F.relu(out)

In [None]:
class MobileNetBackbone(nn.Module):
    """
    Implements a simplified MobileNet-like backbone feature extractor.
    """

    def __init__(self):
        """Initializes the layers of the MobileNet backbone."""
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
        )

        self.blocks = nn.Sequential(
            self._make_block(16, 24, stride=2, expansion_factor=3),
            self._make_block(24, 32, stride=2, expansion_factor=3),
            self._make_block(32, 64, stride=2, expansion_factor=6),
        )

    def _make_block(self, in_channels, out_channels, stride=1, expansion_factor=6):
        """Helper method to create a single InvertedResidualBlock."""
        if in_channels != out_channels or stride != 1:
            shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            shortcut = None

        return InvertedResidualBlock(in_channels, out_channels, stride, expansion_factor, shortcut)

    def forward(self, x):
        """Defines the forward pass of the backbone."""
        x = self.stem(x)
        x = self.blocks(x)
        return x

<a name='1-4'></a>
### 1.4 - Assembling the Full Classifier

The full classifier combines the `MobileNetBackbone` with a classification head that uses Adaptive Average Pooling to reduce spatial dimensions, flattens the result, and applies a Linear layer to produce class logits.

In [None]:
class MobileNetLikeClassifier(nn.Module):
    """
    A classifier model that combines a feature extraction backbone with a classification head.
    """
    
    def __init__(self, num_classes=10):
        """
        Initializes the classifier components.

        Args:
            num_classes (int): The number of output classes.
        """
        super().__init__()
        self.backbone = MobileNetBackbone()
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        """
        Defines the forward pass of the classifier.

        Args:
            x: The input tensor (batch of images).

        Returns:
            torch.Tensor: The raw logits for each class.
        """
        x = self.backbone(x)
        x = self.head(x)
        return x

mobilenet_classifier = MobileNetLikeClassifier(num_classes=num_classes)

<a name='1-5'></a>
### 1.5 - Training the Classifier

Training uses weighted `CrossEntropyLoss` to handle class imbalance, the `Adam` optimizer, and a `StepLR` scheduler to reduce learning rate periodically.

In [None]:
class_weights = helper_utils.compute_class_weights(train_dataset).to(device)
loss_fcn = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(mobilenet_classifier.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

n_epochs = 5

trained_classifier = helper_utils.training_loop(
    mobilenet_classifier, 
    train_loader, 
    val_loader, 
    loss_fcn, 
    optimizer, 
    scheduler, 
    device, 
    n_epochs=n_epochs
)

<a name='1-6'></a>
### 1.6 - Evaluating the Classifier

Visualizing predictions from the trained model gives a qualitative sense of performance.

In [None]:
helper_utils.display_random_predictions_per_class(trained_classifier, val_loader, classes, device)

<a name='2'></a>
## 2 - Building a Visual Search Engine

<a name='2-1'></a>
### 2.1 - Similarity Learning for Recommendations

Instead of classifying items, a Siamese Network maps images into an embedding space where visually similar items are located close together. This enables visual search, smarter recommendations, and finding alternatives for out-of-stock items.

<a name='2-2'></a>
### 2.2 - The Triplet Dataset

The network learns from **triplets**: an **anchor** image, a **positive** (same category), and a **negative** (different category). Training pulls anchor-positive embeddings closer while pushing anchor-negative embeddings apart.

In [None]:
class TripleDataset(Dataset):
    """
    A custom Dataset that returns triplets of images (anchor, positive, negative).
    """
    
    def __init__(self, dataset):
        """
        Initializes the TripleDataset.

        Args:
            dataset: The base dataset containing (data, label) pairs.
        """
        self.dataset = dataset
        self.labels = range(len(dataset.classes))
        self.labels_to_indices = self._get_labels_to_indices()

    def __len__(self):
        return len(self.dataset)

    def _get_labels_to_indices(self):
        """Creates a dictionary mapping each label to a list of indices."""
        labels_to_indices = {}
        for idx, (_, label) in enumerate(self.dataset):
            if label not in labels_to_indices:
                labels_to_indices[label] = []
            labels_to_indices[label].append(idx)
        return labels_to_indices

    def _get_positive_negative_indices(self, anchor_label):
        """Finds random indices for a positive and a negative sample."""
        positive_indices = self.labels_to_indices[anchor_label]
        positive_index = random.choice(positive_indices)

        negative_label = random.choice([label for label in self.labels if label != anchor_label])
        negative_indices = self.labels_to_indices[negative_label]
        negative_index = random.choice(negative_indices)

        return positive_index, negative_index

    def __getitem__(self, idx):
        """Retrieves a triplet (anchor, positive, negative) for a given index."""
        anchor_image, anchor_label = self.dataset[idx]
        positive_index, negative_index = self._get_positive_negative_indices(anchor_label)
        positive_image, _ = self.dataset[positive_index]
        negative_image, _ = self.dataset[negative_index]

        return (anchor_image, positive_image, negative_image)

In [None]:
triple_dataset = TripleDataset(train_dataset)

siamese_dataloader = torch.utils.data.DataLoader(
    triple_dataset,
    batch_size=32,
    shuffle=True,
)

<a name='2-3'></a>
### 2.3 - Architecting the Visual Search Model

<a name='2-3-1'></a>
#### 2.3.1 - The Siamese Encoder

The encoder reuses the `MobileNetBackbone` from the classifier, adding a representation head (pooling + flatten) to produce fixed-size embedding vectors.

In [None]:
class SiameseEncoder(nn.Module):
    """
    An encoder module for Siamese networks that produces fixed-size embeddings.
    """

    def __init__(self, backbone):
        """
        Initializes the SiameseEncoder.

        Args:
            backbone (nn.Module): The feature extractor network.
        """
        super().__init__()
        self.backbone = backbone
        self.representation = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

    def forward(self, x):
        """Returns the embedding vector for the input."""
        features = self.backbone(x)
        return self.representation(features)

siamese_encoder = SiameseEncoder(backbone=trained_classifier.backbone)

<a name='2-3-2'></a>
#### 2.3.2 - The Siamese Network Wrapper

The `SiameseNetwork` passes each triplet image through the shared encoder and returns three embeddings.

In [None]:
class SiameseNetwork(nn.Module):
    """
    Siamese Network that processes triplets through a shared embedding network.
    """

    def __init__(self, embedding_network):
        """
        Initializes the SiameseNetwork.

        Args:
            embedding_network (nn.Module): The shared encoder network.
        """
        super().__init__()
        self.embedding_network = embedding_network

    def forward(self, anchor, positive, negative):
        """Returns embeddings for anchor, positive, and negative images."""
        anchor_output = self.embedding_network(anchor)
        positive_output = self.embedding_network(positive)
        negative_output = self.embedding_network(negative)
        return anchor_output, positive_output, negative_output

    def get_embedding(self, image):
        """Generates an embedding for a single image."""
        return self.embedding_network(image)

siamese_network = SiameseNetwork(embedding_network=siamese_encoder)

<a name='2-4'></a>
### 2.4 - Training the Siamese Network

Training uses `TripletMarginLoss` which penalizes the model if the anchor-positive distance isn't smaller than the anchor-negative distance by at least a margin.

In [None]:
loss_fcn = nn.TripletMarginLoss(margin=1.0, p=2.0)
optimizer = torch.optim.AdamW(siamese_network.parameters(), lr=0.001)

num_epochs = 5

helper_utils.siamese_training_loop(
    model=siamese_network,
    dataloader=siamese_dataloader,
    loss_fcn=loss_fcn,
    optimizer=optimizer,
    device=device,
    n_epochs=num_epochs,
)

<a name='2-5'></a>
### 2.5 - Performing Visual Search

Visual search works by generating embeddings for a query image and all catalog items, then finding the closest matches by Euclidean distance.

In [None]:
def get_query_img_embedding(encoder, transform, img, device):
    """
    Generates an embedding vector for a single query PIL image.

    Args:
        encoder (nn.Module): The trained embedding model.
        transform (callable): The torchvision transforms to apply.
        img (PIL.Image): The input query image.
        device (torch.device): The device to perform inference on.

    Returns:
        np.ndarray: The embedding vector as a NumPy array.
    """
    tensor_img = transform(img)
    query_img_tensor = tensor_img.unsqueeze(0).to(device)
    
    encoder.eval()
    with torch.no_grad():
        query_img_embedding = encoder(query_img_tensor)
    
    return query_img_embedding.cpu().numpy()

In [None]:
image_path = './images/t_shirt.jpg'
query_img = helper_utils.get_query_img(image_path)
display(query_img)

query_img_embedding = get_query_img_embedding(siamese_encoder, val_transform, query_img, device)

catalog = validation_dataset
embeddings = helper_utils.get_embeddings(siamese_encoder, catalog, device)

num_samples = 5
closest_indices = helper_utils.find_closest(embeddings, query_img_embedding, num_samples)

print(f"\nTop {num_samples} similar items:")
for idx_c in closest_indices:
    img_c, label_idx_c = helper_utils.get_image(catalog, idx_c)
    print(f"Class: {catalog.classes[label_idx_c]}")
    display(img_c)

<a name='3'></a>
## 3 - Conclusion

This notebook demonstrated building a dual-purpose AI system: an efficient MobileNet-inspired classifier using `InvertedResidualBlock` and `MobileNetBackbone`, then repurposing that backbone for a Siamese-based visual search engine. Key takeaways include modular architecture design, transfer of learned features between tasks, and the triplet loss approach for learning visual similarity.