## HF-Net

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import init

def image_normalization(image, pixel_value_offset=128.0, pixel_value_scale=128.0):
    return (image - pixel_value_offset) / pixel_value_scale

class VLAD(nn.Module):
    def __init__(self, config):
        super(VLAD, self).__init__()
        self.intermediate_proj = config.get('intermediate_proj', None)
        if self.intermediate_proj:
            self.pre_proj = nn.Conv2d(240, self.intermediate_proj, kernel_size=1)

        self.n_clusters = config['n_clusters']
        self.memberships = nn.Conv2d(240, self.n_clusters, kernel_size=1)

        # Cluster centers
        self.clusters = nn.Parameter(torch.empty(1, self.n_clusters, 240))
        self._initialize_weights()  # Initialize the cluster weights

    def _initialize_weights(self):
        # Xavier initialization for cluster weights
        init.xavier_uniform_(self.clusters)

    def forward(self, feature_map, mask=None):
        if self.intermediate_proj:
            feature_map = self.pre_proj(feature_map)

        batch_size, _, h, w = feature_map.size()

        # Compute memberships (soft-assignment)
        memberships = F.softmax(self.memberships(feature_map), dim=1)

        # Reshape feature_map and clusters for broadcasting
        feature_map = feature_map.permute(0, 2, 3, 1).unsqueeze(3)  # (B, H, W, 1, D)
        residuals = self.clusters - feature_map  # Compute residuals
        residuals = residuals * memberships.permute(0, 2, 3, 1).unsqueeze(4)  # Weight residuals by memberships

        if mask is not None:
            residuals = residuals * mask.unsqueeze(-1).unsqueeze(-1)

        # Sum residuals to form the VLAD descriptor
        descriptor = residuals.sum(dim=[1, 2])

        # Intra-normalization
        descriptor = F.normalize(descriptor, p=2, dim=-1)

        # Flatten descriptor and apply L2 normalization
        descriptor = descriptor.view(batch_size, -1)
        descriptor = F.normalize(descriptor, p=2, dim=1)

        return descriptor

class DimensionalityReduction(nn.Module):
    def __init__(self, config, proj_regularizer=None):
        """
        Initializes the Dimensionality Reduction module.

        Args:
            input_dim (int): Dimension of the input feature descriptor.
            output_dim (int): Dimension of the reduced descriptor.
            proj_regularizer (float, optional): L2 regularization strength. If None, no regularization is applied.
        """
        super(DimensionalityReduction, self).__init__()
        input_dim = config['n_clusters'] * 240
        output_dim = config['dimensionality_reduction']
        self.proj_regularizer = proj_regularizer

        # Fully connected layer with Xavier initialization
        self.fc = nn.Linear(input_dim, output_dim)
        nn.init.xavier_uniform_(self.fc.weight)

        # Optional L2 regularization
        if proj_regularizer is not None:
            self.regularizer = lambda w: proj_regularizer * torch.sum(w ** 2)
        else:
            self.regularizer = None

    def forward(self, descriptor):
        """
        Forward pass for the Dimensionality Reduction module.

        Args:
            descriptor (torch.Tensor): Input feature descriptor of shape (batch_size, input_dim).

        Returns:
            torch.Tensor: Reduced and normalized descriptor of shape (batch_size, output_dim).
        """
        # Normalize the input descriptor
        descriptor = F.normalize(descriptor, p=2, dim=-1)

        # Apply the fully connected layer
        descriptor = self.fc(descriptor)

        # Normalize the output descriptor
        descriptor = F.normalize(descriptor, p=2, dim=-1)

        # Apply regularization if specified
        if self.regularizer is not None:
            reg_loss = self.regularizer(self.fc.weight)
            return descriptor, reg_loss

        return descriptor


