In [1]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
import torchvision.models as models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [2]:
import os
import numpy as np
import cv2
import math
from tqdm import tqdm

In [3]:
IMG_SIZE = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
class ImageDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img_gray = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)
                img = cv2.resize(img_gray, (IMG_SIZE, IMG_SIZE))
                img = self.transform(img)

                idx = torch.Tensor([idx])
                self.data.append((img.to(device), idx.to(device)))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [5]:
batch_size = 8

image_path = os.path.join('data', 'train_set')
train_data = ImageDataset(image_path, device)
train_dl = DataLoader(train_data, batch_size, shuffle=True)

image_path = os.path.join('data', 'test_set')
test_data = ImageDataset(image_path, device)
test_dl = DataLoader(test_data, batch_size, shuffle=False)

data\train_set\bacterial


100%|██████████████████████████████████████████████████████████████████████████████| 3001/3001 [00:50<00:00, 58.93it/s]


data\train_set\normal


100%|██████████████████████████████████████████████████████████████████████████████| 3270/3270 [01:19<00:00, 41.36it/s]


data\train_set\viral


100%|██████████████████████████████████████████████████████████████████████████████| 1656/1656 [00:30<00:00, 55.11it/s]


data\test_set\bacterial


100%|███████████████████████████████████████████████████████████████████████████████| 242/242 [00:01<00:00, 186.81it/s]


data\test_set\normal


100%|████████████████████████████████████████████████████████████████████████████████| 234/234 [00:05<00:00, 46.31it/s]


data\test_set\viral


100%|███████████████████████████████████████████████████████████████████████████████| 148/148 [00:01<00:00, 136.63it/s]


In [6]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer

    def fft_filter_def(self, fft_x, height, width):
        cht, cwt = height//2, width//2
        mask_radius = 30
        
        mask = torch.zeros((height, width))
        fy, fx = torch.meshgrid(torch.arange(0, height, device=device), torch.arange(0, width, device=device), indexing='ij')
        mask_area = torch.sqrt((fx - cwt)**2 + (fy - cht)**2)

        if fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()
        filtered_fft = fft_x * mask
        
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # apply fft on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)

        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)

        # apply fft filter (low pass)
        if fft_filter is not None:
            fft_x = fft_filter_def(fft_x, height, width)
        
        # perform element wise complex multiplcation
        # fft_output = fft_x * kernel_fft

        fft_output = torch.einsum('bixy,oixy->boxy', fft_x, kernel_fft)
        
        # apply inverse fft
        fft_output = fft.ifftshift(fft_out)
        spatial_output = fft.ifft2(fft_out)
        
        # return output
        return spatial_output

In [7]:
def switch_conv_layers(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            # if "Conv2d_1a" in name or "Conv2d_2a" in name or "Conv2d_2b" in name:
            if "Inception3" in name or "Inception4" in name or "Inception5" in name:
                conv_fft = FFTConvNet(module, 'low')
                setattr(model, name, conv_fft)
        elif isinstance(module, nn.Sequential) or isinstance(module, nn.Module):
            switch_conv_layers(module)
    return model

In [8]:
# googlenet
model = models.googlenet(weights=None, init_weights=True)

# disable auxillary branches since we're working with size of images lower than 256
model.aux1 = None
model.aux2 = None

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1024),
    nn.LeakyReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 3),
    nn.LogSoftmax(dim=1)
)

# switch nn.Conv2d layers to the custom FFTConvNet layers
model = switch_conv_layers(model)
model = model.to(device)

In [9]:
learning_rate = 0.001
weight_decay = 0.000001

optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

print(optimizer, loss_fn)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 1e-06
) CrossEntropyLoss()


