In [1]:
import torch
import torch.nn as nn
import torch.fft as fft
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

In [2]:
import torchvision
from torchvision import models
from torchvision.transforms import v2

In [3]:
import matplotlib.pyplot as plt
import os
import numpy as np
import cv2
from tqdm import tqdm
import lmdb
import pickle

In [4]:
from src.FFTConv import *
from src.ImageHandler import *

In [5]:
IMG_SIZE = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
REBUILD_DATA = False

In [6]:
lmdb_path = os.path.join('lmdb')

if REBUILD_DATA:
    image_path = os.path.join('data', 'train_set')
    train_data = ImageDataset(image_path=image_path, device=device, lmdb_path=lmdb_path, save_lmdb=True)
    
    image_path = os.path.join('data', 'val_set')
    val_data = ImageDataset(image_path=image_path, device=device, lmdb_path=lmdb_path, save_lmdb=True, mode="val")
    REBUILD_DATA = False
else:
    train_data = ImageDataset(image_path=None, device=device, lmdb_path=lmdb_path, save_lmdb=False)
    val_data = ImageDataset(image_path=None, device=device, lmdb_path=lmdb_path, save_lmdb=False, mode="val")

Loaded train dataset
Loaded val dataset


In [7]:
len(train_data), len(val_data)

(9959, 4200)

In [8]:
batch_size = 16

train_dl = DataLoader(train_data, batch_size, shuffle=True, pin_memory=True)
val_dl = DataLoader(val_data, batch_size, shuffle=True, pin_memory=True)

In [9]:
for images, labels in train_dl:
    labels = labels.squeeze().long()
    print("Image shape: ", images.shape)
    print("Label shape: ", labels.shape)
    break

for images, labels in val_dl:
    labels = labels.squeeze().long()
    print("Image shape: ", images.shape)
    print("Label shape: ", labels.shape)
    break

Image shape:  torch.Size([16, 3, 128, 128])
Label shape:  torch.Size([16])
Image shape:  torch.Size([16, 3, 128, 128])
Label shape:  torch.Size([16])


In [10]:
learning_rate = 1e-4
weight_decay = 5e-4
# momentum = 0.9

model = FFTAlex(apply_fft=True, device=device, IMG_SIZE=IMG_SIZE)

Total Layers replaced:  1


In [11]:
# Testing if model is working
model.eval()
dummy_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
outputs = model(dummy_input)
print("Output shape: ", outputs.data)

Output shape:  tensor([[-0.0276, -0.0251, -0.0691]], device='cuda:0')


In [12]:
epochs = 10
optimizer = optim.AdamW(model.parameters(), learning_rate, weight_decay=weight_decay)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

print(optimizer, loss_fn)

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.0001
    lr: 0.0001
    maximize: False
    weight_decay: 0.0005
) CrossEntropyLoss()


In [13]:
def train(model, train_dl, val_dl, loss_fn, optimizer, scheduler, epochs, name, device):
    best_acc = 0.0
    tr_acc_list = []
    val_acc_list = []
    
    # Mixed precision training
    scaler = torch.GradScaler(device)  

    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        # Training loop
        model.train()
        for images, labels in tqdm(train_dl):
            # Move data to GPU
            labels = labels.squeeze().long()
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            with torch.autocast(device):
                outputs = model(images)
                loss = loss_fn(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            # Compute metrics
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += labels.size(0)

        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples
        tr_acc_list.append(epoch_acc.cpu().item())
        print(f"Training Loss: {epoch_loss:0.6f}, Training Accuracy: {epoch_acc:0.6f}")

        # Validation loop
        val_loss, val_acc = validate(model, val_dl, loss_fn, device)
        val_acc_list.append(val_acc.cpu().item())

        # Save the best model
        if val_acc > best_acc:
            best_acc = val_acc
            model.save_model_dict(os.path.join("models", "alex"), f"{name}_model.pth")
            # torch.save(model.state_dict(), os.path.join("models", "alex", f"{name}_model.pth"))

        # Step the scheduler
        scheduler.step()

    print('Training Complete.')
    return tr_acc_list, val_acc_list

In [14]:
def validate(model, val_dl, loss_fn, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(val_dl):
            # Move data to GPU
            labels = labels.squeeze().long()
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs.data, labels)

            # Compute metrics
            _, preds = torch.max(outputs.data, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += labels.size(0)

        val_loss = running_loss / total_samples
        val_acc = running_corrects / total_samples
        print(f"Validation Loss: {val_loss:0.6f}, Validation Accuracy: {val_acc:0.6f}")

    return val_loss, val_acc

In [15]:
name = "fft_alex"
tr_acc_list, val_acc_list = train(model, train_dl, val_dl, loss_fn, optimizer, scheduler, epochs=epochs, name=name, device=device)

Epoch [1/10]


100%|████████████████████████████████████████████████████████████████████████████████| 623/623 [01:56<00:00,  5.33it/s]


Training Loss: 0.620304, Training Accuracy: 0.718345


100%|████████████████████████████████████████████████████████████████████████████████| 263/263 [00:18<00:00, 14.37it/s]


Validation Loss: 0.504076, Validation Accuracy: 0.764762
Epoch [2/10]


100%|████████████████████████████████████████████████████████████████████████████████| 623/623 [02:08<00:00,  4.84it/s]


Training Loss: 0.489365, Training Accuracy: 0.784818


100%|████████████████████████████████████████████████████████████████████████████████| 263/263 [00:18<00:00, 14.46it/s]


Validation Loss: 0.474654, Validation Accuracy: 0.779762
Epoch [3/10]


100%|████████████████████████████████████████████████████████████████████████████████| 623/623 [02:05<00:00,  4.95it/s]


Training Loss: 0.440170, Training Accuracy: 0.810121


100%|████████████████████████████████████████████████████████████████████████████████| 263/263 [00:19<00:00, 13.43it/s]


Validation Loss: 0.454747, Validation Accuracy: 0.791429
Epoch [4/10]


100%|████████████████████████████████████████████████████████████████████████████████| 623/623 [02:08<00:00,  4.83it/s]


Training Loss: 0.401637, Training Accuracy: 0.827894


100%|████████████████████████████████████████████████████████████████████████████████| 263/263 [00:18<00:00, 14.27it/s]


Validation Loss: 0.438481, Validation Accuracy: 0.798333
Epoch [5/10]


 51%|████████████████████████████████████████▉                                       | 319/623 [02:38<02:30,  2.01it/s]


KeyboardInterrupt: 

In [None]:
tr_accuracy = np.array(tr_acc_list, dtype=np.float32)
val_accuracy = np.array(val_acc_list, dtype=np.float32)
np.save(os.path.join('models', 'alex', 'tr_fft_accuracy.npy'), tr_accuracy)
np.save(os.path.join('models', 'alex', 'val_fft_accuracy.npy'), val_accuracy)