# C8Steerable CNN


In [1]:
try:
    done
except:
    import os

    os.chdir("../../../")
    from utils.download import download
    from utils.extract import extract

    args = {"model": "Model-1"}
    download(args)
    extract("data/Model_I.tgz", "data/")
    extract("data/Model_I_test.tgz", "data/")
    done = True

In [2]:
import copy
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from itertools import cycle
from PIL import Image
from sklearn.metrics import (
    auc,
    confusion_matrix,
    ConfusionMatrixDisplay,
    roc_auc_score,
    roc_curve,
)
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms


from tqdm import tqdm

warnings.filterwarnings("ignore")

# Data Preparation


In [3]:
train_transforms = transforms.Compose([
    transforms.RandomRotation(180),
    transforms.ColorJitter(brightness=(0.8, 1.2)),
    transforms.Resize(128),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.Resize(128),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor()
])

In [4]:
class AxionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.class_name = "axion"
        self.class_folder = os.path.join(self.root_dir, self.class_name)
        self.file_list = [f for f in os.listdir(self.class_folder) if f.endswith('.npy')]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_name = self.file_list[idx]
        file_path = os.path.join(self.class_folder, file_name)
        data = np.load(file_path, allow_pickle=True)
        image = data[0]
        weight = data[1]
        image = 255 * (image / image.max())
        image = Image.fromarray(image.astype('uint8')).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image.float(), weight.astype('float32')
    

trainset = AxionDataset(
    root_dir='data/Model_I',
    transform = train_transforms
)

testset = AxionDataset(
    root_dir='data/Model_I_test',
    transform = test_transforms
)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True,pin_memory=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True)

In [5]:
lr = 0.0001
epochs = 10
gamma = 0.7
batch_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"

# Model


In [6]:
model = torchvision.models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(2048, 1, bias=True)
model = model.to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15)

# Training


In [None]:
all_train_loss = []
all_test_loss = []


best_loss = np.inf

for epoch in range(epochs):
    epoch_loss = 0
    model.train()
    tr_loss_epoch = []
    test_loss_epoch = []
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label).float()
        tr_loss_epoch.append(loss.item())
       
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        epoch_loss += loss / len(train_loader)
    scheduler.step()
    all_train_loss.append(np.asarray(tr_loss_epoch))


    torch.cuda.empty_cache()
    with torch.no_grad():
        epoch_val_loss = 0
        for data, label in tqdm(test_loader):
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)
            test_loss_epoch.append(val_loss.item())

            epoch_val_loss += val_loss.item() / len(test_loader)
        all_test_loss.append(np.asarray(test_loss_epoch))

    if epoch_val_loss < best_loss:
        best_loss = epoch_val_loss
        best_model = copy.deepcopy(model)

    print(
        f"Epoch : {epoch+1} - train loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f}%"
    )
torch.save(best_model.state_dict(), "equi_nn.pth")
all_train_loss_mean = [j.mean() for j in all_train_loss]
all_test_loss_mean = [j.mean() for j in all_test_loss]

100%|███████████████████████████████████████████████████| 452/452 [01:05<00:00,  6.89it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.07it/s]


Epoch : 1 - train loss : 0.0176 - val_loss : 0.0048%


100%|███████████████████████████████████████████████████| 452/452 [01:04<00:00,  7.04it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.59it/s]


Epoch : 2 - train loss : 0.0048 - val_loss : 0.0038%


100%|███████████████████████████████████████████████████| 452/452 [01:05<00:00,  6.91it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.60it/s]


Epoch : 3 - train loss : 0.0031 - val_loss : 0.0022%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.81it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.19it/s]


Epoch : 4 - train loss : 0.0019 - val_loss : 0.0019%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.78it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.49it/s]


Epoch : 5 - train loss : 0.0018 - val_loss : 0.0015%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.76it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.01it/s]


Epoch : 6 - train loss : 0.0014 - val_loss : 0.0010%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.75it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.09it/s]


Epoch : 7 - train loss : 0.0012 - val_loss : 0.0006%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.77it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  8.00it/s]


Epoch : 8 - train loss : 0.0013 - val_loss : 0.0018%


100%|███████████████████████████████████████████████████| 452/452 [01:06<00:00,  6.79it/s]
100%|█████████████████████████████████████████████████████| 79/79 [00:09<00:00,  7.98it/s]


Epoch : 9 - train loss : 0.0007 - val_loss : 0.0010%


 25%|████████████▋                                      | 112/452 [00:16<00:49,  6.84it/s]

# Plotting Loss and Accuracy


In [None]:
figure = plt.figure(figsize=(12, 8))
cols, rows = 2, 1
figure.add_subplot(rows, cols, 1)
plt.title("Train loss Mean")
plt.plot(all_train_loss_mean)
figure.add_subplot(rows, cols, 2)
plt.title("Test loss Mean")
plt.plot(all_test_loss_mean)

# Loading Best Model


In [None]:
model = torchvision.models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(2048, 1, bias=True)
model = model.to(device)
model.load_state_dict(torch.load("equi_nn.pth"), strict=False)

# Testing


In [None]:
import torch.nn.functional as F
with torch.no_grad():
    model.eval()
    rmse_list = []
    mse_list = []
    mae_list = []
    
    for i, (x, y) in enumerate(tqdm(test_loader)):
        x = x.to(device)
        y = y.to(device)
        _y = model(x)
        
        # Calculate metrics
        rmse = torch.sqrt(F.mse_loss(_y, y))
        mse = F.mse_loss(_y, y)
        mae = F.l1_loss(_y, y)
        
        rmse_list.append(rmse.item())
        mse_list.append(mse.item())
        mae_list.append(mae.item())

# Calculate overall metrics
avg_rmse = np.mean(rmse_list)
avg_mse = np.mean(mse_list)
avg_mae = np.mean(mae_list)

print("Average RMSE:", avg_rmse)
print("Average MSE:", avg_mse)
print("Average MAE:", avg_mae)