In [2]:
import torch
import numpy as np
import pandas as pd

In [12]:
torch.arange(16).reshape(4, 4)

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [36]:
%%timeit
kernel = torch.ones(4).reshape(2, 2).float()
x = torch.arange(16).reshape(4, 4).float()

# Unfold the input tensor to extract sliding windows
y = x.unfold(0, 2, 1).unfold(1, 2, 1)

# Reshape the unfolded tensor to prepare for matrix multiplication
y = y.contiguous().view(-1, 2 * 2)
print(y)
# Flatten the kernel
kernel_flat = kernel.view(-1, 1)
print(kernel_flat)
# Perform the matrix multiplication
result = y @ kernel_flat

# Reshape the result to match the expected output shape
output = result.view(3, 3)

print(output)

tensor([[ 0.,  1.,  4.,  5.],
        [ 1.,  2.,  5.,  6.],
        [ 2.,  3.,  6.,  7.],
        [ 4.,  5.,  8.,  9.],
        [ 5.,  6.,  9., 10.],
        [ 6.,  7., 10., 11.],
        [ 8.,  9., 12., 13.],
        [ 9., 10., 13., 14.],
        [10., 11., 14., 15.]])
tensor([[1.],
        [1.],
        [1.],
        [1.]])
tensor([[10., 14., 18.],
        [26., 30., 34.],
        [42., 46., 50.]])
tensor([[ 0.,  1.,  4.,  5.],
        [ 1.,  2.,  5.,  6.],
        [ 2.,  3.,  6.,  7.],
        [ 4.,  5.,  8.,  9.],
        [ 5.,  6.,  9., 10.],
        [ 6.,  7., 10., 11.],
        [ 8.,  9., 12., 13.],
        [ 9., 10., 13., 14.],
        [10., 11., 14., 15.]])
tensor([[1.],
        [1.],
        [1.],
        [1.]])
tensor([[10., 14., 18.],
        [26., 30., 34.],
        [42., 46., 50.]])