class LocalHead(nn.Module):
    def __init__(self, config):
        super(LocalHead, self).__init__()
        descriptor_dim = config['descriptor_dim']
        detector_grid = config['detector_grid']

        # Descriptor Head
        self.desc_conv1 = nn.Conv2d(config['input_channels'], descriptor_dim, kernel_size=3, stride=1, padding=1)
        self.desc_bn1 = nn.BatchNorm2d(descriptor_dim)
        self.desc_conv2 = nn.Conv2d(descriptor_dim, descriptor_dim, kernel_size=1, stride=1, padding=0)

        # Detector Head
        self.det_conv1 = nn.Conv2d(config['input_channels'], 128, kernel_size=3, stride=1, padding=1)
        self.det_bn1 = nn.BatchNorm2d(128)
        self.det_conv2 = nn.Conv2d(128, 1 + detector_grid ** 2, kernel_size=1, stride=1, padding=0)

        self.detector_grid = detector_grid

    def forward(self, features):
        # Descriptor Head
        desc = F.relu6(self.desc_bn1(self.desc_conv1(features)))
        desc = self.desc_conv2(desc)
        desc = F.normalize(desc, p=2, dim=1)

        # Detector Head
        logits = F.relu6(self.det_bn1(self.det_conv1(features)))
        logits = self.det_conv2(logits)

        prob_full = F.softmax(logits, dim=1)  # Compute softmax over the last dimension
        prob = prob_full[:, :-1, :, :]  # Exclude the "no interest point" dustbin
        prob = F.pixel_shuffle(prob, self.detector_grid)  # Convert to dense scores
        prob = torch.squeeze(prob, dim=1)  # Remove unnecessary channel dimension

        return desc, logits, prob_full, prob


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def conv_3x3_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class HFNet(nn.Module):
    def __init__(self, config, width_mult=1.0):
        super(HFNet, self).__init__()
        # [expand_ratio, channels, repeats, stride]
        self.cfgs = [
            [1,  16, 1, 1],
            [6,  24, 2, 2],
            [6,  32, 1, 2],
            [6,  64, 1, 1],
            [6, 128, 1, 1],
            [6, 64, 4, 2],
            [6,  96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # Feature Extractor (MobileNetV2 backbone)
        input_channel = _make_divisible(32 * width_mult, 8)
        layers = [conv_3x3_bn(3, input_channel, 2)]
        block = InvertedResidual
        for t, c, n, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)

        # Keypoint Detector Head
        self.local_head = LocalHead(config['local_head'])

        # Descriptor Head
        self.global_head = nn.Sequential(
            VLAD(config['global_head']),
            DimensionalityReduction(config['global_head'])
        )

        self._initialize_weights()

    def forward(self, x):
        x = image_normalization(x)
        # Backbone
        features_1 = self.features[:7](x)
        features_2 = self.features[7:](features_1)

        # local_head
        desc, logits, prob_full, prob = self.local_head(features_1)

        # Classification (if needed)
        descriptor = self.global_head(features_2)

        return {'local_descriptor_map':desc,
                'logits':logits,
                'prob_full':prob_full,
                'scores_dense':prob,
                'global_descriptor':descriptor,
                'image_shape':x.shape
                }

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

### Test_model

In [14]:
config= {
        'loss_weights': 'uncertainties',
        'local_head': {
            'descriptor_dim': 128,
            'detector_grid': 8,
            'input_channels': 96
        },
        'global_head': {
            'n_clusters': 32,
            'intermediate_proj': 0,
            'dimensionality_reduction': 4096
        }
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the initial weights
state_dict = torch.load('/kaggle/input/mobilenetv2-0.75/pytorch/default/1/mobilenetv2_0.75-dace9791.pth',
                        weights_only=False,
                        map_location=device)

# Create the model
model = HFNet(config, width_mult=0.75)

# Get the model's state_dict
model_dict = model.state_dict()

# Filter out mismatched keys
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].shape == v.shape}

# Update the model's weights
model_dict.update(filtered_state_dict)
model.load_state_dict(model_dict)

# # Set the model to evaluation mode
# image = torch.randn(1, 3, 480,640 ) # [B, C, H, W]
# output = model(image)

# for key, value in output.items():
#     if key != 'image_shape':
#         print(key, value.shape)
#     else:
#         print(key, value)

<All keys matched successfully>

## Post-processing

In [15]:
import torch
import torch.nn.functional as F

