# Training Notebook

In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


# own code
from kyber_py import *
from fourier import *

## Dataset creation

In [8]:
class MLWEDataset(Dataset):
    def __init__(self, params, num_samples):
        """
        params for the MLWE scheme
        """
        self.mlwe = MLWE(params)
        self.samples = []
        random_byte = self.mlwe.get_random_bytes()
        secret = self.mlwe.generate_secret(random_byte)
        secret_hat = secret.to_ntt()
        for i in range(num_samples):
            updated_byte = self._increase_byte(random_byte, i)
            #A_hat, B_hat = self.mlwe.generate_A_B_hat(secret_hat, updated_byte)
            A_hat = self.mlwe.generate_A_hat(updated_byte)
            B_hat = A_hat @ secret_hat
            
            A_tensor = torch.tensor(A_hat.to_list()).to(dtype=torch.float64)
            B_tensor = torch.tensor(B_hat.to_list()).to(dtype=torch.float64)
            self.samples.append((A_tensor, B_tensor))

        self.secret = torch.tensor(secret.to_list()).float()

    def _increase_byte(self, input_bytes, N):
        return (int.from_bytes(input_bytes, byteorder='big') + N).to_bytes(len(input_bytes), byteorder='big')
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        return self.samples[index]
    
    def get_secret(self):
        return self.secret

## Training with NTT

In [9]:
params = {
    'n': 8,
    'q': 17,
    'k': 1,
    'secret_type': 'binary',
    'error_type': 'binary',
    'seed': 0,
    'hw': 3
}

mlwe_dataset = MLWEDataset(params, 1024)
mlwe_loader = DataLoader(mlwe_dataset, batch_size=32, shuffle=True)

In [10]:
# Neural solver model for Module-LWE using both Fourier mapping and FFT transformation.
class MLWESolver(nn.Module):
    def __init__(self, params):
        """
        n: Secret dimension (e.g., 8)
        q: Modulus
        """
        super(MLWESolver, self).__init__()
        self.q = params['q']
        self.n = params['n']
        self.k = params['k']
        self.secret_type = params['secret_type']
        self.error_type = params['error_type']
        
        root = find_root_of_unity(self.n, self.q)
        if root is None:
            raise ValueError("Root of unity not found for the given n and q.")
        
        self.ntt_zetas = [pow(root, br(i, int(math.log2(self.n))-1), self.q) for i in range(self.n // 2)]
        self.ntt_zetas = torch.tensor(self.ntt_zetas, dtype=torch.float64)
        
        self.ntt_f = pow(self.n // 2, -1, self.q)

        self.guessed_secret = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.k, 1, self.n), gain=1.0))

    def forward(self, A_batch):
        """
        A: Public matrix, shape (batch, 2, 8) with integer entries.
        B: Ground truth vector, shape (batch, 1, 8) with integer entries.
        Returns:
          pred_B: Predicted Fourier-FFT representation of B, shape (batch, 1, 8, 2)
          B_target: Ground truth Fourier-FFT representation of B, shape (batch, 1, 8, 2)
          s_hat: Current estimate of the secret.
        """
        # --- Process s_hat ---
        # Map the trainable secret to complex via Fourier mapping.
        s_complex = fourier_int_to_complex(self.guessed_secret, self.q)  # shape: (8,), complex
        s_complex_ntt = fourier_ntt(s_complex, self.ntt_zetas)

        # --- Multiply in the FFT domain ---
        # Reshape A_batch and s_complex_ntt to handle batch size
        result = torch.stack([fourier_matmul(A_hat, s_complex_ntt, self.ntt_zetas) for A_hat in A_batch])

        result = fourier_complex_to_int(result, self.q)
    
        return result

In [11]:
model = MLWESolver(params)

for name, param in model.named_parameters():
  print(f"Parameter Name: {name}")
  print(f"Requires Grad: {param.requires_grad}")
  print(f"Shape: {param.shape}")
  print(f"Values: {param.data}\n")

Parameter Name: guessed_secret
Requires Grad: True
Shape: torch.Size([1, 1, 8])
Values: tensor([[[ 0.5964,  0.4165, -0.1928, -0.2692,  0.4453, -0.6868, -0.5268,
           0.1806]]])



In [12]:
mlwe_dataset.get_secret()

tensor([[[0., 1., 0., 1., 0., 0., 1., 0.]]])

In [13]:
mse = nn.MSELoss()
loss_secret = mse(model.guessed_secret, mlwe_dataset.get_secret())
print(f"Loss: {loss_secret.item()}")

Loss: 0.672230064868927


In [14]:
def train_model(model, dataloader, n_epochs=100, lr=1e-3):
    # Get secret dimension from first sample.
    secret = dataloader.dataset.get_secret()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    for epoch in range(1, n_epochs + 1):
        epoch_loss = 0.0
        for A_hat, B_hat in dataloader:
            
            optimizer.zero_grad()
            pred_B = model(A_hat)

            loss = criterion(pred_B, B_hat)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        
        # Calculate loss between the secret and s_hat
        s_loss = criterion(model.guessed_secret, secret)
        
        print(f"Epoch {epoch}/{n_epochs} - Loss: {avg_loss:.6f} - Secret Loss: {s_loss.item():.6f}")
    
    return model


In [15]:
model = MLWESolver(params)
trained_model = train_model(model, mlwe_loader, n_epochs=100, lr=1e-3)

Epoch 1/100 - Loss: 109.264784 - Secret Loss: 0.749209
Epoch 2/100 - Loss: 111.600398 - Secret Loss: 0.733288
Epoch 3/100 - Loss: 112.104864 - Secret Loss: 0.717980
Epoch 4/100 - Loss: 111.640498 - Secret Loss: 0.704605
Epoch 5/100 - Loss: 110.221347 - Secret Loss: 0.694326
Epoch 6/100 - Loss: 113.190556 - Secret Loss: 0.683880
Epoch 7/100 - Loss: 110.973790 - Secret Loss: 0.676738
Epoch 8/100 - Loss: 111.989337 - Secret Loss: 0.665154
Epoch 9/100 - Loss: 112.597793 - Secret Loss: 0.656788
Epoch 10/100 - Loss: 113.094740 - Secret Loss: 0.649535
Epoch 11/100 - Loss: 111.309226 - Secret Loss: 0.642847
Epoch 12/100 - Loss: 112.508550 - Secret Loss: 0.638679
Epoch 13/100 - Loss: 112.506064 - Secret Loss: 0.631703


KeyboardInterrupt: 