tensor([[ 0.,  1.,  4.,  5.],
        [ 1.,  2.,  5.,  6.],
        [ 2.,  3.,  6.,  7.],
        [ 4.,  5.,  8.,  9.],
        [ 5.,  6.,  9., 10.],
        [ 6.,  7., 10., 11.],
        [ 8.

In [90]:
import torch

# Setting seed
torch.manual_seed(42)

# Define the mapping function
def mapping_function(row):
    # Check if row has more than 3 NaNs, if so return NaNs
    if row.isnan().sum() >= 3:
        return torch.full_like(row, torch.nan)
    else:
        # Change NaNs to 0 and return the row
        return torch.where(row.isnan(), torch.tensor(0.0, dtype=row.dtype), row)

kernel = torch.ones(4).reshape(2, 2).float().view(-1, 1)
x = torch.arange(16).reshape(4, 4).float()

# Creating a random mask
mask = torch.randint(0, 2, (9, 4)).bool()

# Unfold the input tensor to extract sliding windows
y = x.unfold(0, 2, 1).unfold(1, 2, 1)

# Reshape the unfolded tensor to prepare for matrix multiplication
y = y.contiguous().view(-1, 2 * 2)

# Adding NaNs where the mask is
y = y.masked_fill(mask, torch.nan)

# Keeping only those that have less than 3 NaNs
z = y.isnan().sum(1) < 3

# Apply the mapping function to each row
result = torch.stack([mapping_function(row) for row in y])

convolution = result @ kernel

print(y)
convolution.reshape(3, 3)

tensor([[ 0., nan,  4.,  5.],
        [ 1., nan,  5.,  6.],
        [ 2., nan,  6.,  7.],
        [ 4.,  5., nan,  9.],
        [nan, nan, nan, 10.],
        [nan,  7., nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, 13., 14.],
        [nan, nan, nan, 15.]])


tensor([[ 9., 12., 15.],
        [18., nan, nan],
        [nan, 27., nan]])

In [9]:
import torch
import torch.nn as nn
import torch._dynamo as dynamo
import time

# Setting seed
torch.manual_seed(42)

# Define the mapping function
def mapping_function(row):
    # Check if row has more than 3 NaNs, if so return NaNs
    if row.isnan().sum() >= 3:
        return torch.full_like(row, torch.nan)
    else:
        # Change NaNs to 0 and return the row
        return torch.where(row.isnan(), torch.tensor(0.0, dtype=row.dtype), row)

class GeneralizedMaskedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(GeneralizedMaskedConvolution, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)
        self.kernel = nn.Parameter(torch.ones(out_channels, in_channels, *self.kernel_size).float())

    def forward(self, x, mask):
        # Pad the input tensor
        x = torch.nn.functional.pad(x, (self.padding[1], self.padding[1], self.padding[0], self.padding[0]))

        # Unfold the input tensor to extract sliding windows
        patches = x.unfold(2, self.kernel_size[0], self.stride[0]).unfold(3, self.kernel_size[1], self.stride[1])
        N, C, num_patches_h, num_patches_w, K_H, K_W = patches.shape
        patches = patches.contiguous().view(N, C, num_patches_h * num_patches_w, K_H * K_W)

        # Add NaNs where the mask is
        patches = patches.masked_fill(mask, torch.nan)

        # Apply the mapping function to each row
        result = torch.stack([mapping_function(row) for row in patches.view(-1, K_H * K_W)]).view(N, C, num_patches_h, num_patches_w, K_H * K_W)

        # Reshape the kernel to match the patch size
        kernel = self.kernel.view(self.out_channels, self.in_channels * K_H * K_W)

        # Perform convolution by matrix multiplication
        patches_reshaped = result.permute(0, 2, 3, 1, 4).reshape(-1, self.in_channels * K_H * K_W)
        kernel_reshaped = kernel.t()

        # Matrix multiplication to get the output
        output_reshaped = torch.mm(patches_reshaped, kernel_reshaped)

        # Reshape output to match the expected dimensions
        output = output_reshaped.view(N, num_patches_h, num_patches_w, self.out_channels)
        output = output.permute(0, 3, 1, 2)

        return output

tensor([[[[ 21.,  24.,  27.,  21.,  nan,  nan,  nan,  35.,  nan],
          [ 52.,  23.,  nan,  nan,  38.,  67.,  86.,  45.,  nan],
          [ 50.,  nan,  54.,  81.,  nan,  nan,  nan,  nan,  nan],
          [ 72.,  nan,  75.,  78.,  79.,  nan, 166., 170.,  nan],
          [142., 144., 137.,  96.,  89., 102., 103.,  nan, 214.],
          [162., 165., 177., 170., 129., 186.,  nan, 193., 254.],
          [132., 123.,  nan, 127.,  nan,  nan, 144., 222., 216.],
          [232., 153., 237.,  nan, 149.,  nan, 162.,  nan, 256.],
          [261.,  nan, 175., 260.,  nan,  nan, 183., 272.,  nan]]]],
       grad_fn=<PermuteBackward0>)
tensor([[[[ 21.,  24.,  27.,  21.,  nan,  nan,  nan,  35.,  nan],
          [ 52.,  23.,  nan,  nan,  38.,  67.,  86.,  45.,  nan],
          [ 50.,  nan,  54.,  81.,  nan,  nan,  nan,  nan,  nan],
          [ 72.,  nan,  75.,  78.,  79.,  nan, 166., 170.,  nan],
          [142., 144., 137.,  96.,  89., 102., 103.,  nan, 214.],
          [162., 165., 177., 170., 129

In [16]:
import torch
import torch.nn as nn
import torch._dynamo as dynamo
import time

# Setting seed
torch.manual_seed(42)

# Define the mapping function
def mapping_function(row):
    # Check if row has more than 3 NaNs, if so return NaNs
    if row.isnan().sum() >= 3:
        return torch.full_like(row, torch.nan)
    else:
        # Change NaNs to 0 and return the row
        return torch.where(row.isnan(), torch.tensor(0.0, dtype=row.dtype), row)

class GeneralizedMaskedConvolutionForward(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(GeneralizedMaskedConvolutionForward, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)
        self.kernel = nn.Parameter(torch.ones(out_channels, in_channels, *self.kernel_size).float())
    
    @dynamo.optimize("inductor")
    def forward(self, x, mask):
        # Pad the input tensor
        x = torch.nn.functional.pad(x, (self.padding[1], self.padding[1], self.padding[0], self.padding[0]))

        # Unfold the input tensor to extract sliding windows
        patches = x.unfold(2, self.kernel_size[0], self.stride[0]).unfold(3, self.kernel_size[1], self.stride[1])
        N, C, num_patches_h, num_patches_w, K_H, K_W = patches.shape
        patches = patches.contiguous().view(N, C, num_patches_h * num_patches_w, K_H * K_W)

        # Add NaNs where the mask is
        patches = patches.masked_fill(mask, torch.nan)

        # Apply the mapping function to each row
        result = torch.stack([mapping_function(row) for row in patches.view(-1, K_H * K_W)]).view(N, C, num_patches_h, num_patches_w, K_H * K_W)

        # Reshape the kernel to match the patch size
        kernel = self.kernel.view(self.out_channels, self.in_channels * K_H * K_W)

        # Perform convolution by matrix multiplication
        patches_reshaped = result.permute(0, 2, 3, 1, 4).reshape(-1, self.in_channels * K_H * K_W)
        kernel_reshaped = kernel.t()

        # Matrix multiplication to get the output
        output_reshaped = torch.mm(patches_reshaped, kernel_reshaped)

        # Reshape output to match the expected dimensions
        output = output_reshaped.view(N, num_patches_h, num_patches_w, self.out_channels)
        output = output.permute(0, 3, 1, 2)

        return output

In [20]:
x = torch.arange(100).reshape(1, 1, 10, 10).float()
mask = torch.randint(0, 2, (1, 1, 81, 4)).bool()


# Initialize the generalized masked convolution layer
masked_conv = GeneralizedMaskedConvolution(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0)

# Measure time for non-compiled model
start_time = time.time()
output = masked_conv(x, mask)
non_compiled_time_naive = time.time() - start_time
print(output)

# Compile the model using torch.compile
compiled_masked_conv = torch.compile(masked_conv)

# Measure time for compiled model
start_time = time.time()
output = compiled_masked_conv(x, mask)
compiled_time = time.time() - start_time
print(output)

# Comparing the time for a normal convolution
start_time = time.time()
convolution = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0)
output = convolution(x)
convolution_time = time.time() - start_time
# print(output)

# Compile the model using torch._dynamo.optimize
compiled_masked_conv = dynamo.optimize("eager")(masked_conv)

# Measure time for compiled model
start_time = time.time()
output = compiled_masked_conv(x, mask)
compiled_time_dynamo = time.time() - start_time

# Using forward pass dynamo
masked_conv = GeneralizedMaskedConvolutionForward(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0)

# Measure time for non-compiled model (might be slower due to first-time compilation)
start_time = time.time()
output = masked_conv(x, mask)
non_compiled_time = time.time() - start_time
# print(output)

# Subsequent calls should be faster due to Dynamo optimization
start_time = time.time()
output = masked_conv(x, mask)
compiled_forward = time.time() - start_time
print(compiled_forward)

print(f"Non-compiled model time: {non_compiled_time_naive:.6f} seconds")
print(f"Compiled model time: {compiled_time:.6f} seconds")
print(f"Compiled model time dynamo: {compiled_time_dynamo:.6f} seconds")
print(f"Compiled model time forward: {compiled_forward:.6f} seconds")
print(f"Convolution model time: {convolution_time:.6f} seconds")



tensor([[[[ nan,  nan,   5.,  30.,  23.,  21.,  39.,  nan,  35.],
          [ nan,  66.,  36.,  47.,  54.,  66.,  59.,  90.,  46.],
          [ 82.,  84., 110.,  67.,  58., 122.,  73.,  nan,  66.],
          [ nan,  63.,  75.,  87., 123., 162.,  93.,  nan,  nan],
          [101.,  83.,  94.,  97., 154.,  nan, 102.,  nan, 108.],
          [111., 226., 168.,  nan,  nan, 177., 113., 192., 117.],
          [132., 205.,  nan, 137., 204., 207., 153., 145., 148.],
          [222., 143.,  nan,  nan,  nan, 322., 249.,  nan, 256.],
          [262., 183., 165., 187., 178., 277., 184.,  nan, 188.]]]],
       grad_fn=<PermuteBackward0>)
tensor([[[[ nan,  nan,   5.,  30.,  23.,  21.,  39.,  nan,  35.],
          [ nan,  66.,  36.,  47.,  54.,  66.,  59.,  90.,  46.],
          [ 82.,  84., 110.,  67.,  58., 122.,  73.,  nan,  66.],
          [ nan,  63.,  75.,  87., 123., 162.,  93.,  nan,  nan],
          [101.,  83.,  94.,  97., 154.,  nan, 102.,  nan, 108.],
          [111., 226., 168.,  nan,  na