# **Model for Character De-Obfuscation**
Stage 3 of MSc Project — Ashraf Muhammed Yusuf

# **1. Colab Environment Setup** - Imports & Constants

In [None]:
# Mount Drive so you can read datasets and write checkpoints
# Link to dataset:
# https://drive.google.com/drive/folders/1kygA17GiCeCs8qTeDBEndU6TkXnEu-m7?usp=drive_link
drive.mount('/content/drive')

# Install dependencies
!pip install -q tensorflow matplotlib

import string
import random
import itertools
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
from tensorflow.keras import layers, models, backend as K
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Dirs
train_dir = "/content/drive/MyDrive/MScProject/data/words3/train"
test_dir = "/content/drive/MyDrive/MScProject/data/words3/test"
CKPT_DIR  = "/content/drive/MyDrive/MScProject/ctc_ckpt_best.keras""

# image dimensions (same as your synthetic data)
IMG_H = 64
IMG_W = 64
BATCH  = 64

# CTC parameters
CHARS        = list(string.ascii_uppercase)         # ['A', … 'Z']
BLANK_LABEL  = 0                                    # CTC blank
char_to_num  = {c:i+1 for i,c in enumerate(CHARS)}  # 'A'→1 … 'Z'→26
num_to_char  = {i+1:c for i,c in enumerate(CHARS)}
NUM_CLASSES  = len(CHARS) + 1                       # +1 for blank

# max label length (3 letters)
MAX_LABEL_LEN = 3

# **2. Data Pipeline for CTC**

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

inputs = {
  "image":        image_tensor,           # (64,64,1)
  "label":        label_seq,              # (MAX_LABEL_LEN,)
  "input_len":    input_seq_len,          # (1,)  time steps after CNN
  "label_len":    label_length,           # (1,)
}

def parse_image_and_label(path, label):
    # path: string filepath, label: a 3-char string e.g. "CAT"
    img = tf.io.read_file(path)
    img = tf.io.decode_png(img, channels=1)
    img = tf.image.resize(img, [IMG_H, IMG_W])
    img = tf.cast(img, tf.float32) / 255.0

    # encode label string to ints
    lbl = [char_to_num[c] for c in label]
    lbl = tf.convert_to_tensor(lbl, dtype=tf.int32)
    return img, lbl

def prepare_for_ctc(image, label):
    # image: (H,W,1), label: (MAX_LABEL_LEN,)
    # compute input_length = number of time steps output by the CNN
    # we'll build CNN so that time_steps = W//4 (two 2×2 pools).
    time_steps = IMG_W // 4

    return {
        "image":      image,
        "label":      label,
        "input_len":  tf.cast(time_steps, tf.int32),
        "label_len":  tf.shape(label)[0],
    }, tf.zeros(())  # dummy y_true, since loss is in-model

def make_dataset(
    data_dir,
    subset=None, # "training", "validation", or None (no split)
    val_frac=0.2,
    batch_size=BATCH,
    img_size=(64, 64),
    seed=42
):
    data_dir = Path(data_dir)
    # 1) discover classes
    class_names = sorted([p.name for p in data_dir.iterdir() if p.is_dir()])
    class_to_idx = {c:i for i,c in enumerate(class_names)}

    # 2) collect filepaths + labels
    filepaths, labels = [], []
    for cls in class_names:
        for img_path in (data_dir/cls).glob("*.png"):
            filepaths.append(str(img_path))
            labels.append(class_to_idx[cls])

    # 3) shuffle
    combined = list(zip(filepaths, labels))
    random.Random(seed).shuffle(combined)
    filepaths, labels = zip(*combined)

    # 4) split
    if subset in ("training", "validation"):
        n_val = int(len(filepaths) * val_frac)
        if subset == "validation":
            filepaths = filepaths[:n_val]
            labels    = labels[:n_val]
        else:
            filepaths = filepaths[n_val:]
            labels    = labels[n_val:]

    # 5) build tf.data.Dataset
    ds = tf.data.Dataset.from_tensor_slices((list(filepaths), list(labels)))

    def _load_and_preprocess(path, label):
        # Read + decode
        img = tf.io.read_file(path)
        img = tf.io.decode_image(img, channels=1, expand_animations=False)
        # Resize + normalize
        img = tf.image.resize(img, img_size)
        img = img / 255.0
        # one-hot encode label
        label = tf.one_hot(label, depth=len(class_names))
        return img, label

    ds = ds.map(_load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)

    # 6) shuffle training only
    if subset == "training":
        ds = ds.shuffle(buffer_size=1000, seed=seed)

    # 7) batch & prefetch
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return ds, class_names

train_ds, class_names   = make_dataset(train_dir, subset="training")
val_ds                  = make_dataset(train_dir, subset="validation")
test_ds                 = make_dataset(test_dir, subset=None)

# **3. Build CTC Model**

In [None]:
# 3.1. image input
img_in = layers.Input(shape=(IMG_H,IMG_W,1), name="image")

# 3.2. convolutional feature extractor
x = layers.Conv2D(32, 3, padding="same", activation="relu")(img_in)
x = layers.MaxPool2D(pool_size=2)(x)   # → (32×32×32)
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.MaxPool2D(pool_size=2)(x)   # → (16×16×64)

