In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [2]:
%matplotlib inline
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns

def _read_csv(data_dir, file_name):
    import csv
    with open(data_dir + '/' + file_name + '.csv', 'r') as csvfile:
        reader = csv.reader(csvfile, delimiter=',')
        return list(reader)

def _read_text(data_dir, file_name):
    with open(data_dir + '/' + file_name + '.txt', 'r') as textfile:
        return textfile.readlines()
    
def load(data_dir):
    alphabet = _read_csv(data_dir, 'alphabet')[0]
    P = np.array(_read_csv(data_dir, 'letter_probabilities')[0], dtype=np.float64)
    M = np.array(_read_csv(data_dir, 'letter_transition_matrix'), dtype=np.float64)
    cipher_function = _read_csv(data_dir, 'cipher_function')[0]
    plaintext = _read_text(data_dir, 'plaintext')[0]
    ciphertext = _read_text(data_dir, 'ciphertext')[0]
    return alphabet, P, M, cipher_function, plaintext, ciphertext

alphabet, P, M, cipher_function, plaintext, ciphertext = load('../data/part1-data')

logP = np.log(P)
M[M==0] = 1
logM = np.log(M)

dictionary = dict(zip(alphabet, range(len(alphabet))))

def translate(text):
    return np.array([dictionary[t] for t in text if t != '\n'])

plaincode = translate(plaintext)
ciphercode = translate(ciphertext)
f_true = translate(cipher_function)

In [3]:
def initialize(num, order):
    x = np.zeros((num, order), dtype=np.int)
    for i in range(num):
        x[i,] = np.random.permutation(order)
    return x

def _roll(_x):
    idx = np.random.choice(_x.shape[0], 2, replace=False)
    _x[idx] = _x[np.roll(idx, 1)]
    return _x

def proposal_change(x):
    xp = np.apply_along_axis(_roll, 1, np.copy(x))
    return xp

def get_log_probs_mc(x, logP, logM, c1, c2, num, order):
    aux_idx = np.repeat(np.expand_dims(np.arange(num), axis=1), order, axis=1)
    lp = np.sum(c1[x] * np.repeat(np.expand_dims(logP, axis=0), num, axis=0), axis=1)
    lp += np.sum(c2[np.tile(np.expand_dims(x, axis=2), (1,1,order)), 
                    np.tile(np.expand_dims(x, axis=1), (1,order,1))]*np.tile(np.expand_dims(logM, axis=0), (num,1,1)), 
                 axis=(1,2))
    return lp

def count(order, ciphercode):
    c1 = np.zeros(order)
    np.add.at(c1, ciphercode, 1)
    c2 = np.zeros((order, order))
    np.add.at(c2, (ciphercode[1:], ciphercode[:-1]), 1)
    return c1, c2

def get_log_probs_gm(x, ciphercode):
    xinv = np.argsort(x, axis=1)
    plaincode = xinv[:,ciphercode]
    return np.any(np.logical_and(plaincode[:,:-1] == 27, plaincode[:,1:] != 26), axis=1)*(-0*ciphercode.shape[0])

def random_step(logpr):
    pr = np.exp(np.clip(logpr, -np.inf, 0))
    #return (np.random.rand(pr.shape[0]) < pr).astype(np.int)
    return (logpr > 0).astype(np.int)

