In [1]:
import random
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

In [2]:
letters = [i for i in 'abcdefghijklmnopqrstuvwxyz']

def preprocess(content):
    #文本预处理：全小写，移除标点，空格分隔
    text = ''
    for i in content.lower().strip():
        if i in letters:
            text += i
        elif text[-1] != ' ':
            text += ' '
    return text

file_path = "shakespeare.txt"
with open(file_path, mode="r", encoding="utf-8") as f:
    plaintext = f.read()
plaintext = preprocess(plaintext)
print(len(plaintext))

segments = []; segment = ''; word_count = 0
for i in plaintext:
    segment += i
    if i == ' ':
        word_count += 1
    if word_count >= 92:
        segments.append(segment)
        segment = ''; word_count = 0
print(len(segments))
# 文本分段：每段92词，~470字符，共10068段

random.seed(42)
random.shuffle(segments)
train_text = segments[:7000]
valid_text = segments[7000:8500]
test_text = segments[8500:10000]

4708753
10068


In [3]:
class crypt:
    def __init__(self, seed=None):
        if seed is not None:
            random.seed(seed)
        self.seed = seed
        self.key = [i for i in 'abcdefghijklmnopqrstuvwxyz']
        random.shuffle(self.key)
        self.inv_key = [0 for _ in range(26)]
        for i in range(26):
            self.inv_key[ord(self.key[i]) - ord('a')] = chr(i + ord('a'))
        #key:     ['x','y',...] 明文 ab... 加密为 xy...
        #inv_key: ['x ,'y',...] 密文 ab... 解密为 xy...

    def reveal(self):
        return self.key, self.inv_key

    def encrypt(self, plain):
        #加密：输入明文，输出密文
        cipher = ''
        for i in plain:
            if i == ' ':
                cipher += i
            else:
                cipher += self.key[ord(i) - ord('a')]
        return cipher
    
    def decrypt(self, cipher, inv_key=None):
        #解密：输入密文，输出明文
        #默认使用真实的inv_key，也可自定
        plain = ''
        if inv_key is None:
            for i in cipher:
                if i == ' ':
                    plain += i
                else:
                    plain += self.inv_key[ord(i) - ord('a')]
        else:
            for i in cipher:
                if i == ' ':
                    plain += i
                else:
                    plain += inv_key[ord(i) - ord('a')]
        return plain

def match_freq(crypt, cipher):
    #通过单个字母频率排序解密 (Baseline)
    #输入：密文，输出：推断的解密密钥
    true_freq_order = ['e','t','a','o','i','n','s','h','r','d','l','c','u',
                       'm','w','f','g','y','p','b','v','k','j','x','q','z'] #真实字母频率排序（高到低）
    freq = defaultdict(int)
    for i in cipher:
        if i != ' ':
            freq[i] += 1
    for i in true_freq_order:
        if i not in freq:
            freq[i] = 0
    freq = sorted(dict(freq).items(), key=lambda x:x[1], reverse=True)
    freq_order = [i[0] for i in freq]
    compare = list(zip(freq_order, true_freq_order))
    compare = sorted(compare, key=lambda x:x[0])
    inv_key = [i[1] for i in compare]

    t = 0
    for i in range(26):
        if crypt.inv_key[i] == inv_key[i]:
            t += 1
    return round(100*t/26)

def baseline(test_text, N=1500):
    total_acc = 0
    for i in range(N):
        A = crypt(i)
        random.seed(i)
        plain = ''.join(random.sample(test_text, 10))
        cipher = A.encrypt(plain)
        total_acc += match_freq(A, cipher)
    return round(total_acc/N)
print(f'Baseline accuracy: {baseline(test_text)}%')

Baseline accuracy: 29%


In [4]:
def extract_freq(cipher):
    #输入密文或明文，输出频率比例 [a-z, aa-az, ba-bz, ... za-zz] 共702个
    letter_freq = [0] * 26
    conditional_freq = [[0] * 26 for _ in range(26)]

    length = 0
    prev_char = None
    for char in cipher:
        if char == ' ':
            prev_char = None
            continue

        length += 1
        letter_freq[ord(char) - ord('a')] += 1
        if prev_char is not None:
            conditional_freq[ord(prev_char) - ord('a')][ord(char) - ord('a')] += 1
        prev_char = char
    """
    将字母频率和条件频率合并为一维列表（26 + 26*26 = 702个元素）
    顺序：先字母频率（a-z），再条件频率（aa, ab, ..., az, ba, bb, ..., zz）
    """
        
    flatten_freq = [count for row in conditional_freq for count in row]
    for i in range(len(flatten_freq)):
        if flatten_freq[i] != 0:
            flatten_freq[i] = round(flatten_freq[i] / letter_freq[int(i/26)], 6)
    for i in range(26):
        letter_freq[i] = round(letter_freq[i] / length, 6)
    return letter_freq + flatten_freq

def generate_dataset(train_text, valid_text, test_text, len_dataset=(10000, 1500, 1500)):

    train = list()
    for i in tqdm(range(len_dataset[0])):
        cryptor = crypt(seed=i)
        true_inv_key = cryptor.inv_key
        
        random.seed(i)
        plaintext = ''.join(random.sample(train_text, 10))
        ciphertext = cryptor.encrypt(plaintext)
        freq_features = extract_freq(ciphertext)

        row_data = [true_inv_key] + freq_features
        train.append(row_data)
    train = pd.DataFrame(train, columns=['inv_key'] + [f'freq_{i}' for i in range(702)])

    valid = list()
    for i in tqdm(range(10000, 10000 + len_dataset[1])):
        cryptor = crypt(seed=i)
        true_inv_key = cryptor.inv_key
        
        random.seed(i)
        plaintext = ''.join(random.sample(valid_text, 10))
        ciphertext = cryptor.encrypt(plaintext)
        freq_features = extract_freq(ciphertext)

        row_data = [true_inv_key] + freq_features
        valid.append(row_data)
    valid = pd.DataFrame(valid, columns=['inv_key'] + [f'freq_{i}' for i in range(702)])

    test = list()
    for i in tqdm(range(20000, 20000 + len_dataset[2])):
        cryptor = crypt(seed=i)
        true_inv_key = cryptor.inv_key
        
        random.seed(i)
        plaintext = ''.join(random.sample(test_text, 10))
        ciphertext = cryptor.encrypt(plaintext)
        freq_features = extract_freq(ciphertext)

        row_data = [true_inv_key] + freq_features
        test.append(row_data)
    test = pd.DataFrame(test, columns=['inv_key'] + [f'freq_{i}' for i in range(702)])
    
    return train, valid, test

train, valid, test = generate_dataset(train_text, valid_text, test_text)
print(train.shape, valid.shape, test.shape)

100%|██████████| 10000/10000 [00:14<00:00, 688.27it/s]
100%|██████████| 1500/1500 [00:02<00:00, 700.65it/s]
100%|██████████| 1500/1500 [00:02<00:00, 711.68it/s]


(10000, 703) (1500, 703) (1500, 703)


In [6]:
train.to_csv('new_dataset_train.csv', index=False, encoding='utf-8')
valid.to_csv('new_dataset_valid.csv', index=False, encoding='utf-8')
test.to_csv('new_dataset_test.csv', index=False, encoding='utf-8')