In [12]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_accuracy = 0.0

    model.train()
    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}')
        running_loss = 0
        running_corrects = 0
        total_entries = 0

        for images, labels in tqdm(train_dl):
            labels = labels.squeeze().long()
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs[0], labels)
            _, prediction = torch.max(outputs[0], 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum((prediction == labels)).item()
            total_entries += labels.size(0)

        epoch_loss = running_loss / total_entries
        epoch_accuracy = running_corrects / total_entries
        print(f'Epoch Loss: {epoch_loss}\tEpoch Accuracy:{epoch_accuracy}')

        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(model.state_dict(), os.path.join('models', f'fft_google_model_e{epochs}.pth'))

    print('Training Complete.')
    

In [13]:
train(model, train_dl, loss_fn, optimizer, epochs=20)

Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.50it/s]


Epoch Loss: 0.7034156816866203	Epoch Accuracy:0.6981203481771162
Epoch: 2


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [00:56<00:00, 17.51it/s]


Epoch Loss: 0.6116966794745324	Epoch Accuracy:0.7340734199571086
Epoch: 3


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [00:57<00:00, 17.22it/s]


Epoch Loss: 0.5826395503717006	Epoch Accuracy:0.7553929607669989
Epoch: 4


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [00:58<00:00, 16.98it/s]


Epoch Loss: 0.5441184398715007	Epoch Accuracy:0.7692695849627854
Epoch: 5


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [00:59<00:00, 16.73it/s]


Epoch Loss: 0.5050358860139605	Epoch Accuracy:0.7876876498044657
Epoch: 6


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [00:59<00:00, 16.59it/s]


Epoch Loss: 0.5200941762249572	Epoch Accuracy:0.7828939069004668
Epoch: 7


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.25it/s]


Epoch Loss: 0.5033794308432087	Epoch Accuracy:0.786678440772045
Epoch: 8


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.27it/s]


Epoch Loss: 0.4738137605467886	Epoch Accuracy:0.7943736596442538
Epoch: 9


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.29it/s]


Epoch Loss: 0.4690879781940658	Epoch Accuracy:0.8033303898069888
Epoch: 10


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:01<00:00, 16.22it/s]


Epoch Loss: 0.4395810821960323	Epoch Accuracy:0.8112779109373028
Epoch: 11


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:01<00:00, 16.15it/s]


Epoch Loss: 0.43105847287090815	Epoch Accuracy:0.8115302131954081
Epoch: 12


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:01<00:00, 16.04it/s]


Epoch Loss: 0.5627564734102454	Epoch Accuracy:0.7647281443168916
Epoch: 13


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.79it/s]


Epoch Loss: 0.4853103851496053	Epoch Accuracy:0.7980320423867794
Epoch: 14


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.80it/s]


Epoch Loss: 0.442538764202288	Epoch Accuracy:0.8121609688406711
Epoch: 15


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.89it/s]


Epoch Loss: 0.4089081070135743	Epoch Accuracy:0.825280686262142
Epoch: 16


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.76it/s]


Epoch Loss: 0.3924959112596464	Epoch Accuracy:0.8271729531979312
Epoch: 17


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:03<00:00, 15.64it/s]


Epoch Loss: 0.37832326311263276	Epoch Accuracy:0.8348681720701401
Epoch: 18


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:03<00:00, 15.72it/s]


Epoch Loss: 0.3595465754936679	Epoch Accuracy:0.8487447962659266
Epoch: 19


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.75it/s]


Epoch Loss: 0.3485470133938973	Epoch Accuracy:0.8489970985240318
Epoch: 20


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:03<00:00, 15.67it/s]

Epoch Loss: 0.33766414639597	Epoch Accuracy:0.8554308061057146
Training Complete.





In [16]:
def test(model, test_dl):
    corrects = 0
    total_entries = 0
    model_accuracy = 0
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_dl):
            labels = labels.squeeze().long()
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)

            corrects += torch.sum(predictions == labels).item()
            total_entries += labels.size(0)
        model_accuracy = corrects / total_entries
        print(f'Model Accuracy: {model_accuracy}')
        
    print('Testing Complete')
        

In [17]:
test(model, test_dl)

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 69.86it/s]

Model Accuracy: 0.8028846153846154
Testing Complete



