# Client to train a CharCNN model for the substitution cipher
This notebook demonstrates how to train a classifier to predict correct vs incorrect decryptions on the substitution cipher.

In [None]:
import sys
sys.path.append('src')

import torch
from src.models import CharCNN
from src.model_utils import save_model
from src.data_utils import gen_dl
from src.plot_utils import plot_loss, plot_confusion
import logging
import src.alphabet as alph
from src.trainer import Trainer
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from src.functional import *
mpl_logger = logging.getLogger('matplotlib') 
mpl_logger.setLevel(logging.WARNING) 
import random, copy
from collections import Counter

# Train model on partially decrypted bible data

In [None]:
# Set parameters
data_id = 'bible-train'
alphabet = alph.basic_lower()
alphabet_len = len(alphabet)
partial_lengths = list(range(2, 27, 2))
cipher_props = [1]

# Grid search over different cipher proportions
for p in cipher_props:
    ciphers = {'plain': 1, 'substitution': p}
    for l in partial_lengths:
        # Generate partially encrypted data
        dl_train, dl_val = gen_dl(data_id, alphabet, 0, 0, key_gen_strategy='random',
                                  ciphers=ciphers, partial_lengths=[l], lower=True, key_rot=1)

        # Define classweight as inverse of proportion
        class_weights = torch.tensor([p, 1.0], dtype=torch.float32)
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        
        # Create classifier with specified criterion
        clf = CharCNN(alphabet_len, loss=criterion)
        
        # Train classifier until convergence
        trainer = Trainer(nb_epochs=100)
        loss_train, loss_val = trainer.fit(clf, dl_train, dl_val)
        
        # Evaluate classifier
        model_name = f'partial{l}_prop{p}'
        plot_loss(loss_train, loss_val, f'Loss for {model_name}')
        plot_confusion(trainer.score(clf, dl_val, normed_confusion), f'Confusion matrix for {model_name}')
        acc = trainer.score(clf, dl_val, balanced_accuracy_score)
        
        # Save model
        save_model(clf, f'models/bible_{model_name}_lowdist_shuffle6.cnn')
        print(f'### Model {model_name} saved with balanced accuracy of {round(100*acc, 2)}%. ###')