<a href="https://colab.research.google.com/github/cs-amy/project-codebase/blob/main/Word_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CNN Sliding-Window Model for 3-Letter Word De-Obfuscation**
Stage 2 of MSc Project — Ashraf Muhammed Yusuf

# **1. Colab Environment Setup**

In [None]:
# Install dependencies
!pip install -q tensorflow matplotlib

In [None]:
# Import dependencies
import os, random, itertools, pathlib, math, shutil
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf, os, numpy as np
from glob import glob
from tensorflow.keras import mixed_precision
from google.colab import drive
from tensorflow.keras.callbacks import (ModelCheckpoint, EarlyStopping, ReduceLROnPlateau)
from sklearn.metrics import classification_report, confusion_matrix
from collections import defaultdict
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

In [None]:
# 1.3 Mount Drive & define base path
# Mount Drive so you can read datasets and write checkpoints
# Link to dataset:
# https://drive.google.com/drive/folders/1sfNG1PkmTPBe1wOSQXZmfdkvR97Hn9lk?usp=sharing
drive.mount('/content/drive')

In [None]:
# (Optional—but useful) turn on XLA JIT for extra speed
tf.config.optimizer.set_jit(True)

In [None]:
# ------------------------------------------------------------------
# 3-Letter Word Dataset Generator (single Colab cell)
# ------------------------------------------------------------------
# - Builds data/words3/{train|val|test}/{cls}/img*.png
# - Each cls = 3-letter word from aaa … zzz  (26³ = 17 576 classes)
# - 5 variants per word (random font/glyph + optional obfuscation)
# - Re-uses single-letter glyphs created in Stage-1 of the project
# ------------------------------------------------------------------

# Define paths & knobs
BASE_PATH = "/content/drive/MyDrive/MScProject"
GLYPH_DIR = f"{BASE_PATH}/data/characters/train" # Stage-1 glyphs
OUT_ROOT = f"{BASE_PATH}/data/words3"
IMG_H, IMG_W = 64, 64
PATCH_W = IMG_W // 3 # 21 px
VARIANTS_PER_WORD = 5
SPREAD = (0.5, 0.4, 0.1)

# Obfuscation maps - leet
LEET = {
  'A': ['Α', '4', 'Д', 'Ä', 'Á', 'À', 'Â', '@', 'Δ'],
  'B': ['8', 'β', 'Β', 'В'],
  'C': ['Ç', 'Ć', 'Č', 'С'],
  'D': ['Ð', 'Ď'],
  'E': ['3', 'Σ', 'Έ', 'Ε', 'Е', 'Ë', 'É', 'È', 'Ê'],
  'F': ['Φ', 'Ƒ'],
  'G': ['6', 'Ğ', 'Ģ', 'Γ'],
  'H': ['Η', 'Н'],
  'I': ['1', '|', 'Í', 'Ì', 'Î', 'Ï', 'И'],
  'J': ['Ј'],
  'K': ['Κ', 'К'],
  'L': ['Ι', 'Ł', 'Ĺ', 'Л'],
  'M': ['Μ', 'М'],
  'N': ['Ν', 'Ń', 'Ñ', 'Н'],
  'O': ['0', 'Θ', 'Ο', 'Ө', 'Ø', 'Ö', 'Ó', 'Ò', 'Ô'],
  'P': ['Ρ', 'Р'],
  'Q': ['Φ'],
  'R': ['®', 'Я', 'Ř', 'Ŕ'],
  'S': ['5', '$', 'Ѕ', 'Ś', 'Š'],
  'T': ['Τ', 'Т'],
  'U': ['Υ', 'Ц', 'Ü', 'Ú', 'Ù', 'Û'],
  'V': ['Ѵ', 'V'],
  'W': ['Ω', 'Ѡ', 'Ψ', 'Ш', 'Щ'],
  'X': ['Χ', 'Ж', 'Х'],
  'Y': ['Υ', 'Ү', 'Ý', 'Ÿ'],
  'Z': ['Ζ', 'Ż', 'Ź', 'Ž', 'З', '2']
}

# Obfuscation maps - cyrillic look-alikes
HOMO = {
  'A':'Α',
  'B':'Β',
  'C':'С',
  'E':'Ε',
  'H':'Н',
  'K':'Κ',
  'M':'Μ',
  'O':'О',
  'P':'Р',
  'T':'Τ',
  'X':'Χ',
  'Y':'Υ',
  'Z':'Ζ'
}

# Obfuscation helper
def obfuscate(char):
  mode = random.choices(
    ("none", "leet", "homo"),
    weights=SPREAD
  )[0]
  if mode == "leet" and char in LEET:
    return random.choice(LEET[char])
  if mode == "homo" and char in HOMO:
    return HOMO[char]
  return char

