In [1]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
import torchvision.models as models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [2]:
import os
import numpy as np
import cv2
import math
from tqdm import tqdm

In [3]:
IMG_SIZE = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
class ImageDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img_gray = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)
                img = cv2.resize(img_gray, (IMG_SIZE, IMG_SIZE))
                img = self.transform(img)

                idx = torch.Tensor([idx])
                self.data.append((img.to(device), idx.to(device)))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [5]:
batch_size = 8

image_path = os.path.join('data', 'train_set')
train_data = ImageDataset(image_path, device)
train_dl = DataLoader(train_data, batch_size, shuffle=True)

image_path = os.path.join('data', 'test_set')
test_data = ImageDataset(image_path, device)
test_dl = DataLoader(test_data, batch_size, shuffle=False)

data\train_set\bacterial


100%|██████████████████████████████████████████████████████████████████████████████| 3001/3001 [00:49<00:00, 60.92it/s]


data\train_set\normal


100%|██████████████████████████████████████████████████████████████████████████████| 3270/3270 [01:19<00:00, 41.37it/s]


data\train_set\viral


100%|██████████████████████████████████████████████████████████████████████████████| 1656/1656 [00:30<00:00, 54.38it/s]


data\test_set\bacterial


100%|███████████████████████████████████████████████████████████████████████████████| 242/242 [00:01<00:00, 209.32it/s]


data\test_set\normal


100%|████████████████████████████████████████████████████████████████████████████████| 234/234 [00:04<00:00, 47.25it/s]


data\test_set\viral


100%|███████████████████████████████████████████████████████████████████████████████| 148/148 [00:01<00:00, 139.31it/s]


In [7]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer

    def fft_filter_def(self, fft_x, height, width):
        cht, cwt = height//2, width//2
        mask_radius = 30
        
        mask = torch.zeros((height, width))
        fy, fx = torch.meshgrid(torch.arange(0, height, device=device), torch.arange(0, width, device=device), indexing='ij')
        mask_area = torch.sqrt((fx - cwt)**2 + (fy - cht)**2)

        if fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()
        filtered_fft = fft_x * mask
        
        return filtered_fft

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # apply fft on input image
        fft_x = fft.fft2(x)
        fft_x = fft.fftshift(fft_x)

        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))
        kernel_fft = fft.fftshift(kernel_fft)

        # apply fft filter (low pass)
        if fft_filter is not None:
            fft_x = fft_filter_def(fft_x, height, width)
        
        # perform element wise complex multiplcation
        # fft_output = fft_x * kernel_fft

        fft_output = torch.einsum('bixy,oixy->boxy', fft_x, kernel_fft)
        
        # apply inverse fft
        fft_output = fft.ifftshift(fft_out)
        spatial_output = fft.ifft2(fft_out)
        
        # return output
        return spatial_output

In [10]:
def switch_conv_layers(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            # if "Conv2d_1a" in name or "Conv2d_2a" in name or "Conv2d_2b" in name:
            if "Inception3" in name or "Inception4" in name or "Inception5" in name:
                conv_fft = FFTConvNet(module, 'low')
                setattr(model, name, conv_fft)
        elif isinstance(module, nn.Sequential) or isinstance(module, nn.Module):
            switch_conv_layers(module)
    return model

In [11]:
# googlenet
model = models.googlenet(weights='GoogLeNet_Weights.DEFAULT')

# disable auxillary branches since we're working with size of images lower than 256
model.aux1 = None
model.aux2 = None

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1024),
    nn.LeakyReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 3),
    nn.LogSoftmax(dim=1)
)
model = switch_conv_layers(model)
model = model.to(device)

In [12]:
learning_rate = 0.001
weight_decay = 0.000001

optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

print(optimizer, loss_fn)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 1e-06
) CrossEntropyLoss()


