In [1]:
import torch
import torch as tr
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from spatial_correlation_sampler import SpatialCorrelationSampler

In [2]:
batch = 1
channel = 4
patch_size = 5
height = 14
width = 14

# Creating random tensors with the specified properties and moving them to GPU
input1 = torch.randn(batch, channel, height, width, dtype=torch.float32, requires_grad=True).cuda()
input2 = torch.randn(batch, channel, height, width, dtype=torch.float32, requires_grad=True).cuda()

In [1]:
def corr_abs_to_rel(corr, h, w, patch_size):
    max_d = patch_size // 2

    b, c, s = corr.size()
    corr = corr.view(b, h, w, h, w)

    # Horizontal correlation
    w_diag = torch.zeros((b, h, h, patch_size, w), device='cuda')
    for i in range(max_d + 1):
        w_corr_offset_pos = torch.diagonal(corr, offset=i, dim1=2, dim2=4)
        w_corr_offset_pos = F.pad(w_corr_offset_pos, (i, 0))
        w_diag[:, :, :, max_d - i] = w_corr_offset_pos

        w_corr_offset_neg = torch.diagonal(corr, offset=-i, dim1=2, dim2=4)
        w_corr_offset_neg = F.pad(w_corr_offset_neg, (0, i))
        w_diag[:, :, :, max_d + i] = w_corr_offset_neg

    # Vertical correlation
    hw_diag = torch.zeros((b, patch_size, w, patch_size, h), device='cuda')
    for i in range(max_d + 1):
        h_corr_offset_pos = torch.diagonal(w_diag, offset=i, dim1=1, dim2=2)
        h_corr_offset_pos = F.pad(h_corr_offset_pos, (i, 0))
        hw_diag[:, :, :, max_d - i] = h_corr_offset_pos

        h_corr_offset_neg = torch.diagonal(w_diag, offset=-i, dim1=1, dim2=2)
        h_corr_offset_neg = F.pad(h_corr_offset_neg, (0, i))
        hw_diag[:, :, :, max_d + i] = h_corr_offset_neg

    hw_diag = hw_diag.permute(0, 3, 1, 4, 2).contiguous()
    hw_diag = hw_diag.view(-1, patch_size * patch_size, h * w)

    return hw_diag

In [2]:
# Define ReLU activation
relu = nn.ReLU()

# L2 normalization function
def L2normalize(x, d=1):
    eps = 1e-6
    norm = x.pow(2).sum(dim=d, keepdim=True) + eps
    norm = norm.pow(0.5)
    return x / norm

# Match layer using matrix multiplication
def match_layer_mm(feature1, feature2):
    # L2 normalize input features
    feature1 = L2normalize(feature1)
    feature2 = L2normalize(feature2)

    # Get dimensions of input features
    b, c, h1, w1 = feature1.size()
    _, _, h2, w2 = feature2.size()

    # Reshape features for matrix multiplication
    feature1 = feature1.view(b, c, h1 * w1)
    feature2 = feature2.view(b, c, h2 * w2)

    # Compute correlation matrix using matrix multiplication
    corr = torch.bmm(feature2.transpose(1, 2), feature1)
    corr = corr.view(b, h2 * w2, h1 * w1)  # Channel : target // Spatial grid : source

    # Apply absolute-to-relative transformation
    corr = corr_abs_to_rel(corr, height, width).cuda()  # Assuming height and width for corr_abs_to_rel

    # Apply ReLU activation
    corr = relu(corr)

    return corr

# Spatial correlation sampler
correlation_sampler = SpatialCorrelationSampler(
    kernel_size=1,
    patch_size=patch_size,
    stride=1,
    padding=0,
    dilation_patch=1
)

# Match layer using spatial correlation sampler
def match_layer_scs(feature1, feature2):
    # L2 normalize input features
    feature1 = L2normalize(feature1)
    feature2 = L2normalize(feature2)

    # Get dimensions of input features
    b, c, h1, w1 = feature1.size()
    _, _, h2, w2 = feature2.size()

    # Compute correlation matrix using spatial correlation sampler
    corr = correlation_sampler(feature1, feature2)  # (b, p, p, h, w)
    corr = corr.view(b, -1, h1 * w1)

    # Apply ReLU activation
    corr = relu(corr)

    return corr

NameError: name 'nn' is not defined

In [3]:
# Compute correlation using matrix multiplication
corr_mm = match_layer_mm(input1, input2)

# Compute correlation using spatial correlation sampler
corr_scs = match_layer_scs(input1, input2)

# Print the sizes of the correlation matrices
print("Matrix Multiplication Correlation Size:", corr_mm.size())
print("Spatial Correlation Sampler Correlation Size:", corr_scs.size())

# Check if the two correlation matrices are close
are_close = torch.allclose(corr_mm, corr_scs, atol=1e-6)
print("Are the correlation matrices close?", are_close)

NameError: name 'match_layer_mm' is not defined