# Training Notebook

In [1]:
from kyber_py import *
from fourier import *

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import sympy as sp
from sympy import ntheory

import math
import matplotlib.pyplot as plt
from scipy.linalg import circulant


## Dataset creation

In [2]:
params = {
    'n': 8,
    'q': 17,
    'k': 2,
    'secret_type': 'binary',
    'error_type': 'binary',
    'seed': 0,
    'hw': 3,
}

In [3]:
def neg_circ(a):
    """
    Generates a negative circulant matrix from the input vector a.
    """
    n = len(a)
    A = circulant(a)
    tri = np.triu_indices(n, 1)
    A[tri] *= -1
    return A

n = params['n']
q = params['q']
k = params['k']
l = params['l'] if 'l' in params else k

# Generate the MLWE matrices
mlwe = MLWE(params)
random_bytes = mlwe.get_random_bytes()
result = mlwe.generate(random_bytes)
A, s, e, B = result
# The generated LWE parameters will have shape (kn, ln)


In [4]:
A_lwe = torch.zeros((k*n, l*n), dtype=torch.int32)
for i in range(k): # rows
  for j in range(l): # columns
    a = A[i,j].to_list()
    neg_circ_a = neg_circ(a)
    A_lwe[i*n:(i+1)*n, j*n:(j+1)*n] = torch.from_numpy(neg_circ_a)

print(A_lwe.shape)
print(A_lwe[0:8, 0:8])
print(A_lwe[8:16, 0:8])
print(A_lwe[0:8, 8:16])
print(A_lwe[8:16, 8:16])
print("A: ", A.to_list())

torch.Size([16, 16])
tensor([[  1,  -5, -10, -11,  -2, -16, -12, -10],
        [ 10,   1,  -5, -10, -11,  -2, -16, -12],
        [ 12,  10,   1,  -5, -10, -11,  -2, -16],
        [ 16,  12,  10,   1,  -5, -10, -11,  -2],
        [  2,  16,  12,  10,   1,  -5, -10, -11],
        [ 11,   2,  16,  12,  10,   1,  -5, -10],
        [ 10,  11,   2,  16,  12,  10,   1,  -5],
        [  5,  10,  11,   2,  16,  12,  10,   1]], dtype=torch.int32)
tensor([[ 12,  -4,  -1, -10,  -4, -16,  -2, -11],
        [ 11,  12,  -4,  -1, -10,  -4, -16,  -2],
        [  2,  11,  12,  -4,  -1, -10,  -4, -16],
        [ 16,   2,  11,  12,  -4,  -1, -10,  -4],
        [  4,  16,   2,  11,  12,  -4,  -1, -10],
        [ 10,   4,  16,   2,  11,  12,  -4,  -1],
        [  1,  10,   4,  16,   2,  11,  12,  -4],
        [  4,   1,  10,   4,  16,   2,  11,  12]], dtype=torch.int32)
