In [51]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torch.optim as optim
import time


def train():
    # TRANSFORMATION AUGMENTATION
    data_transform = transforms.Compose([
        transforms.Resize(size=(224, 224)),
        # Turn the image into a torch.Tensor
        transforms.ToTensor(),  # converts all pixel values from 0 to 255 to be between 0.0 and 1.0
        # ConvFiltersTransform(axis=(1, 2)),
        # transforms.Resize(size=(224, 224))
    ])

    # DATA SETS
    train_data = datasets.ImageFolder(root='./dataset/train', transform=data_transform)
    test_data = datasets.ImageFolder(root='./dataset/test', transform=data_transform)

    # DATA LOADER
    train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

    # print(train_data[0][0].shape)  # C, H, W

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3)
            self.bn1 = nn.BatchNorm2d(32)
            # Convolutional blocks
            self.conv1 = nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2)
            self.bn2 = nn.BatchNorm2d(64)
            # self.bn3 = nn.BatchNorm2d(64)

            # Left side
            self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
            self.conv3 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
            self.conv4 = nn.Conv2d(128, 256, kernel_size=5, padding=2)
            self.conv5 = nn.Conv2d(128, 256, kernel_size=5, padding=2)
            # Right side
            self.conv6 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
            self.conv7 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
            self.conv8 = nn.Conv2d(64, 256, kernel_size=5, padding=2)
            self.conv9 = nn.Conv2d(64, 256, kernel_size=5, stride=1, padding=2)

            #
            self.conv10 = nn.Conv2d(512, 512, kernel_size=5, padding=2)
            self.conv11 = nn.Conv2d(512, 256, kernel_size=5, padding=2)

            # Conv3x3
            self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            self.conv13 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            self.conv14 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            self.conv15 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            # Left side
            self.conv16 = nn.Conv2d(64, 96, kernel_size=3, padding=1)
            self.conv17 = nn.Conv2d(96, 128, kernel_size=3, padding=1)
            # Right side
            self.conv18 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            
            self.conv19 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.conv20 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            self.conv21 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
            
            # Pooling layers
            self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=5, padding=0)
            self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=4, padding=0)
            self.avgpool2 = nn.AdaptiveAvgPool2d((3, 3))

            # Fully connected layers (for now, we'll keep this as a placeholder)
            self.fc1 = nn.Linear(512 * 3 * 3, 128)  # Adjust this later dynamically
            self.fc2 = nn.Linear(128, 64)
            self.fc3 = nn.Linear(64, 3)

        def forward(self, x):
            x = F.relu(self.bn1(self.conv0(x)))

            # Pooling
            x = self.maxpool1(x)
            x = F.relu(self.bn2(self.conv1(x)))
            x_avg1 = self.avgpool1(x)
            # Branch 5x5
            # Left side
            x_1 = F.relu(self.conv2(x_avg1))
            x_1 = F.relu(self.conv3(x_1))
            x_1_L = F.relu(self.conv4(x_1))
            x_1_R = F.relu(self.conv5(x_1))
            x_1 = F.relu(torch.add(x_1_L, x_1_R))
            # Right side
            x_2 = F.relu(self.conv6(x_avg1))
            x_2 = F.relu(self.conv7(x_2))
            x_2_L = F.relu(self.conv8(x_2))
            x_2_R = F.relu(self.conv9(x_2))
            x_2 = F.relu(torch.add(x_2_L, x_2_R))
            #
            x_12 = torch.cat((x_1, x_2), 1)
            x_12 = F.relu(self.conv10(x_12))
            x_12 = F.relu(self.conv11(x_12))
            # Branch 3x3
            x_3 = F.relu(self.conv12(x_avg1))
            x_3 = F.relu(self.conv13(x_3))
            x_3 = F.relu(self.conv14(x_3))
            x_3 = F.relu(self.conv15(x_3))
            x_3_L = F.relu(self.conv16(x_3))
            x_3_L = F.relu(self.conv17(x_3_L))
            x_3_R = F.relu(self.conv18(x_3))
            
            x_3 = torch.cat((x_3_L, x_3_R), 1)
            x_3 = F.relu(self.conv19(x_3))
            x_3 = F.relu(self.conv20(x_3))
            x_3 = F.relu(self.conv21(x_3))
            
            x = torch.cat((x_12, x_3), 1)
            x = F.relu(self.avgpool2(x))
            
            x = x.view(x.size(0), -1)  # Flatten
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            x = self.fc3(x)

            return x

    net = Net()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}")

    net.to(device)
    summary(net, input_size=(3, 224, 224))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    
    for epoch in range(5):  # loop over the dataset multiple times

        running_loss = 0.0
        time_start = time.time()
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 20 == 19:    # print every 20 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 20))
                running_loss = 0.0
                print('Time:', time.time() - time_start)
                time_start = time.time()

    print('Finished Training')

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))


if __name__ == '__main__':
    train()

Training on: cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]           4,736
       BatchNorm2d-2         [-1, 32, 224, 224]              64
         MaxPool2d-3           [-1, 32, 44, 44]               0
            Conv2d-4           [-1, 64, 44, 44]          51,264
       BatchNorm2d-5           [-1, 64, 44, 44]             128
         AvgPool2d-6           [-1, 64, 11, 11]               0
            Conv2d-7          [-1, 128, 11, 11]         204,928
            Conv2d-8          [-1, 128, 11, 11]         409,728
            Conv2d-9          [-1, 256, 11, 11]         819,456
           Conv2d-10          [-1, 256, 11, 11]         819,456
           Conv2d-11           [-1, 64, 11, 11]         102,464
           Conv2d-12           [-1, 64, 11, 11]         102,464
           Conv2d-13          [-1, 256, 11, 11]         409,856
           Conv2d-14 