# Non-Maximum Suppression (NMS) function to filter out lower-scoring keypoints

def simple_nms(scores, radius, iterations=3):
    """Performs non-maximum suppression (NMS) on the heatmap using max-pooling."""
    size = 2 * radius + 1  # Kernel size based on radius

    def max_pool(x):
        return F.max_pool2d(x.unsqueeze(0).unsqueeze(0), kernel_size=size, stride=1, padding=radius).squeeze(0).squeeze(0)

    zeros = torch.zeros_like(scores, device=scores.device)  # Initialize zero tensor on the same device
    max_mask = scores == max_pool(scores)  # Mask for local maxima

    # Apply NMS iteratively
    for _ in range(iterations - 1):
        supp_mask = max_pool(max_mask.float()).bool()
        supp_scores = torch.where(supp_mask, zeros, scores)
        new_max_mask = supp_scores == max_pool(supp_scores)
        max_mask = max_mask | (new_max_mask & ~supp_mask)

    return torch.where(max_mask, scores, zeros)  # Keep maxima, zero out others

# Main prediction function to process keypoints and extract descriptors

def predict(ret, config):
    """Processes keypoints and extracts descriptors."""
    scores_dense = ret['scores_dense']
    # Apply NMS if configured
    if config['local']['nms_radius']:
        scores_dense = simple_nms(scores_dense, config['local']['nms_radius'])

    batch_size = scores_dense.shape[0]  # Get batch size

    # Initialize lists to collect outputs for each batch element
    keypoints_list = []
    scores_list = []
    descriptors_list = []
    min_keypoints = float('inf')  # Track the minimum number of keypoints in the batch

    for b in range(batch_size):  # Process each image in the batch
        # Extract keypoints where scores are above the threshold
        keypoints = (scores_dense[b] >= config['local']['detector_threshold']).nonzero(as_tuple=False).to(scores_dense.device)
        scores = scores_dense[b][keypoints[:, 0], keypoints[:, 1]]  # Get corresponding scores

        # Select top-k keypoints based on scores
        if config['local']['num_keypoints']:
            k = min(len(scores), config['local']['num_keypoints'])  # Limit to the specified number of keypoints
            topk_indices = scores.topk(k, largest=True, sorted=True).indices  # Get top-k indices
            keypoints = keypoints[topk_indices]  # Select top-k keypoints
            scores = scores[topk_indices]  # Select top-k scores
            
        # Get descriptor map for this image and adjust dimensions
        desc_map = ret['local_descriptor_map'][b].unsqueeze(0).to(scores_dense.device)  # [1, D, H', W']
        H_desc, W_desc = desc_map.shape[-2:]  # Descriptor map size

        # Unpack original image dimensions directly (assuming consistent size across the batch)
        _, _, H_img, W_img = ret['image_shape']  # Unpack dimensions from [B, C, H, W]

        # Scale keypoints from image space to descriptor space
        keypoints_scaled = keypoints.clone().float().to(scores_dense.device)
        keypoints_scaled[:, 0] *= (H_desc - 1) / (H_img - 1)  # Scale y-coordinates
        keypoints_scaled[:, 1] *= (W_desc - 1) / (W_img - 1)  # Scale x-coordinates

        # Normalize keypoints to [-1, 1] for grid_sample
        keypoints_scaled[:, 0] = (keypoints_scaled[:, 0] / (H_desc - 1)) * 2 - 1
        keypoints_scaled[:, 1] = (keypoints_scaled[:, 1] / (W_desc - 1)) * 2 - 1

        keypoints_scaled = keypoints_scaled.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, K, 2]

        # Sample descriptors using grid_sample for interpolation
        local_descriptors = F.grid_sample(desc_map, keypoints_scaled, mode='bilinear', align_corners=True)
        local_descriptors = local_descriptors.squeeze(2).squeeze(0).transpose(0, 1)  # [K, D]

        # Normalize descriptors to unit length
        local_descriptors = F.normalize(local_descriptors, p=2, dim=-1)

        # Append results for this batch
        keypoints_list.append(keypoints)
        scores_list.append(scores)
        descriptors_list.append(local_descriptors)

    # Convert lists to tensors for GPU compatibility
    ret.update({
        'keypoints': torch.stack(keypoints_list).to(scores_dense.device),
        'scores': torch.stack(scores_list).to(scores_dense.device),
        'local_descriptors': torch.stack(descriptors_list).to(scores_dense.device)
    })
    
    return ret


