In [1]:
import torch
import torch.nn as nn
import torch.fft as fft
import torchvision
import torchvision.models as models
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [2]:
import os
import numpy as np
import cv2
import math
from tqdm import tqdm

In [3]:
IMG_SIZE = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
class ImageDataset(Dataset):
    def __init__(self, image_path, device):
        self.image_path = image_path
        self.device = device
        self.data = []
        self.transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize((IMG_SIZE, IMG_SIZE)),
            v2.Normalize(mean=[0.5,], std=[0.5,]),
            v2.Grayscale(num_output_channels=3)
        ])
        self.load_data()

    def load_data(self):
        for idx, label in enumerate(os.listdir(os.path.join(self.image_path))):
            print(os.path.join(self.image_path, label))
            for img_file in tqdm(os.listdir(os.path.join(self.image_path, label))):
                img = cv2.imread(os.path.join(self.image_path, label, img_file), cv2.IMREAD_GRAYSCALE)

                img = self.transform(img)

                idx = torch.Tensor([idx])
                self.data.append((img, idx))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [9]:
batch_size = 24

image_path = os.path.join('data', 'train_set')
data = ImageDataset(image_path, device)

data\train_set\bacterial


100%|██████████████████████████████████████████████████████████████████████████████| 3001/3001 [00:53<00:00, 56.01it/s]


data\train_set\normal


100%|██████████████████████████████████████████████████████████████████████████████| 3270/3270 [01:26<00:00, 37.91it/s]


data\train_set\viral


100%|██████████████████████████████████████████████████████████████████████████████| 2706/2706 [00:34<00:00, 78.21it/s]


In [10]:
train_data, val_data = train_test_split(data, test_size=0.1)
train_dl = DataLoader(train_data, batch_size, shuffle=True)
val_dl = DataLoader(val_data, batch_size, shuffle=True)

In [11]:
len(train_data), len(val_data), len(data)

(8079, 898, 8977)

In [12]:
class FFTConvNet(nn.Module):
    def __init__(self, conv_layer, fft_filter=None):
        super().__init__()
        self.conv_layer = conv_layer  # Original Conv2d layer
        self.fft_filter = fft_filter

    def fft_filter_def(self, fft_x, height, width):
        cht, cwt = height // 2, width // 2
        mask_radius = 30

        # Create a meshgrid for the mask
        fy, fx = torch.meshgrid(
            torch.arange(0, height, device=fft_x.device),
            torch.arange(0, width, device=fft_x.device),
            indexing='ij'
        )
        mask_area = torch.sqrt((fx - cwt) ** 2 + (fy - cht) ** 2)

        # Create the mask based on the filter type
        if self.fft_filter == 'high':
            mask = (mask_area > mask_radius).float()
        else:
            mask = (mask_area <= mask_radius).float()

        # Apply the mask to the FFT of the input
        filtered_fft = fft_x * mask
        return filtered_fft

    def forward(self, x):
        batch_size, in_channels, height, width = x.size()
        out_channels = self.conv_layer.out_channels  # Number of output channels

        # Apply FFT on input image
        fft_x = fft.fft2(x)  # Shape: [batch_size, in_channels, height, width]
        fft_x = fft.fftshift(fft_x)

        # Apply FFT on the convolutional kernel
        kernel_fft = fft.fft2(self.conv_layer.weight, s=(height, width))  # Shape: [out_channels, in_channels, height, width]
        kernel_fft = fft.fftshift(kernel_fft)

        # Apply FFT filter (low-pass or high-pass)
        if self.fft_filter is not None:
            fft_x = self.fft_filter_def(fft_x, height, width)

        # Perform element-wise complex multiplication
        fft_output = fft_x.unsqueeze(1) * kernel_fft.unsqueeze(0)  # Broadcast multiplication
        fft_output = torch.sum(fft_output, dim=2)  # Sum over input channels

        # Apply inverse FFT
        fft_output = fft.ifftshift(fft_output, dim=(-2, -1))
        spatial_output = fft.ifft2(fft_output, dim=(-2, -1)).real  # Shape: [batch_size, out_channels, height, width]

        # Add bias (if applicable)
        if self.conv_layer.bias is not None:
            spatial_output += self.conv_layer.bias.view(1, -1, 1, 1)

        # Debug: Print shapes
        # print(f"Input shape: {x.shape}")
        # print(f"Output shape: {spatial_output.shape}")

        # Return output
        return spatial_output

In [13]:
# Function to replace Conv2d with FFTConvNet
def change_layer(layer):
    fft_conv = FFTConvNet(layer, 'low')
    return fft_conv

In [14]:
# Load GoogleNet
learning_rate = 0.001
weight_decay = 0.000001

model = models.googlenet(weights=None, init_weights=True)

# Disable auxiliary branches
model.aux1 = None
model.aux2 = None

# Modify the fully connected layer
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1024),
    nn.LeakyReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 3),
    nn.LogSoftmax(dim=1)
)

