<a href="https://colab.research.google.com/github/iamaidenok/adversarial-quantum-cryptography/blob/main/Full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
sys.path.append("..")

from src.bb84 import BB84Core
from src.defender import Defender
from src.adversary import Adversary

# Load everything
bb84 = BB84Core(128)
defender = Defender()
adversary = Adversary()

# Or load pre-trained weights if you want
# defender.load_state_dict(torch.load("src/defender.pth"))
# adversary.load_state_dict(torch.load("src/adversary.pth"))

opt_d = optim.Adam(defender.parameters(), lr=0.001)
opt_a = optim.Adam(adversary.parameters(), lr=0.001)

eve_accs = []

for epoch in tqdm(range(5000)):
    ak, ab, bb = bb84.generate_keys_bases()
    state = torch.randn(128, 16)
    angles = defender(state).detach().cpu().numpy().flatten().tolist()

    bk = bb84.run_protocol(ak, ab, bb, tweak_angles=angles)

    # Sifting
    X, y = [], []
    for i in range(128):
        if ab[i] == bb[i]:
            X.append([bk[i], 1.0 if ab[i]=='X' else 0.0, angles[i]])
            y.append(ak[i])

    if len(X) < 10: continue
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)

    pred = adversary(X)
    loss_a = nn.BCELoss()(pred, y)
    loss_d = -loss_a

    opt_a.zero_grad(); loss_a.backward(retain_graph=True); opt_a.step()
    opt_d.zero_grad(); loss_d.backward(); opt_d.step()

    if epoch % 100 == 99:
        acc = ((pred > 0.5) == y.bool()).float().mean().item()
        eve_accs.append(100 * acc)
        print(f"Epoch {epoch+1} â†’ Eve accuracy: {100*acc:.1f}%")

# Save final models
torch.save(defender.state_dict(), "src/defender_trained.pth")
torch.save(adversary.state_dict(), "src/adversary_trained.pth")

plt.plot(eve_accs); plt.title("Eve gets weaker over time"); plt.show()