In [16]:
# ret = {
#     'scores_dense': torch.rand(1,480, 640),
#     'local_descriptor_map': torch.rand(1,256,60, 80),  # [H', W', D]
#     'logits': torch.rand(1,65,60, 80),
#     'prob_full': torch.rand(1,65,60, 80),
#     'global_descriptor': torch.rand(1,4096),
#     'image_shape': torch.tensor([1,3, 480, 640])
# }

# config = {
#     'local': {
#         'detector_threshold': 0.005,
#         'nms_radius': 0,
#         'num_keypoints': 10000
#     }
# }
# ret = predict(ret, config)

# for k, v in ret.items():
#     print(f"{k}: {v.shape}")

## Losses

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------
# 🔹 Optimized Global Descriptor Loss
# -------------------------------
def descriptor_global_loss(inp, out):
    """Computes L2 loss between input and output global descriptors."""
    d = (inp['global_descriptor'] - out['global_descriptor']) ** 2
    return d.mean(dim=-1)  # Use mean instead of sum to avoid large gradients

# -------------------------------
# 🔹 Optimized Local Descriptor Loss
# -------------------------------
def descriptor_local_loss(inp, out):
    """Computes L2 loss for local descriptors with aligned dimensions."""
    inp_desc = inp['local_descriptors']
    out_desc = out['local_descriptors']

    # Trim to smallest sequence length
    min_size = min(inp_desc.size(1), out_desc.size(1))
    inp_desc = inp_desc[:, :min_size, :]
    out_desc = out_desc[:, :min_size, :]

    # Compute squared L2 loss efficiently
    d = (inp_desc - out_desc).pow(2).mean(dim=(1, 2))  # Mean over keypoints & descriptors
    return d

# -------------------------------
# 🔹 Optimized Detector Loss
# -------------------------------
def detector_loss(inp, out, config):
    """Computes cross-entropy loss for keypoints or dense scores."""
    logits = out['logits']  # [B, C, H, W]
    B, C, H, W = logits.shape

    if 'keypoints' in inp:
        grid_size = config['local_head']['detector_grid']

        # Precompute keypoint heatmaps instead of dynamically creating them
        keypoints = inp['keypoints']
        keypoint_map = torch.zeros((B, 1, H * grid_size, W * grid_size), device=logits.device)

        for b in range(B):
            kp = keypoints[b].long()
            mask = (0 <= kp[:, 0]) & (kp[:, 0] < keypoint_map.shape[3]) & (0 <= kp[:, 1]) & (kp[:, 1] < keypoint_map.shape[2])
            valid_kp = kp[mask]  # Filter out invalid keypoints
            keypoint_map[b, 0, valid_kp[:, 1], valid_kp[:, 0]] = 1.0

        # Downsample the keypoint heatmap
        patches = F.unfold(keypoint_map, kernel_size=grid_size, stride=grid_size)  # [B, grid_size^2, new_H*new_W]
        patches = patches.permute(0, 2, 1).reshape(B, -1, grid_size * grid_size)

        # Construct labels efficiently
        labels = torch.argmax(torch.cat([2 * patches, torch.ones(B, patches.size(1), 1, device=logits.device)], dim=-1), dim=-1)

        # Reshape logits for cross-entropy
        logits = logits.permute(0, 2, 3, 1).reshape(B, H * W, C)  # [B, H*W, C]
        loss = F.cross_entropy(logits.reshape(-1, C), labels.reshape(-1).long(), reduction='mean')

    elif 'dense_scores' in inp:
        labels = inp['dense_scores']
        logits = logits.permute(0, 2, 3, 1).reshape(B, H * W, C)  # [B, H*W, C]
        loss = F.cross_entropy(logits.reshape(-1, C), labels.reshape(-1, labels.size(-1)), reduction='mean')

    else:
        raise ValueError("Input must contain 'keypoints' or 'dense_scores'.")

    return loss