# Replace Conv2d layers with FFTConvNet
for name, module in model.named_modules():
    if name.startswith('inception') and isinstance(module, nn.Conv2d):
        # Replace only layers with kernel size > 1x1
        if module.kernel_size[0] > 1:
            fft_conv = change_layer(module).to(device)  # Move to GPU
            parent_name, attr_name = name.rsplit(".", 1)
            parent_module = dict(model.named_modules())[parent_name]
            setattr(parent_module, attr_name, fft_conv)

# Move the model to the GPU
model = model.to(device)

In [17]:
model.inception3a

Inception(
  (branch1): BasicConv2d(
    (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch2): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicConv2d(
      (conv): FFTConvNet(
        (conv_layer): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (branch3): Sequential(
    (0): BasicConv2d(
      (conv): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicConv2d(
      (conv): FFTConvNet(
        (conv_layer): Conv2d(16, 32, kerne

In [18]:
optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

print(optimizer, loss_fn)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 1e-06
) CrossEntropyLoss()


In [20]:
def train(model, train_dl, loss_fn, optimizer, epochs):
    best_acc = 0.0
    model.train()
    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        for images, labels in tqdm(train_dl):
            # print("Data loaded")

            # Move data to GPU
            labels = labels.squeeze().long()
            images = images.to(device)
            labels = labels.to(device)
            # print("Data moved to GPU")

            # Forward pass
            outputs = model(images)
            # print("Forward pass completed")

            # Compute loss
            loss = loss_fn(outputs[0], labels)
            # print("Loss computed")

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print("Backward pass completed")

            # Compute metrics
            _, preds = torch.max(outputs[0], 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += labels.size(0)
            
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples
        print(f"Epoch Loss: {epoch_loss:.4f}, Epoch Accuracy: {epoch_acc:.4f}")
    print('Training Complete.')

In [21]:
train(model, train_dl, loss_fn, optimizer, epochs=20)

Epoch [1/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:32<00:00,  2.95s/it]


Epoch Loss: 0.7808, Epoch Accuracy: 0.5969
Epoch [2/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.6613, Epoch Accuracy: 0.6552
Epoch [3/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.5905, Epoch Accuracy: 0.7241
Epoch [4/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.5393, Epoch Accuracy: 0.7583
Epoch [5/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.5020, Epoch Accuracy: 0.7851
Epoch [6/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.4754, Epoch Accuracy: 0.7896
Epoch [7/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.4614, Epoch Accuracy: 0.7996
Epoch [8/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:30<00:00,  2.94s/it]


Epoch Loss: 0.4458, Epoch Accuracy: 0.8023
Epoch [9/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:18<00:00,  2.90s/it]


Epoch Loss: 0.4319, Epoch Accuracy: 0.8112
Epoch [10/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:28<00:00,  2.93s/it]


Epoch Loss: 0.4259, Epoch Accuracy: 0.8143
Epoch [11/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.94s/it]


Epoch Loss: 0.4025, Epoch Accuracy: 0.8249
Epoch [12/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.93s/it]


Epoch Loss: 0.4043, Epoch Accuracy: 0.8251
Epoch [13/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.94s/it]


Epoch Loss: 0.4014, Epoch Accuracy: 0.8258
Epoch [14/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:28<00:00,  2.93s/it]


Epoch Loss: 0.3764, Epoch Accuracy: 0.8360
Epoch [15/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.93s/it]


Epoch Loss: 0.3799, Epoch Accuracy: 0.8357
Epoch [16/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.93s/it]


Epoch Loss: 0.3735, Epoch Accuracy: 0.8406
Epoch [17/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.94s/it]


Epoch Loss: 0.3598, Epoch Accuracy: 0.8454
Epoch [18/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:29<00:00,  2.94s/it]


Epoch Loss: 0.3482, Epoch Accuracy: 0.8523
Epoch [19/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:28<00:00,  2.93s/it]


Epoch Loss: 0.3310, Epoch Accuracy: 0.8574
Epoch [20/20]


100%|████████████████████████████████████████████████████████████████████████████████| 337/337 [16:08<00:00,  2.87s/it]

Epoch Loss: 0.3218, Epoch Accuracy: 0.8616
Training Complete.





In [22]:
torch.save(model.state_dict(), os.path.join('models', 'fft_googlenet_final.pth'))

In [28]:
def validate(model, val_dl):
    corrects = 0
    total_entries = 0
    model_accuracy = 0
    
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(val_dl):
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.squeeze().long()
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)

            corrects += torch.sum(predictions == labels).item()
            total_entries += labels.size(0)
        model_accuracy = corrects / total_entries
        print(f'Model Accuracy: {model_accuracy}')
        
    print('Validation Complete')

In [29]:
validate(model, val_dl)

100%|██████████████████████████████████████████████████████████████████████████████████| 38/38 [00:35<00:00,  1.07it/s]

Model Accuracy: 0.8585746102449888
Validation Complete



