In [3]:
import torch
import torch.nn as nn
import torch.fft as fft
from torch.utils.data import Dataset

import torchvision
from torchvision.transforms import v2
from torchvision import models
import matplotlib.pyplot as plt

In [2]:
import os
import math
import numpy
from tqdm import tqdm
import cv2
import lmdb
import pickle

In [None]:
def estimate_lmdb_size(image_path):
    total_size = 0
    for label in os.listdir(image_path):
        label_path = os.path.join(image_path, label)
        if os.path.isdir(label_path):
            for img_file in os.listdir(label_path):
                img_path = os.path.join(label_path, img_file)
                total_size += os.path.getsize(img_path)
    
    buffer_factor = 1.5
    return int(total_size * buffer_factor)

In [None]:
image_path = os.path.join("data", "val_set")
map_size = estimate_lmdb_size(image_path)
print(f"Estimated LMDB size: {map_size / (1024**3):.2f} GB")

In [None]:
def load_image(path):
    transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Resize((128,128))
    ])
    image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    image = transform(image)
    return image

In [None]:
path = os.path.join('..', 'image.jpg')
image = load_image(path)
fft_image = fft.fft2(image, dim=(-2,-1))
fft_shift = fft.fftshift(fft_image, dim=(-2,-1))
# fft_image = fft.fftshift(fft_image, dim=(-2,-1))
magnitude = torch.log(torch.abs(fft_shift)+ 1)

In [None]:
plt.subplot(1,2,1)
plt.title('Original Image')
plt.imshow(image.permute(1,2,0), cmap='gray')
plt.axis('off')

plt.subplot(1,2,2)
plt.title('Magnitude Spectrum')
plt.imshow(magnitude.permute(1,2,0), cmap='gray')
plt.axis('off')

plt.show()

In [None]:
rows, cols = image.squeeze(0).shape
crow, ccol = rows//2, cols//2
radius = 50

mask = torch.zeros((rows, cols))
y, x = torch.meshgrid(torch.arange(0, rows), torch.arange(0, cols), indexing='ij')
mask_area = torch.sqrt((x - ccol)**2 + (y - crow)**2)

mask_h = (mask_area > radius).float()
mask_l = (mask_area <= radius).float()

In [None]:
high_filtered_fft = fft_shift * mask_h
high_filtered_image = torch.abs(fft.ifftshift(high_filtered_fft))
magnitude_high_filtered_image = torch.log(torch.abs(high_filtered_image)+ 1)

low_filtered_fft = fft_shift * mask_l
low_filtered_image = torch.abs(fft.ifftshift(low_filtered_fft))
magnitude_low_filtered_image = torch.log(torch.abs(low_filtered_image)+ 1)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes = axes.flatten()

axes[0].set_title('Original')
axes[0].imshow(image.permute(1,2,0), cmap='gray')
axes[0].axis('off')

axes[1].set_title('FFT')
axes[1].imshow(magnitude.permute(1,2,0), cmap='gray')
axes[1].axis('off')

axes[2].set_title('Low Pass Filter')
axes[2].imshow(magnitude_low_filtered_image.permute(1,2,0), cmap='gray')
axes[2].axis('off')

axes[3].set_title('High Pass Filter')
axes[3].imshow(magnitude_high_filtered_image.permute(1,2,0), cmap='gray')
axes[3].axis('off')

plt.show()

In [None]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width):
        cht, cwt = height // 2, width // 2
        mask_radius = 30

        # Create a meshgrid for the mask
        fy, fx = torch.meshgrid(
            torch.arange(0, height, device=fft_x.device),
            torch.arange(0, width, device=fft_x.device),
            indexing='ij'
        )
        mask_area = torch.sqrt((fx - cwt) ** 2 + (fy - cht) ** 2)

        # Create the mask based on the filter type
        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()

        # Apply the mask to the FFT of the input
        filtered_fft = fft_x * mask
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Apply FFT on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)

        # Apply FFT on the convolutional kernel
        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)

        # Apply FFT filter (low-pass or high-pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width)

        # Perform element-wise complex multiplication
        fft_output = fft_x * kernel_fft

        # Apply inverse FFT
        fft_output = fft.ifftshift(fft_output, dim=(-2, -1))
        spatial_output = fft.ifft2(fft_output, dim=(-2, -1)).real

        # Ensure the output has the same number of channels as the original convolution
        spatial_output = spatial_output[:, :self.conv_layer.out_channels, :, :]

        # Add bias (if applicable)
        if self.conv_layer.bias is not None:
            spatial_output += self.conv_layer.bias.view(1, -1, 1, 1)

        # Return output
        return spatial_output

In [None]:
def switch_conv_layers(model):
    for name, module in model.named_children():
        if name.startswith('inception') and isinstance(module, nn.Conv2d):
            print(name, module)
            # fft_conv = FFTConvNet(module, 'low')
            # setattr(model, name, fft_conv)
        elif isinstance(module, nn.Sequential) or isinstance(module, nn.Module):
            switch_conv_layers(module)
    return model

In [None]:
model = models.googlenet(weights='GoogLeNet_Weights.DEFAULT')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def change_layer(layer):
    fft_conv = FFTConvNet(layer, 'low')
    return fft_conv

In [None]:
# Check the output channels of the original Conv2d layer
original_conv = model.inception3a.branch2[0].conv
print("Original Conv2d output channels:", original_conv.out_channels)

# Replace the Conv2d layer with FFTConvNet
fft_conv = change_layer(original_conv).to(device)
model.inception3a.branch2[0].conv = fft_conv

# Verify the output channels of the FFTConvNet layer
print("FFTConvNet output channels:", fft_conv.conv_layer.out_channels)

In [33]:
model = models.alexnet(weights=None)

In [35]:
model.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [31]:
model.features = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
    nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
    nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
    nn.ReLU(inplace=True),
    nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
    nn.ReLU(inplace=True),
    nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)

model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(9216, out_features=1024),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
    nn.Linear(in_features=1024, out_features=512),
    nn.LeakyReLU(),
    nn.Linear(in_features=512, out_features=3)
)

In [32]:
model

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=1024, bias=True)
   

In [24]:
dummy_input = torch.randn(1, 3, 227, 227)
outputs = model(dummy_input)
print("Outputs: ", outputs)

Outputs:  tensor([[-0.0209,  0.0312, -0.0249]], grad_fn=<AddmmBackward0>)