# -------------------------------
# 🔹 Optimized Overall Loss Function
# -------------------------------
def loss(inputs, outputs, config):
    """Computes the total loss as a weighted combination of all components."""
    desc_g = descriptor_global_loss(inputs, outputs).mean()
    desc_l = descriptor_local_loss(inputs, outputs).mean()
    detect = detector_loss(inputs, outputs, config)

    # Efficient weighting
    if config['loss_weights'] == 'uncertainties':
        logvars = [nn.Parameter(torch.tensor(1.0, device=desc_g.device)) for _ in range(3)]
        precisions = [torch.exp(-logvar) for logvar in logvars]

        loss = desc_g * precisions[0] + logvars[0]
        loss += desc_l * precisions[1] + logvars[1]
        loss += 2 * detect * precisions[2] + logvars[2]

    else:
        w = config['loss_weights']
        total = sum(w.values())
        loss = (w['global_desc'] * desc_g + w['local_desc'] * desc_l + w['detector'] * detect) / total

    return loss


## Metric

In [18]:
def compute_metrics(outputs, inputs, config, global_loss_fn, local_loss_fn, detector_loss_fn):
    """
    Compute various loss metrics for the model, with optional uncertainty-based weighting.

    Args:
        outputs (dict): The model's predictions.
        inputs (dict): The ground truth or input data for comparison.
        config (dict): Configuration parameters, including 'loss_weights'.
        global_loss_fn (function): Custom function to compute global descriptor loss.
        local_loss_fn (function): Custom function to compute local descriptor loss.
        detector_loss_fn (function): Custom function to compute detector cross-entropy loss.

    Returns:
        dict: A dictionary containing computed loss values and optional trainable log-variance weights.
    """
    # Compute individual losses using custom functions
    global_desc_l2 = global_loss_fn(inputs, outputs)
    local_desc_l2 = local_loss_fn(inputs, outputs)
    detector_crossentropy = detector_loss_fn(inputs, outputs, config)

    # Initialize the metrics dictionary
    ret = {
        'global_desc_l2': global_desc_l2,
        'local_desc_l2': local_desc_l2,
        'detector_crossentropy': detector_crossentropy,
    }

    # Add trainable log-variance weights if 'loss_weights' is set to 'uncertainties'
    if config.get('loss_weights') == 'uncertainties':
        logvars = []
        for i in range(3):  # Assuming 3 metrics
            # Define log-variance as trainable parameters
            logvar = torch.nn.Parameter(torch.zeros(1))
            logvars.append(logvar)
            ret[f'logvar{i}'] = logvar  # Add to the metrics dictionary

    return ret


## Load Data

In [19]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import Dataset
from PIL import Image
import os

In [20]:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, image_dir, glo_dir, loc_dir, num, num_keypoints, transform=None):
        self.image_dir = image_dir
        self.glo_dir = glo_dir
        self.loc_dir = loc_dir
        self.num_keypoints = num_keypoints
        self.image_files = sorted(os.listdir(image_dir))[:num]
        self.glo_files = sorted(os.listdir(glo_dir))[:num]
        self.loc_files = sorted(os.listdir(loc_dir))[:num]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        glo_path = os.path.join(self.glo_dir, self.glo_files[idx])
        loc_path = os.path.join(self.loc_dir, self.loc_files[idx])
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Load global and local descriptors
        glo = {key: torch.from_numpy(value) for key, value in np.load(glo_path).items()}
        loc = {key: torch.from_numpy(value) for key, value in np.load(loc_path).items()}
        
        scores = loc['scores']
        keypoints = loc['keypoints']
        local_descriptors = loc['local_descriptors']

        # Select top-k keypoints
        k = min(len(scores), self.num_keypoints)
        if k > 0:
            topk_indices = scores.topk(k, largest=True, sorted=True).indices
            keypoints = keypoints[topk_indices]
            scores = scores[topk_indices]
            local_descriptors = local_descriptors[topk_indices]
        else:
            keypoints = torch.empty((0, 2))  # Empty tensor if no keypoints
            scores = torch.empty((0,))
            local_descriptors = torch.empty((0, local_descriptors.shape[-1]))  # [0, D] shape

        inp = {
            "global_descriptor": glo['global_descriptor'],  # ❌ No .to(self.device)
            "keypoints": keypoints,
            "scores": scores,
            "local_descriptors": local_descriptors,
        }

        return image, inp