In [13]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_accuracy = 0.0

    model.train()
    for epoch in range(epochs):
        print(f'Epoch: {epoch+1}')
        running_loss = 0
        running_corrects = 0
        total_entries = 0

        for images, labels in tqdm(train_dl):
            labels = labels.squeeze().long()
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(outputs, labels)
            _, prediction = torch.max(outputs, 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum((prediction == labels)).item()
            total_entries += labels.size(0)

        epoch_loss = running_loss / total_entries
        epoch_accuracy = running_corrects / total_entries
        print(f'Epoch Loss: {epoch_loss}\tEpoch Accuracy:{epoch_accuracy}')

        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(model.state_dict(), os.path.join('models', f'fft_google_model_e{epochs}.pth'))

    print('Training Complete.')
    

In [15]:
train(model, train_dl, loss_fn, optimizer, epochs=20)

Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:48<00:00,  9.15it/s]


Epoch Loss: 0.5784185090649394	Epoch Accuracy:0.7606913081872083
Epoch: 2


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:08<00:00, 14.49it/s]


Epoch Loss: 0.4601432132182573	Epoch Accuracy:0.8101425507758294
Epoch: 3


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:04<00:00, 15.28it/s]


Epoch Loss: 0.4413192358149456	Epoch Accuracy:0.8081241327109877
Epoch: 4


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:11<00:00, 13.85it/s]


Epoch Loss: 0.39775885756242557	Epoch Accuracy:0.8279298599722468
Epoch: 5


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:13<00:00, 13.53it/s]


Epoch Loss: 0.38980057885205593	Epoch Accuracy:0.8331020562634036
Epoch: 6


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.40it/s]


Epoch Loss: 0.35705858599091445	Epoch Accuracy:0.8467263782010849
Epoch: 7


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.47it/s]


Epoch Loss: 0.34692015104024937	Epoch Accuracy:0.852150876750347
Epoch: 8


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.31it/s]


Epoch Loss: 0.3398790622692331	Epoch Accuracy:0.8556831083638199
Epoch: 9


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.42it/s]


Epoch Loss: 0.3007652752982107	Epoch Accuracy:0.8670367099785543
Epoch: 10


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.38it/s]


Epoch Loss: 0.32381051698431956	Epoch Accuracy:0.8631260249779236
Epoch: 11


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.36it/s]


Epoch Loss: 0.26950852013589827	Epoch Accuracy:0.8863378327236029
Epoch: 12


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.50it/s]


Epoch Loss: 0.24663916368301314	Epoch Accuracy:0.8994575501450738
Epoch: 13


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.33it/s]


Epoch Loss: 0.20736213461467445	Epoch Accuracy:0.9152264412766494
Epoch: 14


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.26it/s]


Epoch Loss: 0.20461741307193468	Epoch Accuracy:0.9181279172448593
Epoch: 15


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:01<00:00, 16.17it/s]


Epoch Loss: 0.15315448569481757	Epoch Accuracy:0.9420966317648543
Epoch: 16


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:01<00:00, 16.17it/s]


Epoch Loss: 0.13428690529927625	Epoch Accuracy:0.9475211303141163
Epoch: 17


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:00<00:00, 16.25it/s]


Epoch Loss: 0.13065456488466892	Epoch Accuracy:0.9523148732181153
Epoch: 18


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.74it/s]


Epoch Loss: 0.10605637891521147	Epoch Accuracy:0.960388545477482
Epoch: 19


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.88it/s]


Epoch Loss: 0.13427056210043722	Epoch Accuracy:0.9535763845086414
Epoch: 20


100%|████████████████████████████████████████████████████████████████████████████████| 991/991 [01:02<00:00, 15.88it/s]


Epoch Loss: 0.07984952132236355	Epoch Accuracy:0.9708590891888482
Training Complete.


In [24]:
def test(model, test_dl):
    corrects = 0
    total_entries = 0
    model_accuracy = 0
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_dl):
            labels = labels.squeeze().long()
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)

            corrects += torch.sum(predictions == labels).item()
            total_entries += labels.size(0)
        model_accuracy = corrects / total_entries
        print(f'Model Accuracy: {model_accuracy}')
        
    print('Testing Complete')
        

In [25]:
test(model, test_dl)

100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 65.38it/s]

Model Accuracy: 0.8878205128205128
Testing Complete



