# Training Notebook

In [1]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from kyber.mlwe import MLWE
from ml_attack import check_secret, clean_secret, get_no_mod, LWEDataset, get_filename_from_params
from ml_attack.utils import get_lwe_default_params, get_reduction_default_params, get_continuous_reduction_default_params, get_default_params, get_b_distribution, get_percentage_true_b, get_true_mask, \
    get_vector_distribution
from ml_attack.train import LinearComplex, train_until_stall

import numpy as np
import torch
import torch.nn as nn

from sklearn.metrics import confusion_matrix, classification_report
from sklearn.linear_model import HuberRegressor

  from .siever_params import SieverParams  # noqa


## Dataset creation

Training debug:

In [4]:
params = get_default_params()
params.update(get_reduction_default_params())
params.update({
    'n': 32,
    'q': 3329,
    'secret_type': 'binary',

    'num_gen': 4,
    'seed': 0,

    'reduction_std': 2,
    'reduction_factor': 0.875,
    'reduction_resampling': True,
    'approximation_std': 3,
    
    'penalty': 4,
    'verbose': True,
    "save_to": "./../data/"
})

filename = get_filename_from_params(params)

#filename = "./../data/data_n_150_k_1_s_binary_ff7b0.pkl"

reload = False
if os.path.exists(filename) and reload:
    print(f"Loading dataset from {filename}")
    dataset = LWEDataset.load_reduced(filename)
    params = dataset.params
else:
    print(f"Generating dataset and saving to {filename}")
    dataset = LWEDataset(params)
    dataset.initialize()
    dataset.reduction()
    dataset.approximate_b()
    dataset.save_reduced()

Generating dataset and saving to ./../data/data_n_32_k_1_s_binary_a081a.pkl
Reducing 39 matrices using 8 threads.
 - Starting std: 1.0381403665259166
Starting new flatter run.
 - Starting std: 1.0479145484710382
Starting new flatter run.
 - Starting std: 1.0256035579943001
Starting new flatter run.
 - Starting std: 1.0623034874305537
Starting new flatter run.
 - Starting std: 1.0424691987224404
Starting new flatter run.
 - Starting std: 1.022070646549283
Starting new flatter run.
 - Starting std: 1.0318361839234185
Starting new flatter run.
 - Starting std: 1.04455331701888
Starting new flatter run.
 - Solvable for: 60 out of 60
Reduction is solvable.
 - Starting std: 1.0401050038409954
Starting new flatter run.
 - Solvable for: 60 out of 60
Reduction is solvable.
 - Starting std: 1.0473287851852748
Starting new flatter run.
 - Solvable for: 60 out of 60
Reduction is solvable.
 - Starting std: 1.0474728674522353
Starting new flatter run.
 - Solvable for: 60 out of 60
Reduction is solva

In [29]:
get_percentage_true_b(dataset, verbose=True)

True B is the best candidate: 2340 / 2340 (100.00%)


np.float64(1.0)

In [7]:
#for idx, value in enumerate(b_real):
#  print(f"Index {idx}: True B = {value}, best_b = {dataset.b_candidates[idx][np.argmax(dataset.b_probs[idx])]}, prob = {np.max(dataset.b_probs[idx]):.4f}")

In [8]:
A_reduced = dataset.get_A()
best_b = np.array(dataset.best_b)

# 1) Train until stall
use_gradient = False
if use_gradient:
    lr = 1e-3
    check_every = 10

    # Check if GPU is available and use it if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LinearComplex(params).to(device)
    A_reduced = torch.tensor(A_reduced, dtype=torch.float).to(device)
    best_b = torch.tensor(best_b, dtype=torch.float).to(device)

    epoch = 0
    loss, epoch = train_until_stall(model, A_reduced, best_b, dataset, epoch=epoch)
    if loss == 0:
        print("Secret guessed correctly at epoch {}!".format(epoch))
    else:
        print(f"Stalling detected at loss {loss:.4f}.")

    raw_guessed_secret = model.guessed_secret.detach().cpu()
else:
    model = HuberRegressor(fit_intercept=True, max_iter=10000, alpha=0.0001, epsilon=1.25)
    raw_guessed_secret = model.fit(A_reduced, best_b).coef_

    guessed_secret = clean_secret(raw_guessed_secret, params)
    if check_secret(guessed_secret, dataset.A, dataset.B, params):
        print("Secret guessed correctly!")
    else:
        print("Wrong secret guessed!")

Secret guessed correctly!


In [9]:
real_mask = get_true_mask(dataset)
outlier_mask = model.outliers_
# Compare not mask to outlier_mask to see how many outliers are actual outliers
num_outliers = np.sum(outlier_mask)
num_actual_outliers = np.sum(outlier_mask & (~real_mask))
print(f"Number of outliers detected by model: {num_outliers}")
print(f"Number of actual outliers among detected: {num_actual_outliers}")
if num_outliers > 0:
  print(f"Fraction of detected outliers that are actual: {num_actual_outliers / num_outliers:.2%}")

Number of outliers detected by model: 1117
Number of actual outliers among detected: 0
Fraction of detected outliers that are actual: 0.00%


In [10]:
non_outlier_indices = np.where(~outlier_mask)[0]
get_percentage_true_b(dataset, verbose=True, indices=non_outlier_indices)

True B is the best candidate: 1223 / 1223 (100.00%)


np.float64(1.0)

In [11]:
# Check the guessed secret
raw_guessed_secret = raw_guessed_secret.cpu().detach().numpy() if use_gradient else raw_guessed_secret
guessed_secret = clean_secret(raw_guessed_secret, params)

real_secret = dataset.get_secret()