tensor([[ 12, -12, -11, -10, -16, -16, -10, -13],
        [ 13,  12, -12, -11, -10, -16, -16, -10],
        [ 10,  13,  12, -12, -11, -10, 

In [5]:
# Transform the secret vector
s_lwe = torch.tensor(s.to_list(), dtype=torch.int32)
s_lwe = s_lwe.reshape(n*l, 1)
print("s_lwe shape:", s_lwe.shape)
print("s_lwe:", s_lwe)
print("s: ", s.to_list())

s_lwe shape: torch.Size([16, 1])
s_lwe: tensor([[0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [0],
        [1],
        [1],
        [1],
        [0],
        [0],
        [0],
        [0],
        [0],
        [1]], dtype=torch.int32)
s:  [[[0, 0, 0, 1, 1, 0, 0, 1]], [[1, 1, 0, 0, 0, 0, 0, 1]]]


In [6]:
b_lwe = torch.tensor(B.to_list(), dtype=torch.int32)
b_lwe = b_lwe.reshape(n*k, 1)
print("b_lwe shape:", b_lwe.shape)
print("b_lwe:", b_lwe)
print("B: ", B.to_list())

b_lwe shape: torch.Size([16, 1])
b_lwe: tensor([[15],
        [16],
        [10],
        [ 4],
        [ 6],
        [10],
        [16],
        [ 3],
        [11],
        [14],
        [ 5],
        [ 5],
        [ 5],
        [ 9],
        [ 8],
        [ 8]], dtype=torch.int32)
B:  [[[15, 16, 10, 4, 6, 10, 16, 3]], [[11, 14, 5, 5, 5, 9, 8, 8]]]


In [7]:
e_lwe = torch.tensor(e.to_list(), dtype=torch.int32)
e_lwe = e_lwe.reshape(n*k, 1)
print("e_lwe shape:", e_lwe.shape)
print("e_lwe:", e_lwe)
print("e: ", e.to_list())

e_lwe shape: torch.Size([16, 1])
e_lwe: tensor([[0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [1],
        [0],
        [0],
        [0],
        [1],
        [1],
        [0],
        [1],
        [0],
        [0]], dtype=torch.int32)
e:  [[[0, 0, 0, 0, 1, 0, 1, 0]], [[0, 0, 1, 1, 0, 1, 0, 0]]]


Check correctness:

In [8]:
(A_lwe @ s_lwe + e_lwe) % q == b_lwe

tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

Function to transform MLWE and RLWE into LWE:

In [9]:
def neg_circ(a):
    """
    Generates a negative circulant matrix from the input vector a.
    """
    n = len(a)
    A = circulant(a)
    tri = np.triu_indices(n, 1)
    A[tri] *= -1
    return A

def get_lwe(params):
  n = params['n']
  k = params['k']
  l = params['l'] if 'l' in params else k

  # Generate the MLWE matrices
  mlwe = MLWE(params)
  random_bytes = mlwe.get_random_bytes()
  A, s, e, B = mlwe.generate(random_bytes)
  
  # Transform A:
  A_lwe = torch.zeros((k*n, l*n), dtype=torch.int32)
  for i in range(k): # rows
    for j in range(l): # columns
      a = A[i,j].to_list()
      neg_circ_a = neg_circ(a)
      A_lwe[i*n:(i+1)*n, j*n:(j+1)*n] = torch.from_numpy(neg_circ_a)

  # Transform s:
  s_lwe = torch.tensor(s.to_list(), dtype=torch.int32)
  s_lwe = s_lwe.reshape(n*l, 1)

  # Transform B:
  b_lwe = torch.tensor(B.to_list(), dtype=torch.int32)
  b_lwe = b_lwe.reshape(n*k, 1)

  # Transform e:
  e_lwe = torch.tensor(e.to_list(), dtype=torch.int32)
  e_lwe = e_lwe.reshape(n*k, 1)

  return A_lwe, s_lwe, b_lwe, e_lwe

def transform_lwe(A : list, b : list):
  k = len(A)
  l = len(A[0])
  n = len(A[0][0])

  # Transform A:
  A_lwe = torch.zeros((k*n, l*n), dtype=torch.int32)
  for i in range(k): # rows
    for j in range(l): # columns
      a = A[i][j]
      neg_circ_a = neg_circ(a)
      A_lwe[i*n:(i+1)*n, j*n:(j+1)*n] = torch.from_numpy(neg_circ_a)

  # Transform b:
  b_lwe = torch.tensor(b, dtype=torch.int32)
  b_lwe = b_lwe.reshape(n*k, 1)
  
  return A_lwe, b_lwe

In [10]:
params = {
    'n': 8,
    'q': 17,
    'k': 1,
    'secret_type': 'binary',
    'error_type': 'binary',
    'seed': 0,
    'hw': 3,
}
A_lwe, s_lwe, b_lwe, e_lwe = get_lwe(params)
print("secret: ", s_lwe.flatten().tolist())

secret:  [0, 1, 0, 1, 0, 0, 1, 0]


In [11]:
class LWEDataset(Dataset):
    def __init__(self, params, num_samples=1000):
        """
        params for the LWE scheme
        """
        n = params['n']
        k = params['k']

        self.mlwe = MLWE(params)
        num_gen = num_samples // (n*k) + 1

        self.A = torch.zeros((num_gen *n*k, k*n), dtype=torch.float64)
        self.B = torch.zeros((num_gen *n*k), dtype=torch.float64)
        
        random_byte = self.mlwe.get_random_bytes()
        secret = self.mlwe.generate_secret(random_byte)

        for i in range(num_gen):
            updated_byte = self._increase_byte(random_byte, i)
            A, B = self.mlwe.generate_A_B(secret, updated_byte)
            A_lwe, B_lwe = transform_lwe(A.to_list(), B.to_list())
            self.A[i*n*k:(i+1)*n*k, :] = A_lwe.squeeze()
            self.B[i*n*k:(i+1)*n*k] = B_lwe.squeeze()

        self.A = self.A[:num_samples, :]
        self.B = self.B[:num_samples]
        self.secret = torch.tensor(secret.to_list(), dtype=torch.float64).squeeze()

    def _increase_byte(self, input_bytes, N):
        return (int.from_bytes(input_bytes, byteorder='big') + N).to_bytes(len(input_bytes), byteorder='big')
        
    def __len__(self):
        return len(self.A)
    
    def __getitem__(self, index):
        return self.A[index], self.B[index]
    
    def get_secret(self):
        return self.secret

In [73]:
params = {
    'n': 50,
    'q': 3329,
    'k': 1,
    'secret_type': 'binary',
    'error_type': 'binary',
    'seed': 0,
    'hw': 1
}

lwe_dataset = LWEDataset(params, num_samples=100000)
lwe_dataloader = DataLoader(lwe_dataset, batch_size=8, shuffle=True)

Training debug:

In [85]:
# Neural solver model for Module-LWE using both Fourier mapping and FFT transformation.
class SimpleSolver(nn.Module):
    def __init__(self, params):
        """
        n: Secret dimension (e.g., 8)
        q: Modulus
        """
        super(SimpleSolver, self).__init__()
        self.q = params['q']
        self.n = params['n']
        self.k = params['k']
        self.secret_type = params['secret_type']
        self.error_type = params['error_type']
        
        self.guessed_secret = nn.Parameter(nn.init.normal_(torch.empty(self.n, dtype=torch.float64), mean=0.5, std=0.5), requires_grad=True)

    def forward(self, A_batch, B_batch=None):
        
        # --- Process s_hat ---
        s_complex = fourier_int_to_complex(self.guessed_secret, self.q)  # shape: (8,), complex
        
        result_complex = s_complex ** A_batch
        
        result_complex = torch.prod(result_complex, dim=1)

        # Transform also B_batch to complex if present (for easier loss evaluation)
        if B_batch is not None:
            return result_complex, fourier_int_to_complex(B_batch, self.q)
        else:
            return result_complex

In [86]:
def parabolic_regularization(weights):
    return torch.mean(torch.pow(torch.pow(weights, 2) - weights, 2))

def angular_loss(pred, target):
    """
    Calculates the angular distance between two complex tensors.
    The angular distance is defined as 1 - cos(theta), where theta is the angle between the two complex numbers.
    Maximum loss = 1
    """
    pred_angle = torch.angle(pred)
    target_angle = torch.angle(target)
    angular_distance = 1 - torch.cos(pred_angle - target_angle)
    return torch.mean(angular_distance)


In [None]:
# Calculate the initial loss for the model
model = SimpleSolver(params)
pred_B, target_B = model(lwe_dataset.A, lwe_dataset.B)
loss = angular_loss(pred_B, target_B)
print("Initial loss:", loss.item())

# Calculate the initial regularization loss
regularization_loss = parabolic_regularization(model.guessed_secret)
print("Regularization loss:", regularization_loss.item())

# Calculate the total loss
total_loss = loss + regularization_loss
print("Initial total loss:", total_loss.item())

# Calculate the secret loss
secret_loss = torch.mean(torch.pow(model.guessed_secret - lwe_dataset.secret, 2))
print("Initial secret loss:", secret_loss.item())


Initial loss: 0.9994718788430906
Regularization loss: 0.14554728711308396
Initial total loss: 1.1450191659561746
Initial secret loss: 0.5611939316359915


In [88]:
# Calculate the global minimum of the loss function
A_lwe, b_lwe = lwe_dataset.A, lwe_dataset.B
real_secret = lwe_dataset.get_secret()

s_complex = fourier_int_to_complex(real_secret, params['q'])
        
result_complex = s_complex ** A_lwe
        
result_complex = torch.prod(result_complex, dim=1)

# Calculate the loss
loss = angular_loss(result_complex, fourier_int_to_complex(b_lwe, params['q']))
print("Global minimum loss:", loss.item())

Global minimum loss: 8.592832149590702e-07


In [89]:
def compute_total_loss(main_loss, regularization, step, max_step):
    # Sigmoid-based schedule
    reg_weight = 1 / (1 + math.exp(-((step / max_step) * 10 - 5)))
    
    return main_loss + reg_weight * regularization * 50

def train_model(model, dataloader, log_interval=100, n_epochs=10, lr=0.01):
    # Get secret dimension from first sample.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    # Get the secret from the dataset
    secret = dataloader.dataset.get_secret()
    secret = secret.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    for epoch in range(1, n_epochs + 1):
        step_loss = 0.0
        for step, (A_batch, B_batch) in enumerate(dataloader):
            A_batch = A_batch.to(device)
            B_batch = B_batch.to(device)
            
            optimizer.zero_grad()
            pred_B_complex, B_complex = model(A_batch, B_batch)

            b_loss = angular_loss(pred_B_complex, B_complex)

            # Regularization term
            reg_loss = parabolic_regularization(model.guessed_secret)
            
            # Combine losses
            total_loss = compute_total_loss(b_loss, reg_loss, step, len(dataloader))
            
            total_loss.backward()
            optimizer.step()
            step_loss += total_loss.item()

            if step % log_interval == 0:
                avg_loss = step_loss / log_interval
                step_loss = 0.0

                # Calculate loss between the secret and s_hat
                s_loss = criterion(model.guessed_secret, secret)
                print(f"Epoch [{epoch}/{n_epochs}], Step [{step}/{len(dataloader)}], "
                      f"Loss: {avg_loss:.4f}, Secret Loss: {s_loss.item():.4f}, "
                      f"Regularization Loss: {reg_loss.item():.4f}")

        
    
    return model

In [90]:
train_model(model, lwe_dataloader, log_interval=500, n_epochs=10, lr=0.01)

Epoch [1/10], Step [0/12500], Loss: 0.0109, Secret Loss: 0.5656, Regularization Loss: 0.1455
Epoch [1/10], Step [100/12500], Loss: 1.0517, Secret Loss: 0.5873, Regularization Loss: 0.1306
Epoch [1/10], Step [200/12500], Loss: 1.0490, Secret Loss: 0.5946, Regularization Loss: 0.1190
Epoch [1/10], Step [300/12500], Loss: 1.0722, Secret Loss: 0.5135, Regularization Loss: 0.0967
Epoch [1/10], Step [400/12500], Loss: 1.0618, Secret Loss: 0.4024, Regularization Loss: 0.1265
Epoch [1/10], Step [500/12500], Loss: 1.0277, Secret Loss: 0.4229, Regularization Loss: 0.0851
Epoch [1/10], Step [600/12500], Loss: 1.0482, Secret Loss: 0.4145, Regularization Loss: 0.0778
Epoch [1/10], Step [700/12500], Loss: 1.0859, Secret Loss: 0.3978, Regularization Loss: 0.0795
Epoch [1/10], Step [800/12500], Loss: 1.0607, Secret Loss: 0.4068, Regularization Loss: 0.0638
Epoch [1/10], Step [900/12500], Loss: 1.0857, Secret Loss: 0.3921, Regularization Loss: 0.0752
Epoch [1/10], Step [1000/12500], Loss: 1.0515, Secre

KeyboardInterrupt: 

In [None]:
# Check the guessed secret
guessed_secret = model.guessed_secret.detach().cpu().numpy()
print("Guessed secret:", guessed_secret)
print("Actual secret:", lwe_dataset.get_secret().numpy())

Guessed secret: [0.28077018 0.6506888  0.11807197 1.0863562  0.14392364 0.4651821
 0.267927   0.08936887]
Actual secret: [0. 0. 0. 1. 0. 0. 0. 0.]