# Word-stitch helper
def stitch_word(word, save_path):
  """
  Compose a 64×64 grayscale PNG for a 3-char word.
  Uses Stage-1 glyphs 70% of the time, else renders via PIL text.
  """
  canvas = Image.new("L", (IMG_W, IMG_H), color=255)
  for idx, ch in enumerate(word):
    # choose glyph source
    if random.random() < 0.7: # reuse Stage-1 glyph
      glyph = Image.open(random.choice(letter_pool[ch]))
      glyph = glyph.resize((PATCH_W, IMG_H))
    else: # render fresh
      glyph = Image.new("L", (PATCH_W, IMG_H), color=255)
      draw  = ImageDraw.Draw(glyph)
      draw.text((4, 4), obfuscate(ch), fill=0)
    canvas.paste(glyph, (idx * PATCH_W, 0))
  if random.random() < 0.3: # light affine jitter
    dx = random.randint(-2, 2)
    canvas = canvas.transform(canvas.size, Image.AFFINE, (1,0,dx,0,1,0))
  canvas.save(save_path)

# Letter-image reservoir
letter_pool = defaultdict(list)
for letter in "ABCDEFGHIJKLMNOPQRSTUVWXYZ":
  letter_pool[letter] = glob(f"{GLYPH_DIR}/{letter}/*.png")
  assert letter_pool[letter], f"No glyphs found for '{letter}'"

# SAFE-GUARD: Skip everything if dataset already exists
if os.path.exists(OUT_ROOT) and any(os.scandir(OUT_ROOT)):
  print("words3 dataset already exists — skipping generation.")
else:
  # Build word lists & splits
  words = ["".join(p) for p in itertools.product("ABCDEFGHIJKLMNOPQRSTUVWXYZ", repeat=3)]
  random.shuffle(words)
  n = len(words)
  splits = {
    "train": words[:int(0.70*n)],
    "val"  : words[int(0.70*n):int(0.85*n)],
    "test" : words[int(0.85*n):]
  }

  # Generate images
  for split_name, word_list in splits.items():
    for word in tqdm(word_list, desc=f"Generating {split_name}"):
      cls_dir = pathlib.Path(OUT_ROOT, split_name, word)
      cls_dir.mkdir(parents=True, exist_ok=True)
      for k in range(VARIANTS_PER_WORD):
        out_file = cls_dir / f"{word}_{k}.png"
        stitch_word(word, out_file)

  # Print results
  print("Dataset generation complete.")
  print("train / val / test words :", [len(splits[s]) for s in ('train','val','test')])
  print("images per split (×5)    :", [len(splits[s])*VARIANTS_PER_WORD for s in ('train','val','test')])

# **2. Load & Freeze the Single-Char Model**

In [None]:
base_model = tf.keras.models.load_model(f"{BASE_PATH}/char_cnn_ckpt_best.keras")
base_model.trainable = False # freeze weights
print("Base model frozen — params:", base_model.count_params())

# **3. Dataset: 3-Letter Words**

In [None]:
BATCH = 64 if tf.config.list_physical_devices('TPU') else 32
IMG_H, IMG_W = 64, 64
IMG_SHAPE   = (IMG_H, IMG_W)
NUM_CLASSES = len(train_ds_class_names) # = 12303
AUTOTUNE = tf.data.AUTOTUNE

def preprocess(img, label):
  img = tf.image.convert_image_dtype(img, tf.float32)
  img = tf.image.resize(img, IMG_SHAPE)
  label_oh = tf.one_hot(label, NUM_CLASSES) # pad
  return img, label_oh

def make_ds(dir_path, shuffle, batch):
  raw = tf.keras.preprocessing.image_dataset_from_directory(
        dir_path,
        labels="inferred",
        label_mode="int", # int first
        batch_size=batch,
        image_size=IMG_SHAPE,
        color_mode="grayscale",
        shuffle=shuffle,
        seed=42
  )
  # capture the folder names before we lose them
  class_names = raw.class_names

  ds = (
    raw.map(preprocess, num_parallel_calls=AUTOTUNE)
      .cache()
      .shuffle(1000) if shuffle else raw.map(preprocess)
  )
  ds = ds.prefetch(AUTOTUNE)
  return ds, class_names

# Create datasets
train_ds, train_ds_class_names = make_ds(f"{OUT_ROOT}/train", shuffle=True,  batch=BATCH)
val_ds, val_ds_class_names = make_ds(f"{OUT_ROOT}/val",   shuffle=False, batch=BATCH)
test_ds, test_ds_class_names = make_ds(f"{OUT_ROOT}/test",  shuffle=False, batch=BATCH)

