In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from PIL import ImageEnhance, ImageFilter
from torchmetrics import Accuracy
from torchinfo import summary

# pretrained models (AlexNet) from torchvision

import torchvision.models as models

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
sys.path.append('../')  

from Models.alexnet import AlexNet

import numpy as np

In [2]:
base_path = os.getcwd() + '\\data'

In [3]:
class EdgeEnhancement:
    def __call__(self, img):
        return img.filter(ImageFilter.FIND_EDGES)

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize(256),  # Resize slightly larger than final size
    transforms.RandomResizedCrop(224),  # Random crop back down to 224x224
    transforms.RandomHorizontalFlip(),  # Randomly flip the images horizontally
    transforms.RandomRotation(15),  # Rotate by +/- 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, ),  # Randomly change brightness and contrast
    # transforms.Resize(224),  # Resize to 224x224 to match AlexNet input size
    # transforms.Lambda(lambda img: EdgeEnhancement()(img)),
    transforms.ToTensor(),   # Convert the image to a tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images
])




class DropletDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Load images and labels
        for label in ['background', 'droplets']:
            class_dir = os.path.join(data_dir, label)
            for filename in os.listdir(class_dir):
                if filename.endswith('.jpg'):  # Modify if needed for different extensions
                    img_path = os.path.join(class_dir, filename)
                    self.images.append(img_path)
                    self.labels.append(1 if label == 'droplets' else 0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB if not already
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(label)
        return image, label
    

droplet_dataset = DropletDataset(data_dir=base_path, transform=transform)

dataloader = DataLoader(droplet_dataset, batch_size=32, shuffle=True, num_workers=0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# get alexnet pretrained on ImageNet
alexnet_droplet_v2 = models.alexnet(weights=True)


alexnet_droplet = AlexNet(num_classes=10, channels=1).to(device)

alexnet_droplet.load_state_dict(torch.load('alexnet_model_mnist_full.pth'))



<All keys matched successfully>

In [None]:
summary(alexnet_droplet, input_size=(1, 1, 224, 224))

In [None]:
summary(alexnet_droplet_v2, input_size=(1, 3, 224, 224))

In [5]:
# Modify the last layer of the classifier to output 2 classes instead of 10
alexnet_droplet.classifier[6] = nn.Linear(4096, 2).to(device)

alexnet_droplet_v2.classifier[6] = torch.nn.Linear(alexnet_droplet_v2.classifier[6].in_features, 2)

# image is also with 1 channel
alexnet_droplet_v2.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
alexnet_droplet_v2 = alexnet_droplet_v2.to(device)


# accuracy = Accuracy(task='multiclass', num_classes=2)

accuracy = Accuracy(task='multiclass', num_classes=2)

accuracy = accuracy.to(device)

In [None]:
summary(alexnet_droplet_v2, input_size=(1, 1, 224, 224))

In [6]:
class_weights = torch.tensor([2.0, 1.0])  # Adjust these values based on your understanding of class importance
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

optimizer_adam = optim.Adam(alexnet_droplet.parameters(), lr=1e-1)
optimizer_sgd = torch.optim.SGD(alexnet_droplet.parameters(), lr=0.001, momentum=0.9)


# optim adam for alexnet_v2
optimizer_adam_v2 = optim.Adam(alexnet_droplet_v2.parameters(), lr=1e-4)
optimizer_sgd_v2 = torch.optim.SGD(alexnet_droplet_v2.parameters(), lr=0.001, momentum=0.9)

In [7]:
# Scheduler for learning rate decay
scheduler_adam = StepLR(optimizer_adam, step_size=10, gamma=0.1)


scheduler_adam_v2 = StepLR(optimizer_adam_v2, step_size=60, gamma=0.1)

In [175]:
# torch.Size([1, 224, 224])

# AlexNet on MNIST10

In [None]:
num_epochs = 40

for epoch in range(num_epochs):
    alexnet_droplet.train()
    train_loss, train_acc = 0, 0

    for images, labels in dataloader:

        images, labels = images.to(device), labels.to(device)

        optimizer_adam.zero_grad() # Zero the gradients

        outputs = alexnet_droplet(images) # Forward pass

        loss = criterion(outputs, labels) # Calculate the loss
        accuracy.update(outputs, labels)

        loss.backward() # Backward pass

        optimizer_adam.step() # Update the weights

        train_loss += loss.item()

    scheduler_adam.step() # Step the scheduler

    train_loss /= len(droplet_dataset)
    train_acc = accuracy.compute()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

# AlexNet on ImageNet

In [8]:
num_epochs = 400

for epoch in range(num_epochs):
    alexnet_droplet_v2.train()
    train_loss, train_acc = 0, 0

    for images, labels in dataloader:

        images, labels = images.to(device), labels.to(device)

        optimizer_adam_v2.zero_grad() # Zero the gradients

        outputs = alexnet_droplet_v2(images) # Forward pass

        loss = criterion(outputs, labels) # Calculate the loss
        accuracy.update(outputs, labels)

        loss.backward() # Backward pass

        optimizer_adam_v2.step() # Update the weights

        train_loss += loss.item()

    scheduler_adam_v2.step() # Step the scheduler

    train_loss /= len(droplet_dataset)
    train_acc = accuracy.compute()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

Epoch 1/400, Loss: 0.0237, Accuracy: 0.6267
Epoch 2/400, Loss: 0.0229, Accuracy: 0.5367
Epoch 3/400, Loss: 0.0239, Accuracy: 0.5733
Epoch 4/400, Loss: 0.0224, Accuracy: 0.5583
Epoch 5/400, Loss: 0.0216, Accuracy: 0.5573
Epoch 6/400, Loss: 0.0222, Accuracy: 0.5644
Epoch 7/400, Loss: 0.0211, Accuracy: 0.5790
Epoch 8/400, Loss: 0.0214, Accuracy: 0.5808
Epoch 9/400, Loss: 0.0211, Accuracy: 0.5844
Epoch 10/400, Loss: 0.0217, Accuracy: 0.5887
Epoch 11/400, Loss: 0.0204, Accuracy: 0.5976
Epoch 12/400, Loss: 0.0196, Accuracy: 0.6006
Epoch 13/400, Loss: 0.0195, Accuracy: 0.6036
Epoch 14/400, Loss: 0.0206, Accuracy: 0.6086
Epoch 15/400, Loss: 0.0202, Accuracy: 0.6089
Epoch 16/400, Loss: 0.0202, Accuracy: 0.6154
Epoch 17/400, Loss: 0.0191, Accuracy: 0.6173
Epoch 18/400, Loss: 0.0184, Accuracy: 0.6193
Epoch 19/400, Loss: 0.0193, Accuracy: 0.6239
Epoch 20/400, Loss: 0.0170, Accuracy: 0.6253
Epoch 21/400, Loss: 0.0183, Accuracy: 0.6324
Epoch 22/400, Loss: 0.0176, Accuracy: 0.6385
Epoch 23/400, Loss:

In [9]:
torch.save(alexnet_droplet_v2.state_dict(), 'alexnet_droplet_model_v0.pth')

In [10]:
# Load the model 
alexnet_droplet_v2.load_state_dict(torch.load('alexnet_droplet_model_v0.pth'))

<All keys matched successfully>

In [14]:
# Set the model to evaluation mode

alexnet_droplet_v2.eval()


# inference on a single image

# Load the image
img_path = 'M:\ML\ML_regs\pytorch_\Object_Classification\data\droplets\drop1.jpg'

image = Image.open(img_path).convert('RGB')

# Apply the transformations


image = transform(image).unsqueeze(0).to(device)

# Perform the inference

output = alexnet_droplet_v2(image)

# Get the predicted class


_, pred = torch.max(output, 1)

In [16]:
pred

tensor([1], device='cuda:0')