<a href="https://colab.research.google.com/github/cs-amy/project-codebase/blob/main/Word_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CNN Sliding-Window Model for 3-Letter Word De-Obfuscation**
Stage 2 of MSc Project — Ashraf Muhammed Yusuf

# **1. Colab Environment Setup**

In [1]:
# Import dependencies
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf, os, numpy as np
from tensorflow.keras import mixed_precision
from google.colab import drive
from tensorflow.keras.callbacks import (ModelCheckpoint, EarlyStopping, ReduceLROnPlateau)
from sklearn.metrics import classification_report, confusion_matrix

In [None]:
# 1.3 Mount Drive & define base path
# Mount Drive so you can read datasets and write checkpoints
# Link to dataset:
# https://drive.google.com/drive/folders/1sfNG1PkmTPBe1wOSQXZmfdkvR97Hn9lk?usp=sharing
drive.mount('/content/drive')
BASE_PATH = "/content/drive/MyDrive/MScProject"

In [None]:
# (Optional—but useful) turn on XLA JIT for extra speed
tf.config.optimizer.set_jit(True)

# **2. Load & Freeze the Single-Char Model**

In [None]:
base_model = tf.keras.models.load_model(f"{BASE_PATH}/char_cnn_ckpt_best")
base_model.trainable = False # freeze weights
print("Base model frozen — params:", base_model.count_params())

# **3. Dataset: 3-Letter Words**

In [None]:
IMG_H, IMG_W = 64, 64
BATCH = 64 if ACC != "CPU" else 16
AUTOTUNE = tf.data.AUTOTUNE

def preprocess(img, label):
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = tf.image.resize(img, [IMG_H, IMG_W])
  return img, label

raw_train = tf.keras.preprocessing.image_dataset_from_directory(
  f"{BASE_PATH}/data/words3/train",
  labels="inferred", label_mode="categorical",
  batch_size=BATCH, image_size=(IMG_H, IMG_W),
  color_mode="grayscale", seed=42
)
class_names = raw_train.class_names # capture once

train_ds = (raw_train
            .map(preprocess, num_parallel_calls=AUTOTUNE)
            .cache().shuffle(1000).prefetch(AUTOTUNE))

val_ds = (tf.keras.preprocessing.image_dataset_from_directory(
            f"{BASE_PATH}/data/words3/val", labels="inferred",
            label_mode="categorical", batch_size=BATCH,
            image_size=(IMG_H, IMG_W), color_mode="grayscale")
          .map(preprocess).cache().prefetch(AUTOTUNE))

test_ds = (tf.keras.preprocessing.image_dataset_from_directory(
            f"{BASE_PATH}/data/words3/test", labels="inferred",
            label_mode="categorical", batch_size=BATCH,
            image_size=(IMG_H, IMG_W), color_mode="grayscale", shuffle=False)
          .map(preprocess).cache().prefetch(AUTOTUNE))

# **4. Visual Sanity Check**

In [None]:
def show_examples(ds, title, n=6):
  imgs, lbls = next(iter(ds))
  plt.figure(figsize=(8,3))
  for i in range(n):
    plt.subplot(2, n//2, i+1)
    plt.imshow(imgs[i].numpy().squeeze(), cmap='gray')
    plt.title(class_names[lbls[i].numpy().argmax()])
    plt.axis('off')
  plt.suptitle(title); plt.show()

show_examples(train_ds, "Train Samples")

# **5. Build the Sliding-Window Model**

In [None]:
PATCH_W = IMG_W // 3

def extract_patch(x, idx):
  start = idx * PATCH_W
  return x[:, :, start:start+PATCH_W, :]

inputs = tf.keras.Input(shape=(IMG_H, IMG_W, 1))
logits = []

for i in range(3):
  patch = tf.keras.layers.Lambda(lambda z, i=i: extract_patch(z, i))(inputs)
  # Re-use frozen base_model (shared weights)
  logits.append(base_model(patch)) # shape (batch, 26)

concat = tf.keras.layers.Concatenate()(logits) # (batch, 78)
dense = tf.keras.layers.Dense(78, activation='softmax')(concat)

word_model = tf.keras.Model(inputs, dense)
word_model.compile(
  optimizer=tf.keras.optimizers.Adam(1e-3),
  loss='categorical_crossentropy',
  metrics=['accuracy']
)

# Print model summary
word_model.summary()

# **6. Callbacks**

In [None]:
CKPT_DIR = f"{BASE_PATH}/words3_ckpt_best.keras"

callbacks = [
  ModelCheckpoint(CKPT_DIR, save_best_only=True, monitor='val_loss'),
  EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
  ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
]

# **7. Train (Frozen Base)**

In [None]:
history = word_model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=20,
  callbacks=callbacks
)

# **8. Optional Fine-Tune Last Conv Block**

In [None]:
# Un-freeze last 3 layers of base_model
for layer in base_model.layers[-3:]:
  layer.trainable = True

# Re-compile with lower LR
word_model.compile(
  optimizer=tf.keras.optimizers.Adam(1e-4),
  loss='categorical_crossentropy',
  metrics=['accuracy']
)

ft_history = word_model.fit(
  train_ds, validation_data=val_ds,
  initial_epoch=history.epoch[-1]+1,
  epochs=history.epoch[-1]+4,
  callbacks=callbacks
)

# **9. Evaluation**

In [None]:
word_model = tf.keras.models.load_model(CKPT_DIR) # best checkpoint
test_loss, test_acc = word_model.evaluate(test_ds)
print(f"Test accuracy: {test_acc:.4f}")

# Util for plotting confusion matrix
def plot_confusion_matrix(cm, class_names, title="Confusion Matrix"):
  """
  Args:
      cm (np.ndarray): square confusion matrix
      class_names (List[str]): labels in the same order used to build cm
  """
  fig, ax = plt.subplots(figsize=(10, 9))
  im = ax.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
  ax.figure.colorbar(im, ax=ax, fraction=0.045)

  # axes & ticks
  ax.set(
    xticks=np.arange(len(class_names)),
    yticks=np.arange(len(class_names)),
    xticklabels=class_names,
    yticklabels=class_names,
    ylabel="True label",
    xlabel="Predicted label",
    title=title,
  )
  plt.setp(ax.get_xticklabels(), rotation=90, ha="center", va="center")

  # annotate cells
  thresh = cm.max() / 2.0
  for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
      ax.text(
        j, i, format(cm[i, j], "d"),
        ha="center", va="center",
        color="white" if cm[i, j] > thresh else "black",
        fontsize=8
      )

  fig.tight_layout()
  plt.show()

# Classification report
y_pred, y_true = [], []
for x, y in test_ds:
  y_pred.extend(np.argmax(word_model.predict(x), axis=1))
  y_true.extend(np.argmax(y.numpy(), axis=1))
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion matrix heat-map (optional)
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, class_names, title="3-Letter Word Confusion Matrix")

# **10. Qualitative Error Analysis**

In [None]:
# Plot a few misclassified 3-letter words
mis_idx = [i for i,(t,p) in enumerate(zip(y_true, y_pred)) if t != p]
show_examples(test_ds.unbatch().skip(mis_idx[0]), "Misclassified example")