In [1]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
import torchvision.models as models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
import os
import numpy as np
import cv2
import math
from tqdm import tqdm

In [3]:
IMG_SIZE = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
class ImageDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((IMG_SIZE, IMG_SIZE)),
            v2.Normalize(mean=[0.5,], std=[0.5,]),
            v2.Grayscale(num_output_channels=3)
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)

                img = self.transform(img)

                idx = torch.Tensor([idx])
                self.data.append((img.to(device), idx.to(device)))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [5]:
batch_size = 24

image_path = os.path.join('data', 'train_set')
train_data = ImageDataset(image_path, device)
train_dl = DataLoader(train_data, batch_size, shuffle=True)
# image_path = os.path.join('data', 'test_set')
# test_data = ImageDataset(image_path, device)

data\train_set\bacterial


100%|██████████████████████████████████████████████████████████████████████████████| 3001/3001 [01:02<00:00, 47.70it/s]


data\train_set\normal


100%|██████████████████████████████████████████████████████████████████████████████| 3270/3270 [01:28<00:00, 36.90it/s]


data\train_set\viral


100%|██████████████████████████████████████████████████████████████████████████████| 1656/1656 [00:32<00:00, 51.56it/s]


In [6]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width, device):
        cht, cwt = height//2, width//2
        mask_radius = 30
        
        mask = torch.zeros((height, width))
        fy, fx = torch.meshgrid(torch.arange(0, height, device=device),
                                torch.arange(0, width, device=device),
                                indexing='ij')
        mask_area = torch.sqrt((fx - cwt)**2 + (fy - cht)**2)

        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()
        filtered_fft = fft_x * mask
        
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # apply fft on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)

        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)

        # apply fft filter (low pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width, x.device)
        
        # perform element wise complex multiplcation
        # fft_output = torch.sum(fft_x * kernel_fft, dim=2)

        # trying einstein summation notation for implementing element wise complex multiplcation instead
        fft_output = torch.einsum('bixy,oixy->boxy', fft_x, kernel_fft)
        
        # apply inverse fft
        fft_output = fft.ifftshift(fft_output, dim=(-2,-1))
        spatial_output = fft.ifft2(fft_output, dim=(-2,-1)).real
        
        # return output
        return spatial_output

In [11]:
def change_layer(layer):
    fft_conv = FFTConvNet(layer, 'low')
    return fft_conv

In [12]:
# googlenet
learning_rate = 0.001
weight_decay = 0.000001

model = models.googlenet(weights=None, init_weights=True)

# disable auxillary branches since we're working with size of images lower than 256
model.aux1 = None
model.aux2 = None

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1024),
    nn.LeakyReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 3),
    nn.LogSoftmax(dim=1)
)

for name, module in model.named_modules():
    if name.startswith('inception') and isinstance(module, nn.Conv2d):
        fft_conv = change_layer(module)
        
        parent_name, attr_name = name.rsplit(".", 1)
        parent_module = dict(model.named_modules())[parent_name]
        
        setattr(parent_module, attr_name, fft_conv)

model = model.to(device)

In [13]:
model.inception3a

Inception(
  (branch1): BasicConv2d(
    (conv): FFTConvNet(
      (conv_layer): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch2): Sequential(
    (0): BasicConv2d(
      (conv): FFTConvNet(
        (conv_layer): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicConv2d(
      (conv): FFTConvNet(
        (conv_layer): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (branch3): Sequential(
    (0): BasicConv2d(
      (conv): FFTConvNet(
        (conv_layer): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, 

In [14]:
optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

print(optimizer, loss_fn)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 1e-06
) CrossEntropyLoss()


In [15]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_accuracy = 0.0

    model.train()
    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}')
        running_loss = 0
        running_corrects = 0
        total_entries = 0

        for images, labels in tqdm(train_dl):
            labels = labels.squeeze().long()
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs[0], labels)
            _, prediction = torch.max(outputs[0], 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum((prediction == labels)).item()
            total_entries += labels.size(0)

        epoch_loss = running_loss / total_entries
        epoch_accuracy = running_corrects / total_entries
        print(f'Epoch Loss: {epoch_loss:0.4f}\tEpoch Accuracy:{epoch_accuracy:0.4f}')

        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(model.state_dict(), 'fft_google_model_e1.pth')

    print('Training Complete.')

In [16]:
train(model, train_dl, loss_fn, optimizer, epochs=1)

Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 331/331 [05:31<00:00,  1.00s/it]

Epoch Loss: 0.6840	Epoch Accuracy:0.7010
Training Complete.





In [None]:
def test(model, test_dl):
    corrects = 0
    total_entries = 0
    model_accuracy = 0
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_dl):
            labels = labels.squeeze().long()
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)

            corrects += torch.sum(predictions == labels).item()
            total_entries += labels.size(0)
        model_accuracy = corrects / total_entries
        print(f'Model Accuracy: {model_accuracy}')
        
    print('Testing Complete')

In [None]:
test(model, test_dl)