<a href="https://colab.research.google.com/github/camao-tec/transfer-learning-demo/blob/main/transfer_learning_metal_defects.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from tqdm.notebook import tqdm

%matplotlib inline

In [None]:
torch.cuda.empty_cache()

## Dataset download
 To get the dataset, follow these steps:
 1. Open the dataset on kaggle: https://www.kaggle.com/datasets/kaustubhdikshit/neu-surface-defect-database
 2. Click *Download* (login required) and save the zip archive
 3. Unpack the zip archive and upload the `NEU-DET` folder into the `data` directory

### Brief overview of the data

In [None]:
SIZE_OF_IMAGE = 112
image_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((SIZE_OF_IMAGE, SIZE_OF_IMAGE)),
])
sample_folder = torchvision.datasets.ImageFolder('data/NEU-DET/train/images', transform=image_transforms, )

In [None]:
# Get the classes present in the dataset
classes = sample_folder.classes

# Plot three samples of each class
fig, ax = plt.subplots(3, len(classes), figsize=(10, 5))

for i in range(3):
    for j, class_name in enumerate(classes):
        idx = sample_folder.class_to_idx[class_name]
        indices = torch.tensor(sample_folder.targets) == idx
        sample_idx = torch.nonzero(indices)[i].item()
        image = sample_folder[sample_idx][0]
        
        ax[i][j].imshow(image)
        ax[i][j].axis('off')
        if i == 0:
            ax[i][j].set_title(class_name)
fig.tight_layout()
plt.show()

### Load dataset for training and evaluation

In [None]:
BATCH_SIZE = 8
SIZE_OF_IMAGE = 112

prediction_trans = torchvision.transforms.Compose([
    torchvision.transforms.Resize((SIZE_OF_IMAGE, SIZE_OF_IMAGE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

train_folder = torchvision.datasets.ImageFolder('data/NEU-DET/train/images', transform=prediction_trans, )
test_folder = torchvision.datasets.ImageFolder('data/NEU-DET/validation/images', transform=prediction_trans, )

train_loader = torch.utils.data.DataLoader(train_folder, shuffle=True, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_folder, shuffle=False, batch_size=BATCH_SIZE)

## Model instantiation and adaption for specific use-case

In [None]:
# use the `cuda` device (GPU) if available
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
device

In [None]:
def get_transfermodel():
    # 1) get pretrained model from the `torchvision` model library
    model = torchvision.models.vgg19(weights="VGG19_Weights.IMAGENET1K_V1").to(device)

    # 2) freeze layers
    for param in model.features.parameters():
        param.requires_grad = False

    # 3) adapt output for our specific task with 6 classes
    num_classes = 6
    num_features = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(num_features, num_classes)
    # ensure we retrain the entire classifier (all trainable parameters require gradients)
    for param in model.classifier.parameters():
        param.requires_grad = True
    
    return model.to(device)

In [None]:
# let's have a look on the entire model architecture
transfer_model = get_transfermodel()
transfer_model

In [None]:
# let's have a look on the frozen parameters of the pretrained model
for i, param in enumerate(transfer_model.features.parameters()):
    print(f"Param #{i:0>3d}: shape = {param.shape} , requires_grad = {param.requires_grad}")

In [None]:
# let's have a look on the trainable parameters of output layers
for i, param in enumerate(transfer_model.classifier.parameters()):
    print(f"Param #{i:0>3d}: shape = {param.shape} | requires_grad = {param.requires_grad}")

## Model training and evaluation

In [None]:
# create a directory for the models
os.makedirs('models', exist_ok=True)

In [None]:
def validate(model, valid_data, loss_fn):
    valid_losses, valid_accuracies = [], []
    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in tqdm(valid_data, leave=False):
            X_batch, y_batch = X_batch.to(device).float(), y_batch.to(device).long()
            logits = model(X_batch)
            loss = loss_fn(logits, y_batch)
            valid_losses.append(loss.item())
            preds = torch.argmax(logits, axis=1)
            valid_accuracies.append(((preds == y_batch).sum() / len(preds)).item())
    return np.mean(valid_losses), np.mean(valid_accuracies)

In [None]:
def train(model, train_data, valid_data, loss_fn, opt, epoches, name):
    train_losses, valid_losses = [], []
    train_accuracies, valid_accuracies = [], []
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=4, gamma=0.5)
    
    for epoch in tqdm(range(epoches)):
        train_loss = []
        train_acc = []
        model.train()
        for X_batch, y_batch in tqdm(train_data, leave=False):
            opt.zero_grad()
            X_batch, y_batch = X_batch.to(device).float(), y_batch.to(device).long()
            logits = model(X_batch)
            loss = loss_fn(logits, y_batch,)
            train_loss.append(loss.item())

            pred = torch.argmax(logits, dim=1)
            train_acc.append(((pred == y_batch).sum() / len(pred)).item())
            loss.backward()
            opt.step()
        scheduler.step()

        valid_loss, valid_accuracy = validate(model, valid_data, loss_fn)

        train_accuracies.append(np.mean(train_acc))
        train_losses.append(np.mean(train_loss))
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)

        print(f'epoch: {epoch}: train_loss: {np.mean(train_losses)}, train_acc: {np.mean(train_acc)}, val_loss: {valid_loss}, val_acc: {valid_accuracy}')
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss': loss_fn,
        }, f'models/{name}_{epoch}.pt')

    return model, train_losses, train_accuracies, valid_losses, valid_accuracies

