# batch.py

In [3]:
import random
import torch

num_examples = 128
message_length = 32

 
def dataset(num_examples):
   
    dataset = []
    for x in range(num_examples):
        ex_out = ''.join([random.choice(vocab) for x in range(message_length)])
        ex_in = encrypt(''.join(ex_out))
        ex_in = [vocab.index(x) for x in ex_in]
        ex_out = [vocab.index(x) for x in ex_out]
        dataset.append([torch.tensor(ex_in), torch.tensor(ex_out)])
    return dataset

#cipher.py

In [4]:
key = 3
vocab = [char for char in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-']


def encrypt(text):
   
    indexes = [vocab.index(char) for char in text]
    encrypted_indexes = [(idx + key) % len(vocab) for idx in indexes]
    encrypted_chars = [vocab[idx] for idx in encrypted_indexes]
    encrypted = ''.join(encrypted_chars)
    return encrypted

print(encrypt('ABCDEFGHIJKLMNOPQRSTUVWXYZ-'))

DEFGHIJKLMNOPQRSTUVWXYZ-ABC


#model.py

In [5]:
embedding_dim = 10
hidden_dim = 10
vocab_size = len(vocab)

embed = torch.nn.Embedding(vocab_size, embedding_dim)
lstm = torch.nn.LSTM(embedding_dim, hidden_dim)
linear = torch.nn.Linear(hidden_dim, vocab_size)
softmax = torch.nn.functional.softmax
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(embed.parameters()) + list(lstm.parameters())
                             + list(linear.parameters()), lr=0.001)

# train.py

In [6]:
num_epochs = 10

accuracies, max_accuracy = [], 0
for x in range(num_epochs):
    print('Epoch: {}'.format(x))
    for encrypted, original in dataset(num_examples):
        lstm_in = embed(encrypted)
        lstm_in = lstm_in.unsqueeze(1)
        lstm_out, lstm_hidden = lstm(lstm_in, zero_hidden())
        scores = linear(lstm_out)        
        scores = scores.transpose(1, 2)        
        original = original.unsqueeze(1)        
        loss = loss_fn(scores, original) 
        # Backpropagate
        loss.backward()
        # Update weights
        optimizer.step()
    print('Loss: {:6.4f}'.format(loss.item()))

Epoch: 0
Loss: 2.7491
Epoch: 1
Loss: 1.7264
Epoch: 2
Loss: 0.8279
Epoch: 3
Loss: 0.4849
Epoch: 4
Loss: 0.2186
Epoch: 5
Loss: 0.1132
Epoch: 6
Loss: 0.0765
Epoch: 7
Loss: 0.0453
Epoch: 8
Loss: 0.0237
Epoch: 9
Loss: 0.0161


# valid.py

In [8]:
    with torch.no_grad():
        matches, total = 0, 0
        for encrypted, original in dataset(num_examples):
            lstm_in = embed(encrypted)
            lstm_in = lstm_in.unsqueeze(1)
            lstm_out, lstm_hidden = lstm(lstm_in, zero_hidden())
            scores = linear(lstm_out)            
            predictions = softmax(scores, dim=2)            
            _, batch_out = predictions.max(dim=2)            
            batch_out = batch_out.squeeze(1)
            # Calculate accuracy
            matches += torch.eq(batch_out, original).sum().item()
            total += torch.numel(batch_out)
        accuracy = matches / total
        print('Accuracy: {:4.2f}%'.format(accuracy * 100))


Accuracy: 100.00%
