In [None]:
import os
import glob
import sklearn
from sklearn.model_selection import train_test_split

import PIL
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchinfo import summary

import torch.optim as optim
from IPython.display import Image
from torch.utils.data import DataLoader, Dataset
from galaxy_mnist import GalaxyMNIST

from torchvision.datasets import ImageFolder

from torchvision.transforms import transforms
from torch.utils.data import TensorDataset

import cv2
import torchvision.models as models

In [None]:
# Use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
#  Import galaxyMNIST dataset
dataset = GalaxyMNIST(
    root='ML_DP/gal_mnist',
    download=True,
    train=False  # by default, or False for canonical test set
)

images = dataset.data
labels = dataset.targets

# Split in test, validation, train
images_tv, images_test, y_tv, y_test  = train_test_split(images, labels, shuffle=True, test_size=0.2, random_state=123)
images_train, images_val, y_train, y_val  = train_test_split(images_tv, y_tv, shuffle=True, test_size=0.25, random_state=123)

# Create PyTorch dataset
train_dataset = TensorDataset(images_train, y_train)
val_dataset = TensorDataset(images_val, y_val)
test_dataset = TensorDataset(images_test, y_test)

In [None]:
VGG_model = models.vgg16(pretrained=True)

for name, param in VGG_model.named_parameters():
    param.requires_grad = False

# define out classifier
binary_classifier = nn.Sequential(
   nn.Linear(in_features=25088, out_features=2048),
   nn.ReLU(),
   nn.Linear(in_features=2048, out_features=1024),
   nn.ReLU(),
   nn.Linear(in_features=1024, out_features=512),
   nn.ReLU(),
   nn.Linear(in_features=512, out_features=4)
)

# replace model class classifier attribute:
VGG_model.classifier = binary_classifier

In [40]:
# define training function

def train_model(model, train_dataset, val_dataset, test_dataset, device,
                lr=0.0001, epochs=30, batch_size=32, l2=0.00001, gamma=0.5,
                patience=7):
    model = model.to(device)

    # construct dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # history
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    # set up loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2)  # pass in the parameters to be updated and learning rate
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=patience, gamma=gamma)

    # Training Loop
    print("Training Start:")
    for epoch in range(epochs):
        model.train()  # start to train the model, activate training behavior

        train_loss = 0
        train_acc = 0
        val_loss = 0
        val_acc = 0

        for i, (images, labels) in enumerate(train_loader):
            # reshape images
            images = images.to(device)
            images = images.float()# reshape: from (128, 1, 28, 28) -> (128, 28 * 28) = (128, 284), move batch to device
            labels = labels.to(device)  # move to device
            # forward
            outputs = model(images)  # forward
            
            _, pred = torch.max(outputs.data, 1)
            
            cur_train_loss = criterion(outputs, labels)  # loss
            cur_train_acc = (pred == labels).sum().item() / batch_size

            # backward
            cur_train_loss.backward()   # run back propagation
            optimizer.step()            # optimizer update all model parameters
            optimizer.zero_grad()       # set gradient to zero, avoid gradient accumulating

            # loss
            train_loss += cur_train_loss
            train_acc += cur_train_acc

        # valid
        model.eval()  # start to train the model, activate training behavior
        with torch.no_grad():  # tell pytorch not to update parameters
            for images, labels in val_loader:
                # calculate validation loss
                images = images.to(device)
                labels = labels.to(device)
                images = images.float()
                
                outputs = model(images)

                # loss
                cur_valid_loss = criterion(outputs, labels)
                val_loss += cur_valid_loss
                # acc
                _, pred = torch.max(outputs.data, 1)
                val_acc += (pred == labels).sum().item() / batch_size

        # learning schedule step
        scheduler.step()

        # print training feedback
        train_loss = train_loss / len(train_loader)
        train_acc = train_acc / len(train_loader)
        val_loss = val_loss / len(val_loader)
        val_acc = val_acc / len(val_loader)

        print(f"Epoch:{epoch + 1} / {epochs}, lr: {optimizer.param_groups[0]['lr']:.5f} train loss:{train_loss:.5f}, train acc: {train_acc:.5f}, valid loss:{val_loss:.5f}, valid acc:{val_acc:.5f}")

        # update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

    test_acc = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            images = images.float()
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, pred = torch.max(outputs.data, 1)
            test_acc += (pred == labels).sum().item()

    print(f'Test Accuracy:  {(test_acc / len(test_loader))}')

    return history

In [None]:
# Train the CNN model
hist = train_model(VGG_model, train_dataset, val_dataset, test_dataset, device, lr=0.0001, batch_size=32, epochs=20, l2=0.2
                   , patience=15)

Training Start:
Epoch:1 / 20, lr: 0.00010 train loss:1.29270, train acc: 0.46382, valid loss:1.04060, valid acc:0.53365
Epoch:2 / 20, lr: 0.00010 train loss:0.84632, train acc: 0.61678, valid loss:0.91110, valid acc:0.55048
Epoch:3 / 20, lr: 0.00010 train loss:0.79488, train acc: 0.63158, valid loss:0.82139, valid acc:0.62740
Epoch:4 / 20, lr: 0.00010 train loss:0.75555, train acc: 0.65378, valid loss:0.76019, valid acc:0.66587
Epoch:5 / 20, lr: 0.00010 train loss:0.75766, train acc: 0.64803, valid loss:0.81753, valid acc:0.62260


In [None]:
# plot training curves
epochs = range(1, len(hist['train_loss']) + 1)

fig, ax = plt.subplots(1,2, figsize=(20,6))
ax[0].plot(epochs, hist['train_loss'], 'r-', label='Train')
ax[0].plot(epochs, hist['val_loss'], 'b-', label='Evaluation')
ax[0].set_title('Loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].legend()


ax[1].plot(epochs, hist['train_acc'], 'r-', label='Train')
ax[1].plot(epochs, hist['val_acc'], 'b-', label='Evaluation')
ax[1].set_title('Accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Acc')
ax[1].legend()

plt.show()