# Semi-supervised image classification using Pseudo Labeling CNN

## Setup

In [None]:
import datetime
import os
import pathlib
import shutil

import cv2
import pandas as pd

os.environ["KERAS_BACKEND"] = "tensorflow"


# Make sure we are able to handle large datasets
import resource

low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))

import math
import time

import keras

%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from IPython.display import display
from ipywidgets import IntProgress
from keras import layers
from sklearn.metrics import classification_report, confusion_matrix

import tensorflow_datasets as tfds

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Set seed for evaluation purpose (remove in production)
keras.utils.set_random_seed(5)
os.environ["PYTHONHASHSEED"] = str(5)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
np.random.seed(5)
tf.random.set_seed(5)
tf.config.experimental.enable_op_determinism()

## Hyperparameter setup

In [None]:
# Dataset hyperparameters
unlabeled_dataset_path = "../data_ssl/unlabeled/"
labeled_dataset_path = "../data_ssl/train/"

unlabeled_dataset_size = sum(
    [len(files) for r, d, files in os.walk(unlabeled_dataset_path)]
)
labeled_dataset_size = sum(
    [len(files) for r, d, files in os.walk(labeled_dataset_path)]
)
img_height = 224
img_width = 224
width = 224
num_epochs = 50
batch_size = 30

print("Unlabeled Images: " + str(unlabeled_dataset_size))
print("Labeled Images: " + str(labeled_dataset_size))

## Dataset

In [None]:
def prepare_dataset():
    labeled_train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
        labeled_dataset_path,
        validation_split=0.4,
        subset="both",
        seed=5,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        color_mode="rgb",
        label_mode="categorical",
    )

    # Define test dataset
    test_dataset = tf.keras.utils.image_dataset_from_directory(
        "../data_ssl/test_labeled/",
        seed=5,
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        color_mode="rgb",
        label_mode="categorical",
    )

    return labeled_train_ds, val_ds, test_dataset


# Load dataset
labeled_train_dataset, validation_dataset, test_dataset = prepare_dataset()

num_classes = len(labeled_train_dataset.class_names)

## Image augmentations

In [None]:
# Plot example classes
def visualize_augmentations(num_images):
    # Sample a batch from a dataset
    images = next(iter(labeled_train_dataset))[0][:num_images]

    augmented_images = zip(
        keras.Sequential([layers.Rescaling(1 / 255)])(images),
    )
    row_titles = [
        "Images:",
    ]
    plt.figure(figsize=(num_images * 1.6, 2), dpi=100)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            plt.subplot(1, num_images, row * num_images + column + 1)
            plt.imshow(image)
            if column == 0:
                plt.title(row_titles[row], loc="left")
            plt.axis("off")
    plt.tight_layout()
    plt.show()


visualize_augmentations(num_images=8)

## Supervised baseline model

In [None]:
# Baseline supervised training with random initialization
pretrained_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
    input_shape=(img_width, img_height, 3),
    include_top=False,
    weights="imagenet",
    pooling="avg",
)
pretrained_model.trainable = False

inputs = pretrained_model.input
x = tf.keras.layers.Dense(132, activation="relu")(pretrained_model.output)
x = tf.keras.layers.Dropout(0.4)(x)
x = tf.keras.layers.Dense(64, activation="relu")(x)
x = tf.keras.layers.Dropout(0.4)(x)
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(x)
baseline_model = tf.keras.Model(
    inputs,
    outputs,
    name="baseline_model",
)

# Compile model
baseline_model.compile(
    optimizer=tf.keras.optimizers.Nadam(learning_rate=0.001),
    loss=keras.losses.CategoricalCrossentropy(from_logits=False),
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
]

# Fit model (training)
baseline_history = baseline_model.fit(
    labeled_train_dataset,
    epochs=num_epochs,
    validation_data=validation_dataset,
    callbacks=callbacks,
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

## CNN Model

In [None]:
# Load images from files
image_filepath = []

for file in pathlib.Path(unlabeled_dataset_path).rglob("*.jpg"):
    if file.is_file():
        image_filepath.append(str(file))

df = pd.DataFrame({"image_filepath": image_filepath})
print(df)

In [None]:
max_count = len(df)

f = IntProgress(min=0, max=max_count)  # instantiate the bar
display(f)  # display the bar

unlabeled_train_dataset = []
pseudo_labels = []

# Predict pseudo labels
for index, row in df.iterrows():
    image = cv2.imread(row["image_filepath"])
    if image.shape != (img_width, img_height, 3):
        image = cv2.resize(image, (img_width, img_height))
    image = image.reshape((1, img_width, img_height, 3))
    pseudo_labels.append(baseline_model.predict(image, verbose=0))
    f.value += 1

In [None]:
# Save predictions and model for later use
baseline_model.save("pseudo2_eff.keras")
np.save("pseudo2_eff.npy", pseudo_labels)

In [None]:
# Define finetuned model
folder = labeled_dataset_path
labels = sorted(
    [name for name in os.listdir(folder) if os.path.isdir(os.path.join(folder, name))]
)

pretrained_finetuned_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
    input_shape=(img_width, img_height, 3), include_top=False, weights="imagenet", pooling="avg"
)
pretrained_finetuned_model.trainable = False

