In [1]:
import torch
import torch.nn as nn
import numpy as np

# ============ Helper Classes (Kept for reference/future use) ============
# Note: Current implementation uses fully differentiable neural architecture
# These classes can be used for additional security layers if needed

class VPDEncoder:
    """Virtual Planet Domain encoding (non-differentiable)"""
    def __init__(self, radius=1.0):
        self.radius = radius

    def encode(self, data):
        if isinstance(data, torch.Tensor):
            data = data.cpu().numpy()
        data_norm = (data - data.min()) / (data.max() - data.min() + 1e-8)
        theta = data_norm * 2 * np.pi
        phi = data_norm * np.pi
        x = self.radius * np.sin(phi) * np.cos(theta)
        y = self.radius * np.sin(phi) * np.sin(theta)
        z = self.radius * np.cos(phi)
        return np.stack([x, y, z], axis=-1)


class PixelIntershuffler:
    """3D pixel intershuffling (non-differentiable)"""
    def __init__(self, key_sequence):
        self.key_sequence = key_sequence

    def generate_shuffle_indices(self, length):
        indices = np.arange(length)
        np.random.seed(int(np.sum(self.key_sequence * 1e10) % (2**32)))
        np.random.shuffle(indices)
        return indices

    def shuffle(self, data):
        original_shape = data.shape
        flat_data = data.flatten()
        indices = self.generate_shuffle_indices(len(flat_data))
        shuffled = flat_data[indices]
        return shuffled.reshape(original_shape), indices


class ZigzagXOR:
    """Zigzag XOR operation (non-differentiable)"""
    def __init__(self, key_sequences):
        self.key_sequences = key_sequences

    def apply_xor(self, data, round_num=0):
        key = self.key_sequences[round_num % len(self.key_sequences)]
        if len(key) < len(data):
            key = np.tile(key, (len(data) // len(key) + 1))[:len(data)]
        else:
            key = key[:len(data)]
        data_int = (data * 255).astype(np.uint8)
        key_int = (key * 255).astype(np.uint8)
        xor_result = np.bitwise_xor(data_int, key_int)
        return xor_result.astype(np.float32) / 255.0


# ============ Alice Network (Encryption) ============
class AliceNetwork(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_rounds=3):
        super(AliceNetwork, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_rounds = num_rounds

        # Simplified architecture for differentiability
        # Encoder pathway
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU()
        )

        # Key mixing layers
        self.key_mixer = nn.ModuleList([
            nn.Linear(input_size, hidden_size) for _ in range(num_rounds)
        ])

        # Encryption layers
        self.encrypt_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size * 2, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Tanh()
            ) for _ in range(num_rounds)
        ])

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Tanh()
        )

    def forward(self, x, key_sequences):
        """
        Forward pass through Alice network
        x: input message (batch_size, input_size)
        key_sequences: list of 8 chaotic key sequences
        """
        batch_size = x.size(0)
        device = x.device

        # Convert key sequences to tensor
        keys_tensor = torch.FloatTensor(np.array(key_sequences[:self.num_rounds])).to(device)

        # Initial encoding
        encoded = self.encoder(x)

        # Multiple rounds of key-based encryption
        for round_idx in range(self.num_rounds):
            # Mix with key
            key_features = self.key_mixer[round_idx](keys_tensor[round_idx].unsqueeze(0).expand(batch_size, -1))

            # Combine message and key features
            combined = torch.cat([encoded, key_features], dim=1)

            # Encrypt
            encoded = self.encrypt_layers[round_idx](combined)

        # Final output
        ciphertext = self.output(encoded)

        return ciphertext


# ============ Complete Alice Encryption Module ============
class AliceEncryption:
    def __init__(self, input_size, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.alice = AliceNetwork(input_size).to(device)

    def encrypt(self, message, key_sequences):
        """
        Encrypt message using Alice network
        message: torch.Tensor or numpy array
        key_sequences: list of 8 chaotic sequences from DCGAN key generator
        """
        if isinstance(message, np.ndarray):
            message = torch.FloatTensor(message)

        # Move message to the same device as the model
        message = message.to(self.device)

        if len(message.shape) == 1:
            message = message.unsqueeze(0)

        self.alice.eval()
        with torch.no_grad():
            ciphertext = self.alice(message, key_sequences)

        return ciphertext

    def train_step(self, message, key_sequences, optimizer):
        """Single training step for Alice"""
        self.alice.train()
        optimizer.zero_grad()

        ciphertext = self.alice(message, key_sequences)

        return ciphertext


# ============ Usage Example ============
if __name__ == "__main__":
    # Example usage
    input_size = 256
    batch_size = 4

    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Initialize Alice
    alice_enc = AliceEncryption(input_size, device=device)

    # Create dummy message and keys (on the correct device)
    message = torch.randn(batch_size, input_size).to(device)

    # Simulate key sequences from Part 1
    key_sequences = [np.random.rand(input_size) for _ in range(8)]

    print("Encrypting message...")
    ciphertext = alice_enc.encrypt(message, key_sequences)

    print(f"Message shape: {message.shape}")
    print(f"Ciphertext shape: {ciphertext.shape}")
    print(f"Encryption completed!")

Using device: cpu
Encrypting message...
Message shape: torch.Size([4, 256])
Ciphertext shape: torch.Size([4, 256])
Encryption completed!
