In [None]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
from torchvision.transforms import v2
import matplotlib.pyplot as plt

In [None]:
import os
import math
import numpy
from tqdm import tqdm
import cv2
import scipy

In [None]:
def load_image(path):
    transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Resize((256,256)),
        v2.Normalize(mean=[0.5,], std=[0.5,]),
        v2.Grayscale(num_output_channels=3)
    ])
    image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    image = transform(image)
    return image

In [None]:
path = os.path.join('image.jpg')
image = load_image(path)
fft_image = fft.fft2(image)
# fft_shift = fft.fftshift(fft_image, dim=(-2,-1))
fft_image = fft.fftshift(fft_image)
magnitude = torch.log(torch.abs(fft_image)+ 1e-10)

In [None]:
image.shape

In [None]:
plt.subplot(1,2,1)
plt.title('Original Image')
plt.imshow(image.permute(1,2,0), cmap='gray')
plt.axis('off')

plt.subplot(1,2,2)
plt.title('Magnitude Spectrum')
plt.imshow(magnitude.permute(1,2,0), cmap='gray')
plt.axis('off')

plt.show()

In [None]:
image.shape

In [None]:
channel, rows, cols = image.shape
crow, ccol = rows//2, cols//2
radius = 50

mask = torch.zeros((rows, cols))
y, x = torch.meshgrid(torch.arange(0, rows), torch.arange(0, cols), indexing='ij')
mask_area = torch.sqrt((x - ccol)**2 + (y - crow)**2)

mask_h = (mask_area > radius).float()
mask_l = (mask_area <= radius).float()

In [None]:
high_filtered_fft = fft_image * mask_h
high_filtered_image = torch.abs(fft.ifftshift(high_filtered_fft))
magnitude_high_filtered_image = torch.log(torch.abs(high_filtered_image)+ 1)

low_filtered_fft = fft_image * mask_l
low_filtered_image = torch.abs(fft.ifftshift(low_filtered_fft))
magnitude_low_filtered_image = torch.log(torch.abs(low_filtered_image)+ 1)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes = axes.flatten()

axes[0].set_title('Original')
axes[0].imshow(image.permute(1,2,0), cmap='gray')
axes[0].axis('off')

axes[1].set_title('FFT')
axes[1].imshow(magnitude.permute(1,2,0), cmap='gray')
axes[1].axis('off')

axes[2].set_title('Low Pass Filter')
axes[2].imshow(magnitude_low_filtered_image.permute(1,2,0), cmap='gray')
axes[2].axis('off')

axes[3].set_title('High Pass Filter')
axes[3].imshow(magnitude_high_filtered_image.permute(1,2,0), cmap='gray')
axes[3].axis('off')

plt.show()

In [None]:
import torchvision.models as models

In [None]:
google = models.googlenet(weights=None, init_weights=True)

In [None]:
image = image.expand(1,-1,-1,-1)

In [None]:
google

In [None]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width, device):
        cht, cwt = height//2, width//2
        mask_radius = 30
        
        mask = torch.zeros((height, width))
        fy, fx = torch.meshgrid(torch.arange(0, height, device=device),
                                torch.arange(0, width, device=device),
                                indexing='ij')
        mask_area = torch.sqrt((fx - cwt)**2 + (fy - cht)**2)

        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()
        filtered_fft = fft_x * mask
        
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # apply fft on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)
        print(f'1. fft_x: {fft_x.shape}')

        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)
        print(f'2. kernel_fft: {kernel_fft.shape}')

        kernel_fft = kernel_fft
        print(f'3. fft_x: {fft_x.shape}')
        print(f'4. kernel_fft: {kernel_fft.shape}')

        # apply fft filter (low pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width, x.device)

        print(f'5. fft_x: {fft_x.shape}')
        
        # perform element wise complex multiplcation
        # fft_output = torch.sum(fft_x * kernel_fft, dim=2)
        
        # trying out einsum complex multiplication for bixy * ioxy -> boxy instead
        fft_output = torch.einsum('bixy,oixy->boxy', fft_x, kernel_fft)
        print(f'6. fft_output: {fft_output.shape}')
        
        # apply inverse fft
        fft_output = fft.ifftshift(fft_output)
        spatial_output = fft.ifft2(fft_output).real
        print(f'7. spatial_output: {spatial_output.shape}')
        
        # return output
        return spatial_output

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.act1 = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=0.2)

        self.f_layer = nn.Flatten()
        self.fc1 = nn.Linear(4194304, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, inp):
        x = self.layer1(inp)
        print(f'Conv: {x.shape}')
        x = self.dropout(x)
        print(f'Drop: {x.shape}')
        x = self.act1(x)
        print(f'Act: {x.shape}')
        x = self.f_layer(x)
        print(f'Flat: {x.shape}')
        
        x = self.fc1(x)
        print(f'FC1: {x.shape}')
        x = self.fc2(x)
        print(f'FC2: {x.shape}')

        x = nn.functional.log_softmax(x, dim=1)
        print(f'OP: {x.shape}')

        return x

In [None]:
model = ConvNet()

In [None]:
for name, module in model.named_children():
    if isinstance(module, nn.Conv2d):
        fft_conv = FFTConvNet(module, 'low')
        setattr(model, name, fft_conv)

In [None]:
model

In [None]:
output = model(image)

In [None]:
output

In [None]:
def search_conv(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            conv_fft = FFTConvNet(module, 'low')
            setattr(model, name, conv_fft)
        elif isinstance(module, nn.Sequential) or isinstance(module, nn.Module):
            search_conv(module)
    return model

In [None]:
test_model = search_conv(google)
test_model

In [None]:
google_op = google(image)