In [133]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from PIL import ImageEnhance, ImageFilter
from torchmetrics import Accuracy
from torchinfo import summary

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
sys.path.append('../')  

from Models.alexnet import AlexNet

import numpy as np

In [97]:
base_path = os.getcwd() + '\\data'

In [129]:
class EdgeEnhancement:
    def __call__(self, img):
        return img.filter(ImageFilter.FIND_EDGES)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize(256),  # Resize slightly larger than final size
    transforms.RandomResizedCrop(224),  # Random crop back down to 224x224
    transforms.RandomHorizontalFlip(),  # Randomly flip the images horizontally
    transforms.RandomRotation(15),  # Rotate by +/- 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, ),  # Randomly change brightness and contrast
    # transforms.Resize(224),  # Resize to 224x224 to match AlexNet input size
    # transforms.Lambda(lambda img: EdgeEnhancement()(img)),
    transforms.ToTensor(),   # Convert the image to a tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images
])




class DropletDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Load images and labels
        for label in ['background', 'droplets']:
            class_dir = os.path.join(data_dir, label)
            for filename in os.listdir(class_dir):
                if filename.endswith('.jpg'):  # Modify if needed for different extensions
                    img_path = os.path.join(class_dir, filename)
                    self.images.append(img_path)
                    self.labels.append(1 if label == 'droplets' else 0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB if not already
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
    

droplet_dataset = DropletDataset(data_dir=base_path, transform=transform)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [135]:
alexnet_droplet = AlexNet(num_classes=10, channels=1).to(device)

alexnet_droplet.load_state_dict(torch.load('alexnet_model_mnist_full.pth'))

<All keys matched successfully>

In [137]:
summary(alexnet_droplet, input_size=(1, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
AlexNet                                  [1, 10]                   --
├─Sequential: 1-1                        [1, 256, 6, 6]            --
│    └─Conv2d: 2-1                       [1, 64, 55, 55]           7,808
│    └─ReLU: 2-2                         [1, 64, 55, 55]           --
│    └─MaxPool2d: 2-3                    [1, 64, 27, 27]           --
│    └─Conv2d: 2-4                       [1, 192, 27, 27]          307,392
│    └─ReLU: 2-5                         [1, 192, 27, 27]          --
│    └─MaxPool2d: 2-6                    [1, 192, 13, 13]          --
│    └─Conv2d: 2-7                       [1, 384, 13, 13]          663,936
│    └─ReLU: 2-8                         [1, 384, 13, 13]          --
│    └─Conv2d: 2-9                       [1, 256, 13, 13]          884,992
│    └─ReLU: 2-10                        [1, 256, 13, 13]          --
│    └─Conv2d: 2-11                      [1, 256, 13, 13]          

In [139]:
# Modify the last layer of the classifier to output 2 classes instead of 10
alexnet_droplet.classifier[6] = nn.Linear(4096, 2).to(device)


In [141]:
criterion = nn.CrossEntropyLoss()
optimizer_adam = optim.Adam(alexnet_droplet.parameters(), lr=1e-4)
optimizer_sgd = torch.optim.SGD(alexnet_droplet.parameters(), lr=0.001, momentum=0.9)