In [21]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

# Define Transformations
transform = transforms.Compose([
    transforms.Resize((480, 640)),
    transforms.ToTensor(),
])

# Paths
dataset_path = '/kaggle/input/google-landmarks-tiny/images'
global_path = '/kaggle/input/global-descriptors/global_descriptors'
local_path = '/kaggle/input/silk-prediction/silk_prediction'

# Dataset Initialization
data_set = CustomImageDataset(
    image_dir=dataset_path, 
    glo_dir=global_path, 
    loc_dir=local_path,  
    num=1000,  # Number of images
    num_keypoints=10000,  
    transform=transform
)

# Train, Validation, Test Split
TRAIN_SIZE = 0.8
VAL_SIZE = 0.1
BATCH_SIZE = 1

train_size = int(TRAIN_SIZE * len(data_set))
val_size = int(VAL_SIZE * len(data_set))
test_size = len(data_set) - train_size - val_size

torch.manual_seed(42)  # Ensure reproducibility
train_dataset, val_dataset, test_dataset = random_split(data_set, [train_size, val_size, test_size])

# ⚡ Optimized DataLoaders ⚡
num_workers = 2  # Use more workers for faster data loading

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=num_workers, pin_memory=True, prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=num_workers, pin_memory=True, prefetch_factor=2
)

test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=num_workers, pin_memory=True
)


In [22]:
# for image, label in train_loader:
#     print(image.shape)
#     for k, v in label.items():
#         print(k,v.shape)
#     break

## Training 

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

