In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torch.optim as optim
import time


def train():
    # TRANSFORMATION AUGMENTATION
    data_transform = transforms.Compose([
        transforms.Resize(size=(224, 224)),
        # Turn the image into a torch.Tensor
        transforms.ToTensor(),  # converts all pixel values from 0 to 255 to be between 0.0 and 1.0
        # ConvFiltersTransform(axis=(1, 2)),
        # transforms.Resize(size=(224, 224))
    ])

    # DATA SETS
    train_data = datasets.ImageFolder(root='./dataset/train', transform=data_transform)
    test_data = datasets.ImageFolder(root='./dataset/test', transform=data_transform)

    # DATA LOADER
    train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

    # print(train_data[0][0].shape)  # C, H, W

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3)
            self.bn1 = nn.BatchNorm2d(32)
            # Convolutional blocks
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            # self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
            # self.bn3 = nn.BatchNorm2d(64)

            self.conv4_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv4_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
            # Left side
            self.conv5 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
            self.conv6 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.conv7 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            # Right side
            self.conv8 = nn.Conv2d(192, 256, kernel_size=3, padding=1)
            self.conv9 = nn.Conv2d(192, 256, kernel_size=3, padding=1)

            #
            self.conv10 = nn.Conv2d(512, 920, kernel_size=3, padding=1)

            # Pooling layers
            self.maxpool = nn.MaxPool2d(kernel_size=5, stride=5, padding=0)
            self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=4, padding=0)
            self.avgpool2 = nn.AdaptiveAvgPool2d((3, 3))

            # Fully connected layers (for now, we'll keep this as a placeholder)
            self.fc1 = nn.Linear(920 * 3 * 3, 920)  # Adjust this later dynamically
            # self.fc2 = nn.Linear(920, 256)
            # self.fc3 = nn.Linear(256, 128)
            # self.fc4 = nn.Linear(128, 64)
            # self.fc5 = nn.Linear(64, 32)
            self.fc6 = nn.Linear(920, 3)

        def forward(self, x):
            x = F.relu(self.bn1(self.conv1(x)))

            # Pooling
            x = self.maxpool(x)
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            # Left side
            x_1 = F.relu(self.conv4_1(x))
            x_L = F.relu(self.conv5(x_1))
            x_L_1 = F.relu(self.conv6(x_L))
            x_L_2 = F.relu(self.conv7(x_L))
            x_L = F.relu(torch.add(x_L_1, x_L_2))
            # Right side
            x_2 = F.relu(self.conv4_2(x))
            x_R = torch.cat((x_1, x_2), 1)
            x_R_1 = F.relu(self.conv8(x_R))
            x_R_2 = F.relu(self.conv9(x_R))
            x_R = F.relu(torch.add(x_R_1, x_R_2))
            #
            x = torch.cat((x_L, x_R), 1)
            x = F.relu(self.avgpool1(x))
            x = F.relu(self.conv10(x))
            x = F.relu(self.avgpool2(x))
            #
            x = x.view(x.size(0), -1)  # Flatten
            x = F.relu(self.fc1(x))
            # x = self.fc2(x)
            # x = self.fc3(x)
            # x = self.fc4(x)
            # x = self.fc5(x)
            x = self.fc6(x)

            return x

    net = Net()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}")

    net.to(device)
    summary(net, input_size=(3, 224, 224))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    
    for epoch in range(20):  # loop over the dataset multiple times

        running_loss = 0.0
        time_start = time.time()
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 20 == 19:    # print every 20 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 20))
                running_loss = 0.0
                print('Time:', time.time() - time_start)
                time_start = time.time()

    print('Finished Training')

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))


if __name__ == '__main__':
    train()

Training on: cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]           4,736
       BatchNorm2d-2         [-1, 32, 224, 224]              64
         MaxPool2d-3           [-1, 32, 44, 44]               0
            Conv2d-4           [-1, 64, 44, 44]          18,496
            Conv2d-5           [-1, 64, 44, 44]          36,928
            Conv2d-6          [-1, 128, 44, 44]          73,856
            Conv2d-7          [-1, 128, 44, 44]         147,584
            Conv2d-8          [-1, 256, 44, 44]         295,168
            Conv2d-9          [-1, 256, 44, 44]         295,168
           Conv2d-10           [-1, 64, 44, 44]          36,928
           Conv2d-11          [-1, 256, 44, 44]         442,624
           Conv2d-12          [-1, 256, 44, 44]         442,624
        AvgPool2d-13          [-1, 512, 11, 11]               0
           Conv2d-14 