# collapse height dimension to 1
x = layers.Conv2D(128, (IMG_H//16,1), activation="relu")(x)
# → (1 × time_steps= IMG_W/4 × 128)
x = layers.Reshape((IMG_W//4, 128))(x)

# 3.3. Bi-LSTM sequence modelling
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)

# 3.4. output layer (softmax over characters + blank)
y_pred = layers.Dense(NUM_CLASSES, activation="softmax", name="y_pred")(x)

# 3.5. additional inputs for CTC loss
labels      = layers.Input(shape=(MAX_LABEL_LEN,), dtype="int32", name="label")
input_len   = layers.Input(shape=(),              dtype="int32", name="input_len")
label_len   = layers.Input(shape=(),              dtype="int32", name="label_len")

# 3.6. CTC loss computation in-graph
def ctc_lambda(args):
    y_pred, labels, input_len, label_len = args
    # swap batch and time for K.ctc_batch_cost signature
    return K.ctc_batch_cost(labels, y_pred, input_len, label_len)

ctc_loss = layers.Lambda(ctc_lambda, output_shape=(1,), name="ctc")(
    [y_pred, labels, input_len, label_len]
)

# 3.7. training model
training_model = models.Model(
    inputs=[img_in, labels, input_len, label_len],
    outputs=ctc_loss
)

training_model.compile(
    optimizer="adam",
    loss={"ctc": lambda y_true, y_pred: y_pred}
)

# 3.8. inference model (for decoding later)
inference_model = models.Model(img_in, y_pred)

training_model.summary()

# **4. Train with CTC**

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        CKPT_DIR,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=False
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=3
    )
]

history = training_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=30,
    callbacks=callbacks
)

# **5. Decode & Evaluate**

In [None]:
# 5.1. restore best
tf.keras.config.enable_unsafe_deserialization()
inference_model = tf.keras.models.load_model(
    CKPT_DIR,
    compile=False
)
inference_model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)
inference_model.summary()

# 5.2. helper to convert preds→strings
def decode_batch(batch_pred):
    # batch_pred: (batch, time_steps, NUM_CLASSES)
    input_length = np.ones(batch_pred.shape[0]) * batch_pred.shape[1]
    # greedy ctc decode
    decoded, _ = K.ctc_decode(batch_pred, input_length, greedy=True)
    decoded = decoded[0].numpy()  # tensor → numpy
    # map ints back to chars
    texts = []
    for seq in decoded:
        s = "".join(num_to_char.get(i, "") for i in seq if i>0)
        texts.append(s)
    return texts

# 5.3. run on test set
y_true, y_pred = [], []
for batch in test_ds:
    inp = batch[0]["image"]
    labels = batch[0]["label"].numpy().astype(int)
    preds = inference_model.predict(inp)
    texts = decode_batch(preds)
    # flatten true labels to strings
    for t, seq in zip(labels, texts):
        true_str = "".join(num_to_char[i] for i in t if i>0)
        y_true.append(true_str)
        y_pred.append(seq)

# 5.4. compute word-level accuracy
print("Word accuracy:", accuracy_score(y_true, y_pred))
print(classification_report(y_true, y_pred, zero_division=0))

# **6. Visualisation**

In [None]:
def plot_training_curves(history, ft_history=None):
    """
    Plots accuracy and loss curves.
    If ft_history (fine-tune history) is provided, it will be appended.
    """
    # merge histories if needed
    h = history.history.copy()
    if ft_history is not None:
        for k, v in ft_history.history.items():
            h[k].extend(v)

    epochs = range(1, len(h['loss']) + 1)

    plt.figure(figsize=(12,5))
    # Accuracy
    plt.subplot(1,2,1)
    plt.plot(epochs, h['accuracy'],    label='Train Acc')
    plt.plot(epochs, h.get('val_accuracy', []), label='Val Acc')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Loss
    plt.subplot(1,2,2)
    plt.plot(epochs, h['loss'],    label='Train Loss')
    plt.plot(epochs, h.get('val_loss', []), label='Val Loss')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.show()


def plot_confusion(cm, class_names, figsize=(8,8), fontsize=6):
    """
    Plots a confusion matrix heatmap.
    """
    plt.figure(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=90, fontsize=fontsize)
    plt.yticks(tick_marks, class_names, fontsize=fontsize)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black",
                 fontsize=fontsize)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()


def show_sample_predictions(model, dataset, class_names, num=9):
    """
    Displays num examples (default 9) from dataset alongside
    their true & predicted labels, highlighting mistakes in red.
    """
    # unbatch and take
    ds = dataset.unbatch().take(num)
    plt.figure(figsize=(10,10))
    for i, (img, label) in enumerate(ds):
        pred = np.argmax(model.predict(img[None,...]), axis=1)[0]
        true = np.argmax(label.numpy())
        plt.subplot(3,3,i+1)
        plt.imshow(img.numpy().squeeze(), cmap='gray')
        title = f"T: {class_names[true]}\nP: {class_names[pred]}"
        plt.title(title, color=('red' if pred!=true else 'black'), fontsize=10)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# 1) Plot training curves.
if "ft_history" in locals():
    plot_training_curves(history, ft_history)
else:
    plot_training_curves(history)

# 2) Evaluate on test set and build confusion matrix:
y_true, y_pred = [], []
for x, y in test_ds:
    preds = word_model.predict(x)
    y_pred.extend(preds.argmax(axis=1))
    y_true.extend(y.numpy().argmax(axis=1))

cm = confusion_matrix(y_true, y_pred)
plot_confusion(cm, class_names, figsize=(12,12), fontsize=4)

print(classification_report(y_true, y_pred, target_names=class_names))

# 3) Show a few prediction examples
show_sample_predictions(word_model, test_ds, class_names, num=9)