def update(x, xp, rs, rs_cum, T):
    rs = np.logical_or(rs == 1, rs_cum > T//2).astype(np.int)
    rs = rs.reshape((rs.shape[0],1))
    return (rs*xp+(1-rs)*x)

def get_acc(x, ciphercode, plaincode):
    xinv = np.argsort(x, axis=1)
    plaincode = np.repeat(plaincode.reshape((1,-1)), x.shape[0], axis=0)
    return np.sum((xinv[:,ciphercode] == plaincode).astype(np.int))/x.shape[0]/ciphercode.shape[0]

def get_ktau(x, f_true):
    ktau = np.zeros(x.shape[0])
    for i in range(x.shape[0]):
        ktau[i] = stats.kendalltau(x[i,], f_true)[0]
    return np.mean(ktau)

def get_func_acc(x, f_true, order):
    x = np.argsort(x, axis=1)
    f_true = np.argsort(f_true)
    func_acc = np.zeros(x.shape[0])
    error_map = np.zeros((order, order))
    for i in range(x.shape[0]):
        func_acc[i] = np.mean(x[i,] == f_true)
        error_map[f_true[x[i,] != f_true].astype(np.int), x[i,][x[i,] != f_true].astype(np.int)] += 1
    return np.mean(func_acc), error_map

def get_gm_vd(x, ciphercode):
    xinv = np.argsort(x, axis=1)
    plaincode = xinv[:,ciphercode]
    return 1-np.mean(np.any(np.logical_and(plaincode[:,:-1] == 27, plaincode[:,1:] != 26), axis=1))
    
def main(num, order, logP, logM, ciphercode, plaincode, f_true, maxiter, T):
    x = initialize(num, order)
    rs_cum = np.zeros(num)
    logp_list, accept_rate_list, acc_list = [], [], []
    c1, c2 = count(order, ciphercode)
    for i in range(maxiter):
        xp = proposal_change(x)
        logp = get_log_probs_mc(x, logP, logM, c1, c2, num, order)
        logp_list.append(np.mean(logp))
        logpp = get_log_probs_mc(xp, logP, logM, c1, c2, num, order)
        logpr = logpp - logp
        rs = random_step(logpr)
        accept_rate_list.append(np.mean(rs))
        rs_cum += rs
        x = update(x, xp, rs, rs_cum, T)
        acc = get_acc(x, ciphercode, plaincode)
        acc_list.append(acc)
        ktau = get_ktau(x, f_true)
        func_acc, error_map = get_func_acc(x, f_true, order)
        gm_vd = get_gm_vd(x, ciphercode)
        if i % T == 0:
            print("it:{}, log_p:{:1.4e}, acpt_r:{:1.4e}, acc:{:1.4e}, ktau:{:1.4e}, facc:{:1.4e}, gmvd:{:1.4e}".format(
                i, np.mean(logp), np.mean(rs_cum)/T, acc, ktau, func_acc, gm_vd))
            rs_cum = np.zeros(num)
            #plt.figure(figsize=(16,14))
            #sns.heatmap(error_map)
            #plt.show()
    return logp_list, accept_rate_list, acc_list, x

length = 128*64
logp_list, accept_rate_list, acc_list, _ = main(1000, 28, logP, logM, 
                                             ciphercode[:length], plaincode[:length], f_true, 
                                             10000, 200)

it:0, log_p:-8.3237e+04, acpt_r:3.0000e-03, acc:2.9590e-02, ktau:2.6455e-04, facc:2.3214e-02, gmvd:1.0000e-01
it:200, log_p:-5.2395e+04, acpt_r:1.9225e-01, acc:1.4258e-01, ktau:3.0159e-02, facc:1.2321e-01, gmvd:0.0000e+00
it:400, log_p:-4.9882e+04, acpt_r:6.3250e-02, acc:2.9039e-01, ktau:1.5847e-01, facc:2.3036e-01, gmvd:0.0000e+00
it:600, log_p:-4.8407e+04, acpt_r:3.8750e-02, acc:3.8260e-01, ktau:2.5450e-01, facc:3.3214e-01, gmvd:0.0000e+00
it:800, log_p:-4.7263e+04, acpt_r:3.1000e-02, acc:4.8533e-01, ktau:3.5714e-01, facc:4.2679e-01, gmvd:0.0000e+00
it:1000, log_p:-4.6447e+04, acpt_r:1.9000e-02, acc:5.5034e-01, ktau:4.2698e-01, facc:4.9107e-01, gmvd:0.0000e+00
it:1200, log_p:-4.5817e+04, acpt_r:1.5750e-02, acc:6.1382e-01, ktau:4.7963e-01, facc:5.5893e-01, gmvd:5.0000e-02
it:1400, log_p:-4.5347e+04, acpt_r:1.1250e-02, acc:6.3691e-01, ktau:4.9894e-01, facc:5.7321e-01, gmvd:0.0000e+00
it:1600, log_p:-4.4748e+04, acpt_r:1.0500e-02, acc:6.8619e-01, ktau:5.6402e-01, facc:6.1607e-01, gmvd:0

KeyboardInterrupt: 