In [1]:
import torch.optim as optim
from torch.utils.data import DataLoader,ConcatDataset
from torchvision.datasets import ImageFolder
from torchvision import models
from torchvision.models.alexnet import AlexNet_Weights
import torch
from torch import nn, optim
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader
import torch.optim as optim
import torch.nn.init as init
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from PIL import ImageOps

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms, models

train_path = '/Users/Gaia/Desktop/CVPR-project/CVPR-project/data/train'
test_path = '/Users/Gaia/Desktop/CVPR-project/CVPR-project/data/test'

def transformation_for_AlexNet(img):
  resize = transforms.Compose([transforms.Resize([224,224]),
  transforms.ToTensor()])
  i = resize(img)
  return i


train=ImageFolder(root=train_path,transform=transformation_for_AlexNet)
test=ImageFolder(root=test_path,transform=transformation_for_AlexNet)


#split training set into training and validation
train_size=int(0.85*len(train))
validation_size=len(train)-train_size
training_set,validation_set=torch.utils.data.random_split(train,[train_size,validation_size])




class TransformedDataSet():
    """Wrap a datset (created with imagefolder) to apply a transformation"""
    def __init__(self, ds):
        self.ds = ds
        self.transformation=transforms.RandomHorizontalFlip(1)

    def __getitem__(self, index):
        """Get a sample from the dataset at the given index"""
        img, label = self.ds[index]

        # Apply the transformation if it is provided
        if self.transformation:
            img = self.transformation(img)

        return img, label

    def __len__(self):
        """Number of batches"""
        return len(self.ds)
    


batch_size = 32

augmented_training_set=TransformedDataSet(training_set)
concatenated_dataset = torch.utils.data.ConcatDataset([training_set, augmented_training_set])
augmented_train_loader = DataLoader(concatenated_dataset, batch_size=batch_size,shuffle=True,pin_memory=True)

validation_loader = DataLoader(validation_set, batch_size=batch_size,shuffle=False,pin_memory=True)
test_loader = DataLoader(test, batch_size=batch_size,shuffle=False,pin_memory=True)



model=models.alexnet(weights=AlexNet_Weights.DEFAULT)



for param in model.parameters():
    param.requires_grad = False


model.classifier[6]=nn.Linear(model.classifier[6].in_features,15)
for param in model.classifier[6].parameters():
    param.requires_grad=True



loss_function=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.00001,momentum=0.9,weight_decay= 0.0005)


def train_one_epoch(model,epoch_index,loader, loss_function,optimizer):
  running_loss=0

  for i, data in enumerate(loader):

    inputs,labels=data #get the minibatch

    outputs=model(inputs) #forward pass

    loss=loss_function(outputs,labels) #compute loss
    running_loss+=loss.item() #sum up the loss for the minibatches processed so far

    optimizer.zero_grad() #reset gradients
    loss.backward() #compute gradient
    optimizer.step() #update weights

  return running_loss/(i+1) # average loss per minibatch


