##**KAN for AI-generated Image calssification**

##install important libraries

In [None]:
# Install necessary libraries
!pip install torch torchvision




##Set up Environment

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18
import torch.nn.functional as F
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
import numpy as np

# Set seed for reproducibility
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset

In [None]:
from sklearn.metrics import accuracy_score, precision_score, f1_score

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
import numpy as np



##Dataset Preprocessing

In [None]:
import os
import torch
import kagglehub
from torchvision import transforms,datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split

# Download the latest version of the dataset
path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")


# Define dataset directories
test_dir = '/root/.cache/kagglehub/datasets/birdy654/cifake-real-and-ai-generated-synthetic-images/versions/3/test'
train_dir = '/root/.cache/kagglehub/datasets/birdy654/cifake-real-and-ai-generated-synthetic-images/versions/3/train'

# Print dataset files for sanity check
print("Files in 'test' directory:", os.listdir(test_dir))
print("Files in 'train' directory:", os.listdir(train_dir))

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for ResNet input
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize
])

# 1. Load the full train dataset
full_trainset = datasets.ImageFolder(root=train_dir, transform=transform)

# Limit the train dataset to 1000 images
num_train_samples = 1000
train_indices = torch.randperm(len(full_trainset))[:num_train_samples]  # Randomly select 1000 indices
subset_trainset = torch.utils.data.Subset(full_trainset, train_indices)

# Split the 1000 images into train (800) and validation (200)
train_size = int(0.8 * num_train_samples)  # 80% -> 800 images
val_size = num_train_samples - train_size  # 20% -> 200 images
trainset, valset = random_split(subset_trainset, [train_size, val_size])

# 2. Load the full test dataset
full_testset = datasets.ImageFolder(root=test_dir, transform=transform)

# Limit the test dataset to 1000 images
num_test_samples = 1000
test_indices = torch.randperm(len(full_testset))[:num_test_samples]  # Randomly select 1000 indices
testset = torch.utils.data.Subset(full_testset, test_indices)

# 3. Create DataLoaders
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
val_loader = DataLoader(valset, batch_size=32, shuffle=False)
test_loader = DataLoader(testset, batch_size=32, shuffle=False)


# 4. Print dataset sizes for confirmation
print(f"Total images in training subset: {len(trainset)}")
print(f"Total images in validation subset: {len(valset)}")
print(f"Total images in test subset: {len(testset)}")






Files in 'test' directory: ['FAKE', 'REAL']
Files in 'train' directory: ['FAKE', 'REAL']
Total images in training subset: 800
Total images in validation subset: 200
Total images in test subset: 1000


##KAN implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# KAN Class
class KAN(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.constant_(self.base_weight, self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                self.scale_spline
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )

    def b_splines(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = self.grid
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(0, 1)
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(A, B).solution
        result = solution.permute(2, 0, 1)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output




##Model definition

In [None]:
# Model Definition with ResNet18 + KAN
class KANResNet18(nn.Module):
    def __init__(self, num_classes=2):
        super(KANResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = KAN(in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)


##Define Training Function

In [None]:
# Training Function
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    return running_loss / total, correct / total

##Define Validation Function

In [None]:
# Validation Function
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return running_loss / total, correct / total


##Training the Model

In [None]:
# Main Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = KANResNet18(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")



Epoch [1/10] Train Loss: 0.4631, Train Acc: 0.7788, Val Loss: 0.7292, Val Acc: 0.6700
Epoch [2/10] Train Loss: 0.3464, Train Acc: 0.8475, Val Loss: 0.5895, Val Acc: 0.7600
Epoch [3/10] Train Loss: 0.2543, Train Acc: 0.8988, Val Loss: 0.9409, Val Acc: 0.7250
Epoch [4/10] Train Loss: 0.3054, Train Acc: 0.8675, Val Loss: 0.7764, Val Acc: 0.7700
Epoch [5/10] Train Loss: 0.1846, Train Acc: 0.9287, Val Loss: 1.2411, Val Acc: 0.6900
Epoch [6/10] Train Loss: 0.2290, Train Acc: 0.9075, Val Loss: 0.3224, Val Acc: 0.8750
Epoch [7/10] Train Loss: 0.1744, Train Acc: 0.9263, Val Loss: 0.3311, Val Acc: 0.8550
Epoch [8/10] Train Loss: 0.1588, Train Acc: 0.9300, Val Loss: 0.5008, Val Acc: 0.8150
Epoch [9/10] Train Loss: 0.1199, Train Acc: 0.9537, Val Loss: 0.4418, Val Acc: 0.8750
Epoch [10/10] Train Loss: 0.1122, Train Acc: 0.9525, Val Loss: 0.3172, Val Acc: 0.9000


##Testing the Model

In [None]:
# Test Loop
test_loss, test_acc, test_precision, test_f1 = validate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test Precision: {test_precision:.4f}, Test F1: {test_f1:.4f}")

Test Loss: 0.3383, Test Acc: 0.8820, Test Precision: 0.8842, Test F1: 0.8819