inputs = pretrained_finetuned_model.input
x = tf.keras.layers.Dense(640, activation="relu")(pretrained_finetuned_model.output)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Dense(320, activation="relu")(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(x)
finetuned_model = tf.keras.Model(
    inputs,
    outputs,
    name="finetuned_model",
)

# Compile model
finetuned_model.compile(
    optimizer=tf.keras.optimizers.Nadam(learning_rate=0.001),
    loss=keras.losses.CategoricalCrossentropy(from_logits=False),
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

pseudo_labels = np.load("pseudo2_eff.npy")
baseline_model = keras.models.load_model("pseudo2_eff.keras")

In [None]:
# Select most confident pseudo-label predictions
confident = []
most_confident = []
not_confident_indices = []
labels_list = []
for i, conf in enumerate(pseudo_labels):
    conf = conf[0]
    index = np.argmax(conf)
    confident.append(index)
    labels_list.append(labels[index])
    if conf[index] > 0.995:
        most_confident.append(labels[index])
    else:
        not_confident_indices.append(i)

print(len(confident))
print(len(most_confident))

In [None]:
# Drop unconfident predictions
data_df = df.drop(not_confident_indices).reset_index(drop=True)
data_df["label"] = most_confident

In [None]:
# Normalize class distribution
value_counts = data_df["label"].value_counts().sort_index()
value_counts_normalized = data_df["label"].value_counts(normalize=True).sort_index()
print(value_counts)
print(data_df)
min = value_counts.min()
for i, count in enumerate(value_counts):
    drop_indices = np.random.choice(
        data_df.loc[data_df["label"] == labels[i]].index,
        max(0, count - int(min * math.pow(value_counts_normalized.iloc[i] + 1, 2))),
        replace=False,
    )
    data_df = data_df.drop(drop_indices).reset_index(drop=True)
print(data_df["label"].value_counts().sort_index())

In [None]:
# Load images from files
image_filepath = []
anomaly_class = []

# Concenate labeled and pseudo-labeled images
for file in pathlib.Path(labeled_dataset_path).rglob("*.jpg"):
    if file.is_file():
        image_filepath.append(str(file))
        anomaly_class.append(str(file).split("/")[3])
for i, x in enumerate(image_filepath):
    data_df.loc[df.index.max() + i] = [image_filepath[i], anomaly_class[i]]
data_df = data_df.reset_index(drop=True)
print(data_df)

In [None]:
# Load labeled and unlabeled images in TensorFlow Dataset and apply augmentation
generator = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=30,
    fill_mode='nearest',
    width_shift_range=0.15,
    height_shift_range=0.15,
    horizontal_flip=True,
    vertical_flip=True,
    brightness_range=[0.8, 1.2],
    zoom_range=0.2,
    validation_split=0.2
)

data = generator.flow_from_dataframe(
    dataframe=data_df,
    x_col="image_filepath",
    y_col="label",
    subset="training",
    target_size=(img_width, img_height),
    color_mode="rgb",
    class_mode="categorical",
    classes=labels,
    batch_size=batch_size,
    seed=5,
    shuffle=True,
)

data_val = generator.flow_from_dataframe(
    dataframe=data_df,
    x_col="image_filepath",
    y_col="label",
    subset="validation",
    target_size=(img_width, img_height),
    color_mode="rgb",
    class_mode="categorical",
    classes=labels,
    batch_size=batch_size,
    seed=5,
    shuffle=True,
)

In [None]:
# Retrain the model on the combined data
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=5, restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
]

# Fit the finetuned model (training)
finetuned_history = finetuned_model.fit(
    data, epochs=num_epochs, validation_data=validation_dataset, callbacks=callbacks
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuned_history.history["val_acc"]) * 100
    )
)

## Comparison against the baseline

