In [1]:
from part2_alice import ZigzagXOR, PixelIntershuffler

In [2]:
import torch
import torch.nn as nn
import numpy as np

# ============ Bob Network (Decryption) ============
class BobNetwork(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_rounds=3):
        super(BobNetwork, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_rounds = num_rounds

        # Input processing
        self.input_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )

        # Key mixing layers (mirror of Alice)
        self.key_mixer = nn.ModuleList([
            nn.Linear(input_size, hidden_size) for _ in range(num_rounds)
        ])

        # Decryption layers (reverse of Alice's encryption)
        self.decrypt_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size * 2, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Tanh()
            ) for _ in range(num_rounds)
        ])

        # Decoder pathway
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, input_size),
            nn.Tanh()
        )

    def forward(self, ciphertext, key_sequences):
        """
        Forward pass through Bob network
        ciphertext: encrypted message (batch_size, input_size)
        key_sequences: same 8 chaotic key sequences used for encryption
        """
        batch_size = ciphertext.size(0)
        device = ciphertext.device

        # Convert key sequences to tensor
        keys_tensor = torch.FloatTensor(np.array(key_sequences[:self.num_rounds])).to(device)

        # Initial processing
        decoded = self.input_layer(ciphertext)

        # Multiple rounds of key-based decryption (reverse order)
        for round_idx in range(self.num_rounds - 1, -1, -1):
            # Mix with key
            key_features = self.key_mixer[round_idx](keys_tensor[round_idx].unsqueeze(0).expand(batch_size, -1))

            # Combine and decrypt
            combined = torch.cat([decoded, key_features], dim=1)
            decoded = self.decrypt_layers[round_idx](combined)

        # Final decoding
        plaintext = self.decoder(decoded)

        return plaintext


# ============ Eve Network (Adversary) ============
class EveNetwork(nn.Module):
    """
    Eve tries to decrypt without the proper keys
    Used for adversarial validation
    """
    def __init__(self, input_size, hidden_size=128):
        super(EveNetwork, self).__init__()

        self.input_size = input_size

        # Eve has a different architecture - tries to learn decryption
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size * 2, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Tanh()
        )

    def forward(self, ciphertext):
        """
        Eve attempts to decrypt without keys
        ciphertext: encrypted message (batch_size, input_size)
        """
        return self.network(ciphertext)


# ============ Complete Bob Decryption Module ============
class BobDecryption:
    def __init__(self, input_size, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.bob = BobNetwork(input_size).to(device)

    def decrypt(self, ciphertext, key_sequences):
        """
        Decrypt ciphertext using Bob network
        ciphertext: torch.Tensor
        key_sequences: list of 8 chaotic sequences (same as encryption)
        """
        if isinstance(ciphertext, np.ndarray):
            ciphertext = torch.FloatTensor(ciphertext)

        # Move to the same device as the model
        ciphertext = ciphertext.to(self.device)

        if len(ciphertext.shape) == 1:
            ciphertext = ciphertext.unsqueeze(0)

        self.bob.eval()
        with torch.no_grad():
            plaintext = self.bob(ciphertext, key_sequences)

        return plaintext

    def train_step(self, ciphertext, key_sequences, optimizer):
        """Single training step for Bob"""
        self.bob.train()
        optimizer.zero_grad()

        plaintext = self.bob(ciphertext, key_sequences)

        return plaintext


# ============ Complete Eve Adversary Module ============
class EveAdversary:
    def __init__(self, input_size, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.eve = EveNetwork(input_size).to(device)

    def attack(self, ciphertext):
        """
        Eve attempts to decrypt without keys
        ciphertext: torch.Tensor
        """
        if isinstance(ciphertext, np.ndarray):
            ciphertext = torch.FloatTensor(ciphertext)

        # Move to the same device as the model
        ciphertext = ciphertext.to(self.device)

        if len(ciphertext.shape) == 1:
            ciphertext = ciphertext.unsqueeze(0)

        self.eve.eval()
        with torch.no_grad():
            attempted_plaintext = self.eve(ciphertext)

        return attempted_plaintext

    def train_step(self, ciphertext, optimizer):
        """Single training step for Eve"""
        self.eve.train()
        optimizer.zero_grad()

        attempted_plaintext = self.eve(ciphertext)

        return attempted_plaintext


# ============ Helper Functions for Reverse Operations ============
# These are kept for potential future use with non-differentiable operations
# but are not used in the current differentiable implementation


# ============ Usage Example ============
if __name__ == "__main__":
    # Example usage
    input_size = 256
    batch_size = 4

    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize Bob and Eve
    bob_dec = BobDecryption(input_size, device=device)
    eve_adv = EveAdversary(input_size, device=device)

    # Create dummy ciphertext and keys (on correct device)
    ciphertext = torch.randn(batch_size, input_size).to(device)
    key_sequences = [np.random.rand(input_size) for _ in range(8)]

    print("Bob attempting decryption with correct keys...")
    plaintext_bob = bob_dec.decrypt(ciphertext, key_sequences)

    print("Eve attempting decryption without keys...")
    plaintext_eve = eve_adv.attack(ciphertext)

    print(f"Ciphertext shape: {ciphertext.shape}")
    print(f"Bob's plaintext shape: {plaintext_bob.shape}")
    print(f"Eve's plaintext shape: {plaintext_eve.shape}")
    print(f"Decryption completed!")

Using device: cpu
Bob attempting decryption with correct keys...
Eve attempting decryption without keys...
Ciphertext shape: torch.Size([4, 256])
Bob's plaintext shape: torch.Size([4, 256])
Eve's plaintext shape: torch.Size([4, 256])
Decryption completed!