def train_model(model, train_loader, val_loader, config, lr=1e-3, patience=5, epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = optim.RMSprop(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience)
    best_val_loss = float('inf')
    patience_counter = 0

    log_file = "training_log.csv"
    pd.DataFrame(columns=['epoch', 'global_desc_l2', 'local_desc_l2', 'detector_crossentropy', 'val_loss', 'learning_rate']).to_csv(log_file, index=False)

    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0
        metrics = {'global_desc_l2': 0, 'local_desc_l2': 0, 'detector_crossentropy': 0}
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            images = images.to(device)
            inputs = {k: v.to(device) for k, v in labels.items()}

            optimizer.zero_grad()
            ret = model(images)
            outputs = predict(ret, config)
            
            # Compute loss metrics using the compute_metrics function
            loss_dict = compute_metrics(outputs, inputs, config, descriptor_global_loss, descriptor_local_loss, detector_loss)

            # Extract loss values
            global_desc_l2 = loss_dict['global_desc_l2']
            local_desc_l2 = loss_dict['local_desc_l2']
            detector_crossentropy = loss_dict['detector_crossentropy']

            # Compute total loss
            total_loss = loss(inputs, outputs, config)

            total_loss.backward()
            optimizer.step()
            
            train_loss += total_loss.item()
            metrics['global_desc_l2'] += global_desc_l2.mean().item()
            metrics['local_desc_l2'] += local_desc_l2.mean().item()
            metrics['detector_crossentropy'] += detector_crossentropy.mean().item()

        # Normalize loss metrics
        num_batches = len(train_loader)
        for key in metrics:
            metrics[key] /= num_batches
        train_loss /= num_batches

        torch.cuda.empty_cache()  # Free unused GPU memory
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                inputs = {k: v.to(device) for k, v in labels.items()}

                ret = model(images)
                outputs = predict(ret, config)
                
                # Compute validation loss
                loss_dict = compute_metrics(outputs, inputs, config, descriptor_global_loss, descriptor_local_loss, detector_loss)
                total_loss = loss(inputs, outputs, config)
                
                val_loss += total_loss.item()

        val_loss /= len(val_loader)
        scheduler.step(val_loss)

        # Logging
        lr = scheduler.get_last_lr()[0]  # Retrieve current learning rate
        print(f"Epoch {epoch}/{epochs} | Val Loss: {val_loss:.4f}|"
              f"Local Loss: {metrics['local_desc_l2']:.4f} | Detector Loss: {metrics['detector_crossentropy']:.4f}| Global Loss: {metrics['global_desc_l2']:.4f}| LR: {lr:.6f} |")

        # Save logs
        pd.DataFrame([[epoch, metrics['global_desc_l2'], metrics['local_desc_l2'], metrics['detector_crossentropy'], val_loss, lr]],
                     columns=['epoch', 'global_desc_l2', 'local_desc_l2', 'detector_crossentropy', 'val_loss', 'learning_rate']).to_csv(log_file, mode='a', header=False, index=False)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break


In [24]:
config={
    'loss_weights': 'uncertainties',
    'local':{
        'detector_threshold': 0.005,
        'nms_radius': 0,
        'num_keypoints': 10000
    },
    'local_head': {
        'descriptor_dim': 128,
        'detector_grid': 8,
        'input_channels': 96
    },
    'global_head': {
        'n_clusters': 32,
        'intermediate_proj': 0,
        'dimensionality_reduction': 4096
    }
}

train_model(model, train_loader, val_loader, config, lr=1e-3, patience=5, epochs=30)

Epoch 1/30: 100%|██████████| 800/800 [00:36<00:00, 21.67it/s]


Epoch 1/30 | Val Loss: 3.8668|Local Loss: 0.0195 | Detector Loss: 1.1335| Global Loss: 0.0004| LR: 0.001000 |


Epoch 2/30: 100%|██████████| 800/800 [00:36<00:00, 21.83it/s]


Epoch 2/30 | Val Loss: 4.0000|Local Loss: 0.0194 | Detector Loss: 1.0971| Global Loss: 0.0004| LR: 0.001000 |


Epoch 3/30: 100%|██████████| 800/800 [00:37<00:00, 21.44it/s]


Epoch 3/30 | Val Loss: 3.8167|Local Loss: 0.0194 | Detector Loss: 1.0909| Global Loss: 0.0004| LR: 0.001000 |


Epoch 4/30: 100%|██████████| 800/800 [00:37<00:00, 21.42it/s]


Epoch 4/30 | Val Loss: 3.7832|Local Loss: 0.0194 | Detector Loss: 1.0870| Global Loss: 0.0004| LR: 0.001000 |


Epoch 5/30: 100%|██████████| 800/800 [00:37<00:00, 21.58it/s]


Epoch 5/30 | Val Loss: 3.8141|Local Loss: 0.0194 | Detector Loss: 1.0844| Global Loss: 0.0004| LR: 0.001000 |


Epoch 6/30: 100%|██████████| 800/800 [00:37<00:00, 21.59it/s]


Epoch 6/30 | Val Loss: 3.7855|Local Loss: 0.0194 | Detector Loss: 1.0822| Global Loss: 0.0003| LR: 0.001000 |


Epoch 7/30: 100%|██████████| 800/800 [00:36<00:00, 21.73it/s]


Epoch 7/30 | Val Loss: 3.7848|Local Loss: 0.0194 | Detector Loss: 1.0804| Global Loss: 0.0003| LR: 0.001000 |


Epoch 8/30: 100%|██████████| 800/800 [00:37<00:00, 21.31it/s]


Epoch 8/30 | Val Loss: 3.7846|Local Loss: 0.0194 | Detector Loss: 1.0786| Global Loss: 0.0003| LR: 0.001000 |


Epoch 9/30: 100%|██████████| 800/800 [00:36<00:00, 21.73it/s]


Epoch 9/30 | Val Loss: 4.0126|Local Loss: 0.0193 | Detector Loss: 1.0773| Global Loss: 0.0003| LR: 0.001000 |
Early stopping triggered.
