# Evaluate classifier and key recovery
This notebook evaluates the performance of classifier and key search algorithm for messages encrypted under the substitution cipher.

In [None]:
import sys
sys.path.append('src')

from src.eval_utils import *
from src.data_utils import *
from src.functional import *
import src.alphabet as alph
import src.cipher as Cipher
import src.cracker as Cracker
import seaborn as sns
import random, copy
import numpy as np

In [None]:
# Set parameters
data_id = 'bible-test'
alphabet = alph.basic_lower()
cipher = Cipher.Substitution(alphabet)

In [None]:
# Samples messages
n_samples = 50
msgs = load_msgs(data_id, alphabet, 0, 0, lower=True)
samples = random.sample(msgs, n_samples)
order = get_char_order(msgs, alphabet)

In [None]:
# Encrypt samples
subst_cipher = Cipher.Substitution(alphabet)
encs, keys = subst_cipher.encrypt_all(samples, order=order, key_rot=1, key_gen_strategy='random')

In [None]:
# Initialize cracker with partial decryption classifiers
partial_lengths = list(range(2, 27, 2))

crackers = {}
models_dict = {}
for l in partial_lengths:
    model_dict = {l: load_model(f'models/bible_partial{l}_prop1.cnn', len(alphabet))}
    crackers[l] = Cracker.Substitution(alphabet, model_dict)

# Initialize cracker with partial decryption classifiers on low distance data
crackers_lowdist = {}
models_dict = {}
for l in partial_lengths:
    model_dict = {l: load_model(f'models/bible_partial{l}_prop1_lowdist_shuffle6.cnn', len(alphabet))}
    crackers_lowdist[l] = Cracker.Substitution(alphabet, model_dict)

In [None]:
# Computer ranks of correct decryption across different samples
all_ranks = []
for enc, key in zip(encs, keys):
    ranks = get_rank_correct(crackers, cipher, enc, key, order)
    all_ranks.append(ranks)

all_ranks_lowdist = []
for enc, key in zip(encs, keys):
    ranks = get_rank_correct(crackers_lowdist, cipher, enc, key, order)
    all_ranks_lowdist.append(ranks)

In [None]:
# Plot avg ranks of correct decryptions
fid, axes = plt.subplots(1, 2, figsize=(10, 5))
axes = axes.flatten()

avg_ranks = np.mean(all_ranks, axis=0).reshape(-1, )
sns.barplot(x=list(range(2, 27, 2)), y=avg_ranks, ax=axes[0])
axes[0].set_title('Train on: normal data', fontsize=20)

avg_ranks_lowdist = np.mean(all_ranks_lowdist, axis=0).reshape(-1, )
sns.barplot(x=list(range(2, 27, 2)), y=avg_ranks_lowdist, ax=axes[1])
axes[1].set_title('Train on: low distance data', fontsize=20)

for i in range(2):
    axes[i].set_xlabel('Size of partial decryption')
    axes[i].set_ylabel('Avg rank')
    axes[i].set_ylim(0, 160)