In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, accuracy_score
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil

# Define Gish activation function
class Gish(nn.Module):
    def __init__(self):
        super(Gish, self).__init__()

    def forward(self, x):
        x = torch.clamp(x, min=1e-5)  # Prevent nan issues in log
        return x * torch.log(2 - torch.exp(-x))

# Custom transformation function
def custom_transform():
    return transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),  # Augmentation
        transforms.RandomRotation(10),     # Augmentation
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

# CNN Model
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.gish = Gish()

    def forward(self, x):
        x = self.pool(self.gish(self.conv1(x)))
        x = self.pool(self.gish(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = self.gish(self.fc1(x))
        x = self.fc2(x)
        return x

# Load dataset
train_data = datasets.ImageFolder('dataset/CaltechTinySplit/train', transform=custom_transform())
test_data = datasets.ImageFolder('dataset/CaltechTinySplit/test', transform=custom_transform())

num_classes = len(train_data.classes)
print(f"Detected {num_classes} classes in the dataset.")

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

# Model setup
model = CNNModel(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.3)

# Training loop
epochs = 50
best_accuracy = 0.0

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    scheduler.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1).numpy()  # Get probabilities
            preds = np.argmax(probabilities, axis=1)
            all_preds.extend(probabilities)  # Append probabilities
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, np.argmax(all_preds, axis=1))  # Correctly compute accuracy
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {acc:.4f}")

    # Save the best model weights
    if acc > best_accuracy:
        best_accuracy = acc
        torch.save(model.state_dict(), 'best_cnn_model.pth')

# Load the best model weights
model.load_state_dict(torch.load('best_cnn_model.pth', map_location=torch.device('cpu')))
model.eval()

# Evaluate on test set
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        probabilities = torch.softmax(outputs, dim=1).numpy()
        all_preds.extend(probabilities)  # Append probabilities
        all_labels.extend(labels.numpy())

# Convert to arrays
all_preds = np.array(all_preds)
all_labels_bin = label_binarize(all_labels, classes=list(range(num_classes)))

# Metrics
cm = confusion_matrix(np.argmax(all_labels_bin, axis=1), np.argmax(all_preds, axis=1))
print("Confusion Matrix:\n", cm)
print("Classification Report:\n", classification_report(np.argmax(all_labels_bin, axis=1), np.argmax(all_preds, axis=1)))

# Compute AUC
if len(np.unique(np.argmax(all_preds, axis=1))) > 1:  # At least two classes predicted
    auc = roc_auc_score(all_labels_bin, all_preds, multi_class='ovr')
else:
    auc = "N/A"
print(f"AUC Score: {auc}")


Detected 9 classes in the dataset.
Epoch 1/50, Loss: 1.7031, Accuracy: 0.4655
Epoch 2/50, Loss: 1.2540, Accuracy: 0.6782
Epoch 3/50, Loss: 1.0770, Accuracy: 0.6897
Epoch 4/50, Loss: 0.8874, Accuracy: 0.7299
Epoch 5/50, Loss: 0.7618, Accuracy: 0.7874
Epoch 6/50, Loss: 0.6488, Accuracy: 0.7816
Epoch 7/50, Loss: 0.5864, Accuracy: 0.7701
Epoch 8/50, Loss: 0.5324, Accuracy: 0.7701
Epoch 9/50, Loss: 0.4956, Accuracy: 0.8103
Epoch 10/50, Loss: 0.4819, Accuracy: 0.7931
Epoch 11/50, Loss: 0.4517, Accuracy: 0.7931
Epoch 12/50, Loss: 0.4223, Accuracy: 0.7989
Epoch 13/50, Loss: 0.3929, Accuracy: 0.8276
Epoch 14/50, Loss: 0.3840, Accuracy: 0.8103
Epoch 15/50, Loss: 0.3638, Accuracy: 0.8276
Epoch 16/50, Loss: 0.3625, Accuracy: 0.8276
Epoch 17/50, Loss: 0.3335, Accuracy: 0.8218
Epoch 18/50, Loss: 0.3234, Accuracy: 0.8161
Epoch 19/50, Loss: 0.3198, Accuracy: 0.8391
Epoch 20/50, Loss: 0.2917, Accuracy: 0.8276
Epoch 21/50, Loss: 0.2677, Accuracy: 0.8448
Epoch 22/50, Loss: 0.2600, Accuracy: 0.8448
Epoch 

  model.load_state_dict(torch.load('best_cnn_model.pth', map_location=torch.device('cpu')))


Confusion Matrix:
 [[43  0  0  0  0  1  0  0  0]
 [ 0 79  0  1  0  0  0  1  0]
 [ 0  0  3  1  0  0  0  1  0]
 [ 0  1  0  2  0  0  0  2  0]
 [ 0  0  0  0  6  1  0  0  0]
 [ 2  1  0  0  0  2  1  1  1]
 [ 0  0  0  0  0  0 10  0  0]
 [ 0  1  0  0  0  2  2  2  1]
 [ 0  0  0  0  0  1  0  0  5]]
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.98      0.97        44
           1       0.96      0.98      0.97        81
           2       1.00      0.60      0.75         5
           3       0.50      0.40      0.44         5
           4       1.00      0.86      0.92         7
           5       0.29      0.25      0.27         8
           6       0.77      1.00      0.87        10
           7       0.29      0.25      0.27         8
           8       0.71      0.83      0.77         6

    accuracy                           0.87       174
   macro avg       0.72      0.68      0.69       174
weighted avg       0.87      0.87   

In [2]:
from google.colab import files
uploaded = files.upload()



Saving CaltechTinySplit.zip to CaltechTinySplit.zip


In [3]:
!unzip CaltechTinySplit.zip -d ./dataset


Archive:  CaltechTinySplit.zip
   creating: ./dataset/CaltechTinySplit/
   creating: ./dataset/CaltechTinySplit/test/
   creating: ./dataset/CaltechTinySplit/test/camera/
  inflating: ./dataset/CaltechTinySplit/test/camera/image_0024.jpg  
  inflating: ./dataset/CaltechTinySplit/test/camera/image_0035.jpg  
  inflating: ./dataset/CaltechTinySplit/test/camera/image_0037.jpg  
  inflating: ./dataset/CaltechTinySplit/test/camera/image_0040.jpg  
  inflating: ./dataset/CaltechTinySplit/test/camera/image_0046.jpg  
  inflating: ./dataset/CaltechTinySplit/test/camera/Thumbs.db  
   creating: ./dataset/CaltechTinySplit/test/cannon/
  inflating: ./dataset/CaltechTinySplit/test/cannon/image_0024.jpg  
  inflating: ./dataset/CaltechTinySplit/test/cannon/image_0035.jpg  
  inflating: ./dataset/CaltechTinySplit/test/cannon/image_0037.jpg  
  inflating: ./dataset/CaltechTinySplit/test/cannon/image_0038.jpg  
  inflating: ./dataset/CaltechTinySplit/test/cannon/image_0040.jpg  
  inflating: ./dataset