def train_model(model,train_loader,validation_loader,loss_function,optimizer,EPOCHS):
  best_validation_loss=np.inf

  train_losses = []
  validation_losses = []
  validation_accuracies = []

  for epoch in range(EPOCHS):
    print('EPOCH{}:'.format(epoch+1))

    model.train(True)
    train_loss=train_one_epoch(model,epoch,train_loader, loss_function,optimizer) ##train for each epoch

    running_validation_loss=0.0

    model.eval()

    with torch.no_grad(): # Disable gradient computation and reduce memory consumption
      correct=0
      total=0
      for i,vdata in enumerate(validation_loader):
        vinputs,vlabels=vdata
        voutputs=model(vinputs)
        _,predicted=torch.max(voutputs.data,1)
        vloss=loss_function(voutputs,vlabels)
        running_validation_loss+=vloss
        total+=vlabels.size(0)
        correct+=(predicted==vlabels).sum().item()
    validation_loss=running_validation_loss/(i+1)
    validation_acc = 100*correct/total
    print(f'LOSS train: {train_loss} validation: {validation_loss} | validation_accuracy: {validation_acc}% ')

    if validation_loss<best_validation_loss: #save the model if it's the best so far
      timestamp=datetime.now().strftime('%Y%m%d_%H%M%S')
      best_validation_loss=validation_loss
      model_path='model_{}_{}'.format(timestamp,epoch)
      torch.save(model.state_dict(),model_path)

    train_losses.append(train_loss)
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_acc)


  plt.plot(train_losses, color='tab:red', linewidth=3, label='train loss')
  validation_losses_np = torch.stack(validation_losses).cpu().numpy() #move validation losses to cpu to plot with matplotlib
  plt.plot(validation_losses_np, color='tab:green', linewidth=3, label='validation loss')
  plt.xlabel('Epoch')
  plt.ylabel('CE loss')

  ax_right = plt.gca().twinx()
  #validation_accuracies_np = torch.stack(validation_accuracies).cpu().numpy() #move validation accuracies to cpu to plot with matplotlib
  ax_right.plot(validation_accuracies, color='tab:green', linestyle='--', label='validation accuracy')
  ax_right.set_ylabel('accuracy (%)')

  plt.gcf().legend(ncol=3)
  plt.gcf().set_size_inches(6, 3)

  return model_path



model_path=train_model(model,augmented_train_loader,validation_loader,loss_function,optimizer,30)

EPOCH1:
LOSS train: 2.6629636943340302 validation: 2.4765138626098633 | validation_accuracy: 18.666666666666668% 
EPOCH2:
LOSS train: 2.4052781015634537 validation: 2.2252233028411865 | validation_accuracy: 39.55555555555556% 
EPOCH3:
LOSS train: 2.2026075273752213 validation: 2.035402774810791 | validation_accuracy: 49.77777777777778% 
EPOCH4:
LOSS train: 2.0283912286162376 validation: 1.8644353151321411 | validation_accuracy: 55.55555555555556% 
EPOCH5:
LOSS train: 1.8935782179236411 validation: 1.728905439376831 | validation_accuracy: 60.888888888888886% 
EPOCH6:
LOSS train: 1.7894921600818634 validation: 1.6111842393875122 | validation_accuracy: 64.0% 
EPOCH7:
LOSS train: 1.6892506271600722 validation: 1.5114099979400635 | validation_accuracy: 65.33333333333333% 
EPOCH8:
LOSS train: 1.5956883952021599 validation: 1.4263393878936768 | validation_accuracy: 67.11111111111111% 
EPOCH9:
LOSS train: 1.5143862083554267 validation: 1.3539365530014038 | validation_accuracy: 67.1111111111111

KeyboardInterrupt: 

In [None]:
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms, models

# Assuming you have a DataLoader for your dataset named `data_loader`
# and the number of classes in your dataset is `num_classes`

# Step 1: Load the pre-trained AlexNet model
alexnet = models.alexnet(pretrained=True)

# Step 2: Freeze all layers except the last fully connected layer
for param in alexnet.parameters():
    param.requires_grad = False

# Step 3: Modify the last fully connected layer for your task
in_features = alexnet.classifier[6].in_features
alexnet.classifier[6] = nn.Linear(in_features, num_classes)

# Step 4: Define loss function, optimizer, and train the modified model
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(alexnet.parameters(), lr=0.001, momentum=0.9)

# Train the model for a few epochs (you can adjust the number of epochs)
num_epochs = 5
for epoch in range(num_epochs):
    for inputs, labels in data_loader:
        optimizer.zero_grad()
        outputs = alexnet(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Evaluate the model on the test set
# (make sure to use a separate test set that was not seen during training)
alexnet.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = alexnet(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Test Accuracy: {accuracy * 100:.2f}%')