# **4. Visual Sanity Check**

In [None]:
def show_examples(ds, title, n=6):
  imgs, lbls = next(iter(ds))
  plt.figure(figsize=(8,3))
  for i in range(n):
    plt.subplot(2, n//2, i+1)
    plt.imshow(imgs[i].numpy().squeeze(), cmap='gray')
    plt.title(class_names[lbls[i].numpy().argmax()])
    plt.axis('off')
  plt.suptitle(title); plt.show()

show_examples(train_ds, "Train Samples")

# **5. Build the Sliding-Window Model**

In [None]:
PATCH_W = IMG_W // 3 # 21 when IMG_W = 64

def extract_patch(x, idx):
  start = idx * PATCH_W
  return x[:, :, start:start+PATCH_W, :] # (None, 64, 21, 1)

inputs = tf.keras.Input(shape=(IMG_H, IMG_W, 1))
logits = []

for i in range(3):
  patch = tf.keras.layers.Lambda(lambda z, i=i: extract_patch(z, i))(inputs)
  patch = tf.keras.layers.Resizing(IMG_H, IMG_H)(patch) # -> (None, 64, 64, 1)
  # Re-use frozen base_model (shared weights)
  logits.append(base_model(patch)) # (None, 26)

concat = tf.keras.layers.Concatenate()(logits) # (None, 78)
outputs = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(concat)

word_model = tf.keras.Model(inputs, outputs)
word_model.compile(
  optimizer=tf.keras.optimizers.Adam(1e-3),
  loss='categorical_crossentropy',
  metrics=['accuracy']
)

# Print model summary
word_model.summary()

# **6. Callbacks**

In [None]:
CKPT_DIR = f"{BASE_PATH}/words3_ckpt_best.keras"

callbacks = [
  # 1. Checkpoint
  ModelCheckpoint(CKPT_DIR, save_best_only=True, monitor='val_loss'),
  # 2. Early stopping
  EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
  # 3. Learning rate scheduler
  ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
]

# **7. Train (Frozen Base)**

In [None]:
history = word_model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=20,
  callbacks=callbacks
)

# **8. (Optional) Fine-Tune Last Conv Block**

In [None]:
# Un-freeze last 3 layers of base_model
for layer in base_model.layers[-3:]:
  layer.trainable = True

# Re-compile with lower LR
word_model.compile(
  optimizer=tf.keras.optimizers.Adam(1e-4),
  loss='categorical_crossentropy',
  metrics=['accuracy']
)

ft_history = word_model.fit(
  train_ds,
  validation_data=val_ds,
  initial_epoch=history.epoch[-1] + 1,
  epochs=history.epoch[-1] + 4,
  callbacks=callbacks
)

# **9. Evaluation**

In [None]:
word_model = tf.keras.models.load_model(CKPT_DIR) # best checkpoint
test_loss, test_acc = word_model.evaluate(test_ds)
print(f"Test accuracy: {test_acc:.4f}")

# Util for plotting confusion matrix
def plot_confusion_matrix(cm, class_names, title="Confusion Matrix"):
  """
  Args:
      cm (np.ndarray): square confusion matrix
      class_names (List[str]): labels in the same order used to build cm
  """
  fig, ax = plt.subplots(figsize=(10, 9))
  im = ax.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
  ax.figure.colorbar(im, ax=ax, fraction=0.045)

  # axes & ticks
  ax.set(
    xticks=np.arange(len(class_names)),
    yticks=np.arange(len(class_names)),
    xticklabels=class_names,
    yticklabels=class_names,
    ylabel="True label",
    xlabel="Predicted label",
    title=title,
  )
  plt.setp(ax.get_xticklabels(), rotation=90, ha="center", va="center")

  # annotate cells
  thresh = cm.max() / 2.0
  for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
      ax.text(
        j, i, format(cm[i, j], "d"),
        ha="center", va="center",
        color="white" if cm[i, j] > thresh else "black",
        fontsize=8
      )

  fig.tight_layout()
  plt.show()

# Classification report
y_pred, y_true = [], []
for x, y in test_ds:
  y_pred.extend(np.argmax(word_model.predict(x), axis=1))
  y_true.extend(np.argmax(y.numpy(), axis=1))
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion matrix heat-map (optional)
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, class_names, title="3-Letter Word Confusion Matrix")

# **10. Qualitative Error Analysis**

In [None]:
# Plot a few misclassified 3-letter words
mis_idx = [i for i,(t,p) in enumerate(zip(y_true, y_pred)) if t != p]
show_examples(test_ds.unbatch().skip(mis_idx[0]), "Misclassified example")