In [1]:
import numpy as np

def _read_csv(data_dir, file_name):
    import csv
    with open(data_dir + '/' + file_name + '.csv', 'r') as csvfile:
        reader = csv.reader(csvfile, delimiter=',')
        return list(reader)

def _read_text(data_dir, file_name):
    with open(data_dir + '/' + file_name + '.txt', 'r') as textfile:
        return textfile.readlines()
    
def load(data_dir):
    alphabet = _read_csv(data_dir, 'alphabet')[0]
    P = np.array(_read_csv(data_dir, 'letter_probabilities')[0], dtype=np.float64)
    M = np.array(_read_csv(data_dir, 'letter_transition_matrix'), dtype=np.float64)
    cipher_function = _read_csv(data_dir, 'cipher_function')[0]
    plaintext = _read_text(data_dir, 'plaintext')[0]
    ciphertext = _read_text(data_dir, 'ciphertext')[0]
    return alphabet, P, M, cipher_function, plaintext, ciphertext

alphabet, P, M, cipher_function, plaintext, ciphertext = load('./data')

logP = np.log(P)
M[M==0] = 1
logM = np.log(M)
logM[logM==0] = -1e10

plaintext = plaintext
ciphertext = ciphertext

In [2]:
dictionary = dict(zip(alphabet, range(len(alphabet))))

def translate(text):
    return np.array([dictionary[t] for t in text if t != '\n'])

plaincode = translate(plaintext)
ciphercode = translate(ciphertext)
f_true = translate(cipher_function)

In [5]:
def initialize(num, order):
    x = np.zeros((num, order), dtype=np.float64)
    for i in range(num):
        x[i,] = np.random.permutation(order)
    return x

def proposal_change(x):
    def _swap(_x):
        idx = np.random.choice(_x.shape[0], 2, replace=False)
        temp = _x[idx[0]]
        _x[idx[0]] = _x[idx[1]]
        _x[idx[1]] = temp
        return _x
    xp = np.apply_along_axis(_swap, 1, np.copy(x))
    return xp

def cal_log_prob_ratio(x, xp, logP, logM, ciphercode):
    def _get_log_probs(x):
        ps = np.zeros((x.shape[0], ciphercode.shape[0]))
        xinv = np.argsort(x, axis=1)
        ps[:,0] = logP[xinv[:,ciphercode[0]]]
        ps[:,1:] = logM[xinv[:,ciphercode[1:]],xinv[:,ciphercode[:-1]]]
        return np.sum(ps, axis=1)
    return (_get_log_probs(xp)-_get_log_probs(x))

def random_step(logpr):
    pr = np.exp(np.clip(logpr, -np.inf, 0))
    return (np.random.rand(pr.shape[0]) < pr).astype(np.int)

def update(x, xp, rs):
    rs = rs.reshape((rs.shape[0],1))
    return (rs*xp+(1-rs)*x)

def get_acc(x, ciphercode, plaincode):
    xinv = np.argsort(x, axis=1)
    plaincode = np.repeat(plaincode.reshape((1,-1)), x.shape[0], axis=0)
    return np.sum((xinv[:,ciphercode] == plaincode).astype(np.int))/x.shape[0]/ciphercode.shape[0]


def main(num, order, logP, logM, ciphercode, plaincode, maxiter):
    x = initialize(num, order)
    for i in range(maxiter):
        xp = proposal_change(x)
        logpr = cal_log_prob_ratio(x, xp, logP, logM, ciphercode)
        rs = random_step(logpr)
        x = update(x, xp, rs)
        acc = get_acc(x, ciphercode, plaincode)
        print(acc)
        
        
main(100, 28, logP, logM, ciphercode, plaincode, 10000)

0.031383601756954614
0.030202415812591508
0.03179282576866765
0.032871156661786236
0.033417642752562225
0.03488396778916544
0.03406588579795022
0.03558894582723279
0.034748535871156665
0.03463909224011713
0.03595644216691069
0.03707650073206442
0.036358711566617866
0.03793265007320644
0.0371314055636896
0.03666874084919473
0.03694838945827233
0.03672437774524158
0.03765519765739385
0.037780746705710105
0.03953367496339678
0.04141398243045388
0.04162408491947291
0.042480600292825764
0.04272803806734993
0.04267679355783309
0.042050146412884334
0.04321412884333821
0.0431303074670571
0.04458125915080527
0.044451683748169844
0.042833089311859446
0.04223206442166911
0.04265959004392387
0.042236456808199124
0.043046486090775986
0.045237554904831626
0.0473993411420205
0.04954795021961933
0.04747693997071742
0.04779612005856515
0.048950585651537336
0.05031551976573938
0.05034773060029282
0.050938140556368965
0.052396046852122985
0.052599560761346996
0.053046120058565155
0.05256112737920937
0.05

KeyboardInterrupt: 