## Imports

In [None]:
import os
import h5py
import pandas as pd
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
import io
import numpy as np

## Set random seed for reproducibility

In [None]:
SEED = 42
torch.manual_seed(SEED)


## Model Definition

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.3)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.4)

        self.fc1 = nn.Linear(64 * 32 * 32, 256)  
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.dropout1(x)

        x = self.pool2(self.relu(self.bn2(self.conv2(x))))
        x = self.dropout2(x)

        x = x.view(x.size(0), -1)  
        x = self.relu(self.fc1(x))
        x = self.dropout1(x)

        x = self.relu(self.fc2(x))
        x = self.dropout2(x)

        x = self.fc3(x)
        return x