print("Raw Guessed secret:", raw_guessed_secret)
print("Guessed secret:", guessed_secret)
print("Actual secret:", real_secret)

Raw Guessed secret: [-0.08226637  1.04515821  0.97493014  1.0288899  -0.02086501  0.05166768
  1.06212252 -0.04446577  1.00778634  0.94790364  0.97576549  0.05087982
  0.02434674  0.97663085 -0.04844022  0.00364799  0.94075949  0.93244438
  0.94255363 -0.05538213  0.01427396  0.94886504  0.98788517  1.01927622
  0.00538173  1.01750243  0.03316809  0.94129635 -0.05706733 -0.0177757
  1.02416635  0.00304447]
Guessed secret: [0. 1. 1. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1.
 0. 1. 0. 1. 0. 0. 1. 0.]
Actual secret: [0 1 1 1 0 0 1 0 1 1 1 0 0 1 0 0 1 1 1 0 0 1 1 1 0 1 0 1 0 0 1 0]


In [12]:
# Check the differences between the guessed and actual secret
diff = guessed_secret - real_secret
raw_diff = raw_guessed_secret[diff != 0]
raw_diff[raw_diff > params['q'] // 2] -= params['q']
diff_indices = np.nonzero(diff)
if len(diff[diff != 0]) > 0:
    print("Number of differences:", len(diff[diff != 0]))
    print("Difference:", raw_diff)
    print("real_secret:", real_secret[diff != 0])
    print("guessed_secret:", guessed_secret[diff != 0])
    print("Indices of differences:", diff_indices)

In [13]:
close_to_integer = np.abs(raw_guessed_secret - np.round(raw_guessed_secret))
sorted_indices = np.argsort(-close_to_integer)
print("Sorted uncertain indices:", sorted_indices)
print("Sorted uncertain values:", np.round(close_to_integer[sorted_indices], 3))

if len(diff_indices[0]) > 0:
  diff_indices_in_sorted = [np.where(sorted_indices == i)[0][0] for i in diff_indices[0]]
  print("Worst case scenario:", max(diff_indices_in_sorted))

Sorted uncertain indices: [ 0 17  6 16 27 18 28 19  9  5 21 11 14  1  7 26  3  2 12 10 30 13  4 23
 29 25 20 22  8 24 15 31]
Sorted uncertain values: [0.082 0.068 0.062 0.059 0.059 0.057 0.057 0.055 0.052 0.052 0.051 0.051
 0.048 0.045 0.044 0.033 0.029 0.025 0.024 0.024 0.024 0.023 0.021 0.019
 0.018 0.018 0.014 0.012 0.008 0.005 0.004 0.003]


In [14]:
from itertools import product

# Find values in raw_guessed_secret that are within ±0.1 of an integer
close_to_integer = np.abs(raw_guessed_secret - np.round(raw_guessed_secret)) < 0.2
uncertain_count = np.sum(~close_to_integer)
print("Number of uncertain values:", uncertain_count)

# Calculate the number of brute force attacks to perform
brute_force_attempts = 2 ** uncertain_count
print("Number of brute force attempts required:", brute_force_attempts)

# Get the indices of uncertain values
uncertain_indices = np.where(~close_to_integer)[0]

real_uncertain_secret = real_secret[uncertain_indices]
print("Real uncertain secret:", real_uncertain_secret)

# Perform brute force attack
raw_uncertain_secret = raw_guessed_secret[uncertain_indices]
raw_uncertain_secret[raw_uncertain_secret > params['q'] // 2] -= params['q']
raw_uncertain_secret = raw_uncertain_secret[np.abs(raw_uncertain_secret) <= params['eta']]

lower_values = np.floor(raw_uncertain_secret)
upper_values = np.ceil(raw_uncertain_secret)

#values = product(*zip(lower_values, upper_values))

#for value in values:
#    print("Trying values:", value)
    # Create a copy of the guessed secret
#    brute_force_secret = copy.deepcopy(guessed_secret)
    # Update the uncertain values with the current combination
#    for idx, val in zip(uncertain_indices, value):
#        brute_force_secret[idx] = val
    # Check if the guessed secret is correct
#    if check_secret(brute_force_secret, dataset.A, dataset.B, params):
#        print("Brute force attack successful! Guessed secret:", brute_force_secret)
#        break

Number of uncertain values: 0
Number of brute force attempts required: 1
Real uncertain secret: []


In [15]:
def report(real_secret, guessed_secret):
    """
    Print classification report and confusion matrix.
    """
  
    # Get unique sorted labels and compute confusion matrix
    labels = np.unique(np.concatenate((real_secret, guessed_secret)))
    cm = confusion_matrix(real_secret, guessed_secret, labels=labels)

    # Header
    header = "       |" + "".join([f"{l:>6}" for l in labels]) + " | Accuracy"
    print("Confusion Matrix:")
    print(header)
    print("-" * len(header))

    # Rows
    for i, row in enumerate(cm):
        label = f"{labels[i]:>6} |"
        values = "".join([f"{v:6}" for v in row])

        correct = row[i]
        total = row.sum()
        acc = correct / total if total > 0 else 0.0
        print(label + values + f" | {acc:4.1%}")

    # Print classification report
    print("\nClassification Report:")
    print(classification_report(real_secret, guessed_secret, zero_division=0))

report(real_secret, guessed_secret)

Confusion Matrix:
       |   0.0   1.0 | Accuracy
-------------------------------
   0.0 |    15     0 | 100.0%
   1.0 |     0    17 | 100.0%

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        15
           1       1.00      1.00      1.00        17

    accuracy                           1.00        32
   macro avg       1.00      1.00      1.00        32
weighted avg       1.00      1.00      1.00        32

