In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pandas as pd


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class PWConv_Batch_Relu(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False
    ):
        super(PWConv_Batch_Relu, self).__init__()
        self.pw_conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU6()

    def forward(self, x):
        x = self.pw_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class PWConv_Batch(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False
    ):
        super(PWConv_Batch, self).__init__()
        self.pw_conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.pw_conv(x)
        x = self.bn(x)
        return x


class DwConv_Batch(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
    ):
        super(DwConv_Batch, self).__init__()
        self.dw_conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.dw_conv(x)
        x = self.bn(x)
        return x


class DwConv_Batch_Relu(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
    ):
        super(DwConv_Batch_Relu, self).__init__()
        self.dw_conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU6()

    def forward(self, x):
        x = self.dw_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class ECA(nn.Module):
    def __init__(self, in_channels, kernel_size=3, bias=False):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = y.squeeze(-1).transpose(-1, -2)
        y = self.conv(y)
        y = self.sigmoid(y)
        y = y.transpose(-1, -2).unsqueeze(-1)
        return x * y


class SGECA(nn.Module):
    def __init__(
        self, in_channels, out_channels, first_out_channel, final_stride=2, config=2
    ):
        super(SGECA, self).__init__()
        self.config = config
        self.use_shortcut = config and (
            in_channels == out_channels and final_stride == 1
        )

        if config == 1:
            self.pw_conv1 = PWConv_Batch_Relu(
                in_channels,
                first_out_channel,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            self.eca = ECA(in_channels=first_out_channel, kernel_size=3, bias=False)
            self.pw_conv2 = PWConv_Batch(
                first_out_channel,
                out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            self.dw_conv = DwConv_Batch(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=final_stride,
                padding=1,
                bias=False,
            )
        else:
            self.dw_conv = DwConv_Batch_Relu(
                in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False
            )
            self.pw_conv1 = PWConv_Batch_Relu(
                in_channels,
                first_out_channel,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            self.eca = ECA(in_channels=first_out_channel, kernel_size=3, bias=False)
            self.pw_conv2 = PWConv_Batch(
                first_out_channel,
                out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            self.dw_conv2 = DwConv_Batch(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=final_stride,
                padding=1,
                bias=False,
            )
            self.activate_shortcut = (in_channels == out_channels) and (final_stride == 1)

    def forward(self, x):

        if self.config == 1:
            x = self.pw_conv1(x)
            x = self.eca(x)
            x = self.pw_conv2(x)
            x = self.dw_conv(x)
        elif self.config == 2:
            
            original_x = x
            
            x = self.dw_conv(x)
            x = self.pw_conv1(x)
            x = self.eca(x)
            x = self.pw_conv2(x)
            x = self.dw_conv2(x)
            if self.activate_shortcut:
                x = x + original_x

        return x


class ParC_operator(nn.Module):
    def __init__(self, dim, type, global_kernel_size, use_pe=True):
        super().__init__()
        self.type = type  # H or W
        self.dim = dim
        self.use_pe = use_pe
        self.global_kernel_size = global_kernel_size
        self.kernel_size = (
            (global_kernel_size, 1) if self.type == "H" else (1, global_kernel_size)
        )
        self.gcc_conv = nn.Conv2d(dim, dim, kernel_size=self.kernel_size, groups=dim)
        if use_pe:
            if self.type == "H":
                self.pe = nn.Parameter(torch.randn(1, dim, self.global_kernel_size, 1))
            elif self.type == "W":
                self.pe = nn.Parameter(torch.randn(1, dim, 1, self.global_kernel_size))
            init.trunc_normal_(self.pe, std=0.02)

    def forward(self, x):
        if self.use_pe:
            x = x + self.pe.expand(
                1, self.dim, self.global_kernel_size, self.global_kernel_size
            )

        x_cat = (
            torch.cat((x, x[:, :, :-1, :]), dim=2)
            if self.type == "H"
            else torch.cat((x, x[:, :, :, :-1]), dim=3)
        )
        x = self.gcc_conv(x_cat)

        return x


class ParcSG(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        dim,
        global_kernel_size=14,
        use_pe=True,
        first_out_channel=4,
    ):
        super(ParcSG, self).__init__()
        self.parc_H = ParC_operator(dim // 2, "H", global_kernel_size, use_pe=use_pe)
        self.parc_W = ParC_operator(dim // 2, "W", global_kernel_size, use_pe=use_pe)
        self.pw_1 = PWConv_Batch_Relu(
            in_channels,
            first_out_channel,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.pw2 = PWConv_Batch(
            first_out_channel,
            out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.dw = DwConv_Batch_Relu(
            out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False
        )

    def forward(self, x):
        x_H, x_W = torch.chunk(x, 2, dim=1)
        x_H = self.parc_H(x_H)
        x_W = self.parc_W(x_W)

        # print(f"x_H: {x_H.shape}")
        # print(f"x_W: {x_W.shape}")

        target_size = (max(x_H.shape[2], x_W.shape[2]), max(x_H.shape[3], x_W.shape[3]))
        x_H_resized = F.interpolate(
            x_H, size=target_size, mode="bilinear", align_corners=False
        )
        x_V_resized = F.interpolate(
            x_W, size=target_size, mode="bilinear", align_corners=False
        )

        x = torch.cat((x_H_resized, x_V_resized), dim=1)
        # print(f"Final x: {x.shape}")

        x = self.pw_1(x)
        x = self.pw2(x)
        x = self.dw(x)

        return x


# Define the model to test output shape
class LSGNet(nn.Module):
    def __init__(self, num_classes=8):
        super(LSGNet, self).__init__()

        input_tensor = torch.randn(1, 3, 224, 224)
        # Layer 1 (Conv 3x3, BN, ReLU)
        self.conv1 = nn.Conv2d(3, 24, kernel_size=3, stride=2, padding=1, bias=False)
        output_tensor = self.conv1(input_tensor)

        self.bn1 = nn.BatchNorm2d(24)
        output_tensor = self.bn1(output_tensor)

        self.relu = nn.ReLU(inplace=True)
        output_tensor = self.relu(output_tensor)
        # print(f"    Layer #1: Shape of output tensor: {output_tensor.shape}")

        # Layer 2
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        output_tensor = self.maxpool(output_tensor)
        # print(f"    Layer #2: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #2: Number of parameters: {count_parameters(self.maxpool)}")

        # Layer 3: SGECA (pw_eca_pw_dw)
        self.sgeca1 = SGECA(24, 116, first_out_channel=32, config=1)
        output_tensor = self.sgeca1(output_tensor)
        # print(f"    Layer #3: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #3: Number of parameters: {count_parameters(self.sgeca1)}")

        # Layer 4: SGECA (dw_pw_eca_pw)
        self.sgeca2 = SGECA(116, 116, first_out_channel=32, final_stride=1, config=2)
        output_tensor = self.sgeca2(output_tensor)
        # print(f"    Layer #4: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #4: Number of parameters: {count_parameters(self.sgeca2)}")

        # Layer 5: SGECA (pw_eca_pw_dw)
        self.sgeca3 = SGECA(116, 232, first_out_channel=48, config=1)
        output_tensor = self.sgeca3(output_tensor)
        # print(f"    Layer #5: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #5: Number of parameters: {count_parameters(self.sgeca3)}")

        # Layer 6: SGECA (dw_pw_eca_pw)
        self.sgeca4 = SGECA(232, 232, first_out_channel=48, final_stride=1, config=2)
        output_tensor = self.sgeca4(output_tensor)
        # print(f"    Layer #6: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #6: Number of parameters: {count_parameters(self.sgeca4)}")

        # Layer 7: ParcSG
        self.parcsg1 = ParcSG(
            in_channels=232,
            out_channels=232,
            dim=232,
            global_kernel_size=1,
            use_pe=True,
            first_out_channel=56,
        )
        output_tensor = self.parcsg1(output_tensor)
        # print(f"    Layer #7: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #7: Number of parameters: {count_parameters(self.parcsg1)}")

        # Layer 8: ParcSG
        self.parcsg2 = ParcSG(
            in_channels=232,
            out_channels=232,
            dim=232,
            global_kernel_size=1,
            use_pe=True,
            first_out_channel=56,
        )
        output_tensor = self.parcsg2(output_tensor)
        # print(f"    Layer #8: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #8: Number of parameters: {count_parameters(self.parcsg2)}")

        # Layer 9: SGECA (pw_eca_pw_dw)
        self.sgeca5 = SGECA(232, 464, first_out_channel=80, config=1)
        output_tensor = self.sgeca5(output_tensor)
        # print(f"    Layer #9: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #9: Number of parameters: {count_parameters(self.sgeca5)}")

        # Layer 10: ParcSG
        self.parcsg3 = ParcSG(
            in_channels=464,
            out_channels=464,
            dim=464,
            global_kernel_size=1,
            use_pe=True,
            first_out_channel=84,
        )
        output_tensor = self.parcsg3(output_tensor)
        # print(f"    Layer #10: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #10: Number of parameters: {count_parameters(self.parcsg3)}")

        # Layer 11: (Conv 1x1 (stride=1), BN, ReLU)
        self.conv2 = nn.Conv2d(
            464, 1024, kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn2 = nn.BatchNorm2d(1024)
        self.relu2 = nn.ReLU(inplace=True)
        output_tensor = self.relu2(self.bn2(self.conv2(output_tensor)))
        # print(f"    Layer #11: Shape of output tensor: {output_tensor.shape}")
        # print(f"Layer #11: Number of parameters: {count_parameters(self.conv2)}")

        # Layer 12: Global Average Pooling
        self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # print(f"    Layer #12: Shape of output tensor: {output_tensor.shape}")
        # print(
        #     f"Layer #12: Number of parameters: {count_parameters(self.global_avgpool)}"
        # )

        # Layer 13: Classifier
        self.fc = nn.Linear(1024, num_classes, bias=True)
        # print(f"Layer #13: Number of parameters: {count_parameters(self.fc)}")
        # print(f"Total number of parameters: {count_parameters(self)}")

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.sgeca1(x)
        x = self.sgeca2(x)
        x = self.sgeca3(x)
        x = self.sgeca4(x)
        x = self.parcsg1(x)
        x = self.parcsg2(x)
        x = self.sgeca5(x)
        x = self.parcsg3(x)
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.global_avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import time
import torch
import torch.nn as nn
import numpy as np

def save_model(model, filename="model_state_dict.pth"):
  torch.save(model, f"{filename}.pth")
  torch.save(model.state_dict(), f"s_{filename}.pth")

def validate(device, model, val_loader):
    model.eval()
    num_classes = len(val_loader.dataset.classes) 
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=int)

    with torch.no_grad():
        correct = 0
        total = 0
        class_correct = {}
        class_total = {}
        false_positives = {}
        false_negatives = {}
        start = time.time()
        
        for images, labels in val_loader:
            images, labels = images.to(device, dtype=torch.float32), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for label, prediction in zip(labels, predicted):
                confusion_matrix[label.item(), prediction.item()] += 1  # Update the confusion matrix
                if label == prediction:
                    class_correct[label.item()] = class_correct.get(label.item(), 0) + 1
                else:
                    false_negatives[label.item()] = false_negatives.get(label.item(), 0) + 1
                    false_positives[prediction.item()] = false_positives.get(prediction.item(), 0) + 1

                class_total[label.item()] = class_total.get(label.item(), 0) + 1

        # Calculate precision and recall
        precision_list = []
        recall_list = []

        for class_id in class_total.keys():
            tp = class_correct.get(class_id, 0)
            fp = false_positives.get(class_id, 0)
            fn = false_negatives.get(class_id, 0)

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            precision_list.append(precision)
            recall_list.append(recall)

        # Calculate overall precision, recall
        end = time.time()
        computation_time = end - start
        
        overall_precision = sum(precision_list) / len(precision_list) if len(precision_list) > 0 else 0
        overall_recall = sum(recall_list) / len(recall_list) if len(recall_list) > 0 else 0
        accuracy = 100 * correct / total

        print(f'Accuracy: {accuracy:.2f}%')
        print(f'Precision: {overall_precision:.2f}')
        print(f'Recall: {overall_recall:.2f}')

        return accuracy, overall_precision, overall_recall, confusion_matrix
    

def train(device, model, train_loader, criterion, optimizer, num_epochs):
    model.to(device).float()
    model.train()
    train_losses = []
    start = time.time()
    
    for epoch in range(num_epochs):
        print(f'Start epoch {epoch+1}/{num_epochs}')
        running_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        train_accuracy = train_correct / train_total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {train_accuracy:.2f}')

    end = time.time()
    computation_time = end - start 
    print(f'Training completed in {(end - start):.2f} seconds')
    print(f'Training accuracy: {train_accuracy:.2f}')
    print('--------------------------------')

    return train_accuracy, train_losses, computation_time

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5),
    transforms.ColorJitter(contrast=0.5),
    transforms.ColorJitter(saturation=0.5),
    transforms.RandomRotation(45),
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_transforms_validation_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


num_epochs = 150
batch_size = 64
learning_rate = 0.0001


# Data loaders
dataset_path_training = '../../dataset-tomatoes/train'
dataset_path_validation = '../../dataset-tomatoes/validation'
validation_dataset = ImageFolder(root=dataset_path_validation, transform=data_transforms_validation_test)
train_dataset = ImageFolder(root=dataset_path_training, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

dataloaders = {
    'train': train_loader,
    'val': validation_loader
}

# Initialize the model, criterion, optimizer
model = LSGNet(num_classes=8)
print(f"Number of trainable parameters: {count_parameters(model)}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


Number of trainable parameters: 757203


In [19]:
# Train and validate the model
training_accuracy, training_loss, time_elapsed_tr = train(device=device, model=model, train_loader=train_loader, criterion=criterion, optimizer=optimizer, num_epochs=num_epochs)

print ("Training completed, stats:")
print(f"Training accuracy: {training_accuracy}")
print(f"Training loss: {training_loss}")
print(f"Time elapsed: {time_elapsed_tr}")

Start epoch 1/150
Epoch 1/150, Loss: 1.4614, Accuracy: 0.48
Start epoch 2/150
Epoch 2/150, Loss: 1.0703, Accuracy: 0.63
Start epoch 3/150
Epoch 3/150, Loss: 0.9040, Accuracy: 0.70
Start epoch 4/150
Epoch 4/150, Loss: 0.7961, Accuracy: 0.73
Start epoch 5/150
Epoch 5/150, Loss: 0.7185, Accuracy: 0.75
Start epoch 6/150
Epoch 6/150, Loss: 0.6595, Accuracy: 0.78
Start epoch 7/150
Epoch 7/150, Loss: 0.6123, Accuracy: 0.79
Start epoch 8/150
Epoch 8/150, Loss: 0.5718, Accuracy: 0.80
Start epoch 9/150
Epoch 9/150, Loss: 0.5280, Accuracy: 0.82
Start epoch 10/150
Epoch 10/150, Loss: 0.4936, Accuracy: 0.83
Start epoch 11/150
Epoch 11/150, Loss: 0.4694, Accuracy: 0.84
Start epoch 12/150
Epoch 12/150, Loss: 0.4535, Accuracy: 0.85
Start epoch 13/150
Epoch 13/150, Loss: 0.4197, Accuracy: 0.86
Start epoch 14/150
Epoch 14/150, Loss: 0.3997, Accuracy: 0.86
Start epoch 15/150
Epoch 15/150, Loss: 0.3855, Accuracy: 0.87
Start epoch 16/150
Epoch 16/150, Loss: 0.3872, Accuracy: 0.87
Start epoch 17/150
Epoch 1

In [24]:
file_name = f"lsgnet-{num_epochs}-epochs"

#save model with jit
input_test = torch.randn(1, 3, 224, 224)
input_test = input_test.to(device)

traced_model = torch.jit.trace(model, input_test)

traced_model.save(f"{file_name}.pt")

  target_size = (max(x_H.shape[2], x_W.shape[2]), max(x_H.shape[3], x_W.shape[3]))


In [21]:

try:
    validation_accuracy, validation_precision, validation_recall, confusion_matrix = validate(device, model, validation_loader)
    print(f"Validation Accuracy: {validation_accuracy}%")
    print(f"Validation Precision: {validation_precision}")
    print(f"Validation Recall: {validation_recall}")
    print(f"Confusion Matrix:\n{confusion_matrix}")
except RuntimeError as e:
    print(f"Runtime error: {e}")
    

Accuracy: 96.38%
Precision: 0.95
Recall: 0.94
Validation Accuracy: 96.38084632516704%
Validation Precision: 0.9491139519066106
Validation Recall: 0.9448535313523843
Confusion Matrix:
[[224   0   1   2   6   1   1   0]
 [  2 120   4   2   3   0   0   0]
 [  1   3 204   1   0   1   0   0]
 [  0   0   0 118   1   0   1   0]
 [  3   4   1   6 191   2   2   1]
 [  1   0   0   4   0 619   0   0]
 [  1   1   3   2   0   2  55   2]
 [  0   0   0   0   0   0   0 200]]
