# Training Notebook

In [55]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from ml_attack import *

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from sklearn.metrics import confusion_matrix, classification_report


## Dataset creation

Training debug:

In [56]:
# Neural solver model for Module-LWE using both Fourier mapping and FFT transformation.
class LinearComplex(nn.Module):
    def __init__(self, params):
        """
        n: Secret dimension (e.g., 8)
        q: Modulus
        """
        super(LinearComplex, self).__init__()
        self.q = params['q']
        self.n = params['n']
        self.k = params['k']
        self.secret_type = params['secret_type']
        mean_s, _, std_s = get_vector_distribution(params, self.secret_type, params.get('hw', -1))

        self.guessed_secret = nn.Parameter(nn.init.normal_(torch.empty(self.n * self.k, dtype=torch.float), mean=mean_s, std=std_s), requires_grad=True)

        self.C = nn.Parameter(nn.init.normal_(torch.empty(self.n * self.k, dtype=torch.float), mean=0, std=1), requires_grad=True)

    def forward(self, A_batch):
        
        return torch.tensordot(A_batch, self.guessed_secret, dims=1)

In [57]:
def train_model(model, dataset, params, n_epochs=10, lr=0.01, check_every=10):
    # Get secret dimension from first sample.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    
    # Get the secret from the dataset
    secret_np = dataset.get_secret()
    secret = torch.tensor(secret_np, dtype=torch.float, device=device).view(-1)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    A = torch.tensor(dataset.get_A(), dtype=torch.float, device=device)
    b = torch.tensor(dataset.get_B(), dtype=torch.float, device=device)

    for epoch in range(1, n_epochs + 1):
        optimizer.zero_grad()
        pred_b = model(A)

        b_loss = loss_fn(pred_b, b)
            
        b_loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}/{n_epochs}, Loss: {b_loss:.4f}")

        if epoch % check_every == 0:
            with torch.no_grad():
                guessed_secret = model.guessed_secret.round().cpu().numpy()
                if check_secret(guessed_secret, dataset.get_A(), dataset.get_B(), params):
                    print(f"Secret guessed correctly at epoch {epoch}!")
                    break

    return model

In [72]:
params = {
    'n': 256,
    'q': 3329,
    'k': 4,
    'secret_type': 'cbd',
    'error_type': 'cbd',
    'eta': 2,
    'mod_q': False,
}

lwe_dataset = LWEDataset(params)
lwe_dataset.initialize()

In [73]:
model = LinearComplex(params)

train_model(model, lwe_dataset, params, n_epochs=20000, lr=1, check_every=10)

Epoch 1/20000, Loss: 6824699904.0000
Epoch 2/20000, Loss: 851795574784.0000
Epoch 3/20000, Loss: 58348142592.0000
Epoch 4/20000, Loss: 190285938688.0000
Epoch 5/20000, Loss: 485796085760.0000
Epoch 6/20000, Loss: 342155427840.0000
Epoch 7/20000, Loss: 80125345792.0000
Epoch 8/20000, Loss: 8454282752.0000
Epoch 9/20000, Loss: 137255878656.0000
Epoch 10/20000, Loss: 246357442560.0000
Epoch 11/20000, Loss: 203243945984.0000
Epoch 12/20000, Loss: 78626504704.0000
Epoch 13/20000, Loss: 4036806656.0000
Epoch 14/20000, Loss: 34003214336.0000
Epoch 15/20000, Loss: 108655624192.0000
Epoch 16/20000, Loss: 134597369856.0000
Epoch 17/20000, Loss: 88845467648.0000
Epoch 18/20000, Loss: 24463495168.0000
Epoch 19/20000, Loss: 2375588608.0000
Epoch 20/20000, Loss: 31982256128.0000
Epoch 21/20000, Loss: 69662031872.0000
Epoch 22/20000, Loss: 71592394752.0000
Epoch 23/20000, Loss: 38659407872.0000
Epoch 24/20000, Loss: 6889768960.0000
Epoch 25/20000, Loss: 5246318592.0000
Epoch 26/20000, Loss: 274828288

LinearComplex()

In [74]:
A = lwe_dataset.get_A()
b = lwe_dataset.get_B()

raw_guessed_secret = np.linalg.pinv(A) @ b
raw_guessed_secret = torch.tensor(raw_guessed_secret, dtype=torch.float)