In [None]:
# training setup and hyper parameters
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(transfer_model.parameters(), lr=1e-5)
EPOCHES = 10

In [None]:
# train model
transfer_model, train_losses, train_accuracies, valid_losses, valid_accuracies = train(
    transfer_model, train_loader, test_loader, loss_fn, opt, epoches=EPOCHES, name="model_epoch"
)

In [None]:
def plot_losses_and_acc(train_losses, train_accuracies, valid_losses, valid_accuracies):
    fig, axes = plt.subplots(1, 2, figsize=(15, 10))
    epochs = np.arange(1, len(train_losses) + 1)
    axes[0].plot(epochs, train_losses)
    axes[0].plot(epochs, valid_losses)
    axes[0].set_title('Losses')
    axes[0].legend(['Training', 'Validation'])
    axes[0].set_xlim([1, len(train_losses)])

    axes[1].plot(epochs, train_accuracies)
    axes[1].plot(epochs, valid_accuracies)
    axes[1].set_title('Accuracy')
    axes[1].legend(['Training', 'Validation'])
    axes[1].set_xlim([1, len(train_accuracies)])

In [None]:
plot_losses_and_acc(train_losses, train_accuracies, valid_losses, valid_accuracies)

## Model Prediction
Now that we are satisfied with the model accuracy on the test data, let's see some predictions.
Load the trained model from disk and predict the classes of some images.

In [None]:
# load the model from disk
best_model_epoch = 5
checkpoint = torch.load(f'models/model_epoch_{best_model_epoch}.pt')
model = get_transfermodel()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device) ;

In [None]:
SIZE_OF_IMAGE = 112
image_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((SIZE_OF_IMAGE, SIZE_OF_IMAGE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])
sample_folder_prediction = torchvision.datasets.ImageFolder('data/NEU-DET/validation/images', transform=image_transforms)

In [None]:
# Get the classes present in the dataset
classes = sample_folder_prediction.classes

# Plot two samples with their corresponding label and prediction per class
fig, ax = plt.subplots(len(classes), 2, figsize=(6, 16))

for j in range(2):
    for i, label_class_name in enumerate(classes):
        idx = sample_folder_prediction.class_to_idx[label_class_name]
        indices = torch.tensor(sample_folder_prediction.targets) == idx
        sample_idx = torch.nonzero(indices)[j].item()
        image = sample_folder_prediction[sample_idx][0]

        image_input = image.unsqueeze(0).to(device).float()
        logits = model(image_input)
        prediction = torch.argmax(logits, dim=1).to(device="cpu").squeeze().item()
        prediction_class_name = classes[prediction]

        image_transformed = image.permute(1, 2, 0).to(device="cpu").numpy()
        ax[i][j].imshow((image_transformed + 1) / 2)
        ax[i][j].axis('off')
        ax[i][j].set_title(f'Label = `{label_class_name}`\nPrediction = `{prediction_class_name}`')
fig.tight_layout()
plt.show()