In [None]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
import torchvision.models as models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
import os
import numpy as np
import cv2
import math
from tqdm import tqdm

In [None]:
IMG_SIZE = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((IMG_SIZE, IMG_SIZE)),
            v2.Normalize(mean=[0.5,], std=[0.5,]),
            v2.Grayscale(num_output_channels=3)
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)

                img = self.transform(img)

                idx = torch.Tensor([idx])
                self.data.append((img.to(device), idx.to(device)))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
batch_size = 16

image_path = os.path.join('data', 'train_set')
train_data = ImageDataset(image_path, device)
train_dl = DataLoader(train_data, batch_size, shuffle=True)

# image_path = os.path.join('data', 'test_set')
# test_data = ImageDataset(image_path, device)
# test_dl = DataLoader(train_data, batch_size, shuffle=True)

In [None]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width, device):
        cht, cwt = height//2, width//2
        mask_radius = 30
        
        mask = torch.zeros((height, width))
        fy, fx = torch.meshgrid(torch.arange(0, height, device=device),
                                torch.arange(0, width, device=device),
                                indexing='ij')
        mask_area = torch.sqrt((fx - cwt)**2 + (fy - cht)**2)

        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()
        filtered_fft = fft_x * mask
        
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # apply fft on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)

        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)

        # apply fft filter (low pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width, x.device)
        
        # perform element wise complex multiplcation
        # fft_output = torch.sum(fft_x * kernel_fft, dim=2)

        # trying einstein summation notation for implementing element wise complex multiplcation instead
        fft_output = torch.einsum('bixy,oixy->boxy', fft_x, kernel_fft)
        
        # apply inverse fft
        fft_output = fft.ifftshift(fft_output, dim=(-2,-1))
        spatial_output = fft.ifft2(fft_output, dim=(-2,-1)).real
        
        # return output
        return spatial_output

In [None]:
def switch_conv_layers(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            # if 'Conv2d_1a' in name or 'Conv2d_2a' in name or 'Conv2d_2b' in name:
            # if 'Inception3' in name or 'Inception4' in name:
            if 'Inception3' in name or 'Inception4' in name or 'Inception5' in name:
                conv_fft = FFTConvNet(module, 'low')
                setattr(model, name, conv_fft)
        elif isinstance(module, nn.Sequential) or isinstance(module, nn.Module):
            switch_conv_layers(module)
    return model

In [None]:
# googlenet
learning_rate = 0.001
weight_decay = 0.000001

model = models.googlenet(weights=None, init_weights=True)

# disable auxillary branches since we're working with size of images lower than 256
model.aux1 = None
model.aux2 = None

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1024),
    nn.LeakyReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 3),
    nn.LogSoftmax(dim=1)
)

model = switch_conv_layers(model)

model = model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

print(optimizer, loss_fn)

In [None]:
model

In [None]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_accuracy = 0.0

    model.train()
    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}')
        running_loss = 0
        running_corrects = 0
        total_entries = 0

        for images, labels in tqdm(train_dl):
            labels = labels.squeeze().long()
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs[0], labels)
            _, prediction = torch.max(outputs[0], 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum((prediction == labels)).item()
            total_entries += labels.size(0)

        epoch_loss = running_loss / total_entries
        epoch_accuracy = running_corrects / total_entries
        print(f'Epoch Loss: {epoch_loss:0.4f}\tEpoch Accuracy:{epoch_accuracy:0.4f}')

        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(model.state_dict(), os.path.join('models', f'fft_google_model_e{epochs}.pth'))

    print('Training Complete.')

In [None]:
train(model, train_dl, loss_fn, optimizer, epochs=20)

In [None]:
def test(model, test_dl):
    running_corrects = 0
    running_accuracy = 0
    total_entries = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_dl):
            labels = labels.squeeze(0).long()
            
            outputs = model(images)
            _, prediction = torch.max(outputs[0], 1)

            running_corrects += torch.sum((prediction == labels)).item()
            total_entries += labels.size(0)
        running_accuracy = running_corrects / total_entries
        print(f'Model Accuracy: {running_accuracy:0.6f}')
        print('Testing Complete.')

In [None]:
test(model, test_dl)