In [75]:
# Check the guessed secret
#raw_guessed_secret = model.guessed_secret.detach().cpu()
guessed_secret = raw_guessed_secret.round().numpy()
guessed_secret[guessed_secret == -0.0] = 0.0
guessed_secret[guessed_secret > params['q'] // 2] -= params['q']

real_secret = lwe_dataset.get_secret()
real_secret[real_secret > params['q'] // 2] -= params['q']

print("Raw Guessed secret:", raw_guessed_secret)
print("Guessed secret:", guessed_secret)
print("Actual secret:", real_secret)

Raw Guessed secret: tensor([ 1.0398,  0.0269,  0.0503,  ...,  0.9741, -0.9928, -0.0079])
Guessed secret: [ 1.  0.  0. ...  1. -1.  0.]
Actual secret: [ 1  0  0 ...  1 -1  0]


In [76]:
# Check the differences between the guessed and actual secret
diff = guessed_secret - real_secret
raw_diff = raw_guessed_secret[diff != 0]
raw_diff[raw_diff > params['q'] // 2] -= params['q']
if len(diff[diff != 0]) > 0:
    print("Difference:", raw_diff)
    print("real_secret:", real_secret[diff != 0])
    print("guessed_secret:", guessed_secret[diff != 0])

In [77]:
from itertools import product

# Find values in raw_guessed_secret that are within ±0.1 of an integer
close_to_integer = torch.abs(raw_guessed_secret - torch.round(raw_guessed_secret)) < 0.4
uncertain_count = torch.sum(~close_to_integer).item()
print("Number of uncertain values:", uncertain_count)

# Calculate the number of brute force attacks to perform
brute_force_attempts = 2 ** uncertain_count
print("Number of brute force attempts required:", brute_force_attempts)

# Get the indices of uncertain values
uncertain_indices = torch.where(~close_to_integer)[0]

real_uncertain_secret = real_secret[uncertain_indices]
print("Real uncertain secret:", real_uncertain_secret)

# Perform brute force attack
raw_uncertain_secret = raw_guessed_secret[uncertain_indices]
raw_uncertain_secret[raw_uncertain_secret > params['q'] // 2] -= params['q']
raw_uncertain_secret = raw_uncertain_secret[torch.abs(raw_uncertain_secret) <= params['eta']]

lower_values = torch.floor(raw_uncertain_secret).long()
upper_values = torch.ceil(raw_uncertain_secret).long()

#values = product(*zip(lower_values, upper_values))

#for value in values:
#    print("Trying values:", value)


Number of uncertain values: 0
Number of brute force attempts required: 1
Real uncertain secret: []


In [78]:
def report(real_secret, guessed_secret):
    """
    Print classification report and confusion matrix.
    """
  
    # Get unique sorted labels and compute confusion matrix
    labels = np.unique(np.concatenate((real_secret, guessed_secret)))
    cm = confusion_matrix(real_secret, guessed_secret, labels=labels)

    # Header
    header = "       |" + "".join([f"{l:>6}" for l in labels]) + " | Accuracy"
    print("Confusion Matrix:")
    print(header)
    print("-" * len(header))

    # Rows
    for i, row in enumerate(cm):
        label = f"{labels[i]:>6} |"
        values = "".join([f"{v:6}" for v in row])

        correct = row[i]
        total = row.sum()
        acc = correct / total if total > 0 else 0.0
        print(label + values + f" | {acc:4.1%}")

    # Print classification report
    print("\nClassification Report:")
    print(classification_report(real_secret, guessed_secret, zero_division=0))

    # Calculate Mean Squared Error (MSE)
    mse = np.mean((real_secret - guessed_secret) ** 2)
    print(f"Mean Squared Error (MSE): {mse:.4f}")

    # Calculate Mean Absolute Error (MAE)
    mae = np.mean(np.abs(real_secret - guessed_secret))
    print(f"Mean Absolute Error (MAE): {mae:.4f}")



report(real_secret, guessed_secret)

Confusion Matrix:
       |  -2.0  -1.0   0.0   1.0   2.0 | Accuracy
-------------------------------------------------
  -2.0 |    64     0     0     0     0 | 100.0%
  -1.0 |     0   255     0     0     0 | 100.0%
   0.0 |     0     0   382     0     0 | 100.0%
   1.0 |     0     0     0   248     0 | 100.0%
   2.0 |     0     0     0     0    75 | 100.0%

Classification Report:
              precision    recall  f1-score   support

          -2       1.00      1.00      1.00        64
          -1       1.00      1.00      1.00       255
           0       1.00      1.00      1.00       382
           1       1.00      1.00      1.00       248
           2       1.00      1.00      1.00        75

    accuracy                           1.00      1024
   macro avg       1.00      1.00      1.00      1024
weighted avg       1.00      1.00      1.00      1024

Mean Squared Error (MSE): 0.0000
Mean Absolute Error (MAE): 0.0000