In [None]:
# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(finetuning_history, baseline_history):
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        plt.figure(figsize=(8, 5), dpi=100)
        plt.plot(
            baseline_history.history[f"val_{metric_key}"],
            label="supervised baseline",
        )
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        plt.legend()
        plt.title(f"Classification {metric_name} durforing training")
        plt.xlabel("epochs")
        plt.ylabel(f"validation {metric_name}")
        plt.show()


plot_training_curves(finetuned_history, baseline_history)

In [None]:
# Evaluate the base model
results_baseline = baseline_model.evaluate(test_dataset, verbose=0)
print(f"Baseline Test Accuracy: {np.round(results_baseline[1] * 100,2)}%")

# Evaluate the finetuned model
results_finetuned = finetuned_model.evaluate(test_dataset, verbose=0)
print(f"Finetuned Test Accuracy: {np.round(results_finetuned[1] * 100,2)}%")

In [None]:
y_pred_baseline = []  # store predicted labels
y_pred_finetuned = []  # store predicted labels
y_true = []  # store true labels

max_count = int(len(test_dataset))  # reduce dataset (change below)

f = IntProgress(min=0, max=max_count)  # instantiate the bar
display(f)  # display the bar

# iterate over the dataset
# for image_batch, label_batch in test_dataset.take(max_count):   # use dataset.unbatch() with repeat
for image_batch, label_batch in test_dataset:  # use dataset.unbatch() with repeat
    # append true labels
    y_true.append(label_batch)
    # compute predictions
    preds = baseline_model.predict(image_batch, verbose=0)
    # append predicted labels
    y_pred_baseline.append(np.argmax(preds, axis=-1))

    # compute predictions
    preds = finetuned_model.predict(image_batch, verbose=0)
    # append predicted labels
    y_pred_finetuned.append(np.argmax(preds, axis=-1))

    f.value += 1

# convert the true and predicted labels into tensors
correct_labels = tf.concat([item for item in y_true], axis=0)
predicted_labels_baseline = tf.concat([item for item in y_pred_baseline], axis=0)
predicted_labels_finetuned = tf.concat([item for item in y_pred_finetuned], axis=0)

correct_labels = correct_labels.numpy().argmax(axis=1)

# Generate reports
matrix_baseline = confusion_matrix(correct_labels, predicted_labels_baseline)
matrix_finetuned = confusion_matrix(correct_labels, predicted_labels_finetuned)
report_baseline = classification_report(
    correct_labels,
    predicted_labels_baseline,
    target_names=labels,
    zero_division=0,
)
report_finetuned = classification_report(
    correct_labels,
    predicted_labels_finetuned,
    target_names=labels,
    zero_division=0,
)

In [None]:
# Print example predictions from base- and finetuned-model
num_images = 10

images = next(iter(test_dataset))[0][:num_images]

# Apply augmentations
augmented_images = zip(
    keras.Sequential([layers.Rescaling(1 / 255)])(images),
)
row_titles = [
    "Images:",
]
plt.figure(figsize=(num_images * 1.3, 1.3), dpi=100)
for column, image_row in enumerate(augmented_images):
    for row, image in enumerate(image_row):
        plt.subplot(1, num_images, row * num_images + column + 1)
        plt.imshow(image)
        if column == 0:
            plt.title(row_titles[row], loc="left")
        plt.axis("off")
plt.tight_layout()
plt.show()

print(
    "Baseline: {}".format(
        np.array(labels)[np.argmax(baseline_model.predict(images, verbose=0), axis=-1)]
    )
)
print(
    "Finetuned: {}".format(
        np.array(labels)[np.argmax(finetuned_model.predict(images, verbose=0), axis=-1)]
    )
)

In [None]:
# Plot confusion matrix
fig = plt.figure(figsize=(10, 10))
sns.heatmap(matrix_baseline, annot=True, cmap="viridis", fmt='g')
plt.xticks(ticks=np.arange(num_classes) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(num_classes) + 0.5, labels=labels, rotation=0)
plt.title("Confusion Matrix (Baseline)")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

fig = plt.figure(figsize=(10, 10))
sns.heatmap(matrix_finetuned, annot=True, cmap="viridis", fmt='g')
plt.xticks(ticks=np.arange(num_classes) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(num_classes) + 0.5, labels=labels, rotation=0)
plt.title("Confusion Matrix (Finetuned)")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

In [None]:
# Print classification report
print("Classification Report (Baseline):\n", report_baseline)
print("Classification Report (Finetuned):\n", report_finetuned)

In [None]:
# Load TensorBoard
%tensorboard --logdir logs