In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv
import random
import numpy as np
import copy
import gc
from torch_geometric.loader import DataLoader

# The specific imports for Graph Neural Networks
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv

# --- 1. RAAG DEFINITION ---
# Commutation Graph: Even nodes commute with the next odd node
RAAG_ADJ = np.zeros((2001, 2001))
for i in range(2000):
    if i % 2 == 0:
        RAAG_ADJ[i][i+1] = 1
        RAAG_ADJ[i+1][i] = 1

class RAAGWord:
    def __init__(self, letters, adj_matrix=RAAG_ADJ):
        self.adj_matrix = adj_matrix
        self.letters = self.reduce(letters)

    def reduce(self, lts):
        # SAFETY CAP: Prevent "Infinite Growth" memory explosion
        if len(lts) > 60: lts = lts[:60]
        
        res = []
        for l in lts:
            if not res:
                res.append(l)
                continue
            idx = len(res) - 1
            can_cancel = False
            while idx >= 0:
                if res[idx] == -l:
                    can_cancel = True
                    break
                if not self.adj_matrix[abs(l)][abs(res[idx])]:
                    break
                idx -= 1
            if can_cancel:
                res.pop(idx)
            else:
                res.append(l)
        return res

    def __mul__(self, other):
        return RAAGWord(self.letters + other.letters, self.adj_matrix)

    def inv(self):
        return RAAGWord([-l for l in reversed(self.letters)], self.adj_matrix)

# --- 2. ALGEBRAIC MOVES & FAN TOPOLOGY ---
def apply_nielsen_moves_fixed(subgroup_basis, num_moves=10):
    scrambled = list(subgroup_basis)
    rank = len(scrambled)
    for _ in range(num_moves):
        i, j = random.sample(range(rank), 2)
        op = random.choice(['mul', 'inv_mul'])
        if op == 'mul':
            scrambled[i] = (scrambled[i][0] * scrambled[j][0], scrambled[i][1] * scrambled[j][1])
        else:
            j_inv = (scrambled[j][0].inv(), scrambled[j][1].inv())
            scrambled[i] = (scrambled[i][0] * j_inv[0], scrambled[i][1] * j_inv[1])
    return scrambled

def subgroup_to_graph(subgroup):
    """FAN TOPOLOGY: Connects every letter directly to its hub."""
    x, edge_index = [], []
    for i in range(5): x.append([float(501 + i)]) # 5 Hubs
    
    curr_idx = 5
    for i, (w_a, w_b) in enumerate(subgroup):
        hub_idx = i
        full_word = w_a.letters + w_b.letters
        if not full_word: full_word = [0]
        for letter in full_word:
            x.append([float(letter)])
            edge_index.append([curr_idx, hub_idx])
            edge_index.append([hub_idx, curr_idx])
            curr_idx += 1
            
    return Data(x=torch.tensor(x, dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous())


# --- 4. KERNEL-SAFE DATA GENERATION ---
def triplet_generator(batch_size, nielsen_moves):
    """LAZY GENERATOR: Never creates more than one batch at a time."""
    while True:
        triplets = []
        for _ in range(batch_size):
            shift = random.randint(-400, 400)
            base_sub = [(RAAGWord([shift + i]), RAAGWord([shift + i + 100])) for i in range(5)]
            
            anchor = apply_nielsen_moves_fixed(base_sub, num_moves=random.randint(1, 5))
            pos = apply_nielsen_moves_fixed(base_sub, num_moves=nielsen_moves)
            
            # Surgical Negative: Change one generator only
            neg = copy.deepcopy(anchor)
            idx = random.randint(0, 4)
            neg[idx] = (RAAGWord([random.randint(-1000, 1000)]), neg[idx][1])
            
            triplets.append((subgroup_to_graph(anchor), subgroup_to_graph(pos), subgroup_to_graph(neg)))
        
        yield (Batch.from_data_list([t[0] for t in triplets]),
               Batch.from_data_list([t[1] for t in triplets]),
               Batch.from_data_list([t[2] for t in triplets]))

def generate_triplets_hardened_v2(num_samples, nielsen_moves, surgical_prob=0.5):
    """
    Generates a large batch of triplets.
    surgical_prob: Probability of a 'Surgical' negative (hard) vs 'Alphabet' (easy).
    """
    triplets = []
    for _ in range(num_samples):
        # 1. Setup Base
        shift = random.randint(-600, 600)
        # Generators A and B are 100 apart in the alphabet
        base_sub = [(RAAGWord([shift + i]), RAAGWord([shift + i + 100])) for i in range(5)]

        # 2. Positive Pair (Anchor and Scrambled)
        # Anchor gets a light scramble (1-5 moves) so it's not "perfect"
        anchor_sub = apply_nielsen_moves_fixed(base_sub, num_moves=random.randint(1, 5))
        # Positive gets the full scramble
        pos_sub = apply_nielsen_moves_fixed(base_sub, num_moves=nielsen_moves)

        # 3. Negative Pair
        if random.random() < surgical_prob:
            # HARD NEGATIVE: Take the anchor and change one generator to a random letter
            neg_sub = copy.deepcopy(anchor_sub)
            idx = random.randint(0, 4)
            # Replace one generator in the pair with a random one
            neg_sub[idx] = (RAAGWord([random.randint(-1000, 1000)]), neg_sub[idx][1])
        else:
            # EASY NEGATIVE: Pick a completely different alphabet range
            # Ensure the new shift is far away from the original shift
            alt_shift = random.choice([s for s in range(-600, 600) if abs(s - shift) > 200])
            neg_base = [(RAAGWord([alt_shift + i]), RAAGWord([alt_shift + i + 100])) for i in range(5)]
            neg_sub = apply_nielsen_moves_fixed(neg_base, num_moves=random.randint(1, 5))

        triplets.append((
            subgroup_to_graph(anchor_sub),
            subgroup_to_graph(pos_sub),
            subgroup_to_graph(neg_sub)
        ))
    return triplets

# --- REFINED MODEL FOR NUMERICAL REASONING ---
class UniversalHubGNN(nn.Module):
    def __init__(self, hidden=256): # Increased hidden size for 512 batch
        super().__init__()
        # Linear encoder treats node IDs as coordinates, not just labels
        self.node_encoder = nn.Linear(1, hidden) 
        self.lin_deg = nn.Linear(1, hidden)
        self.convs = nn.ModuleList([
            GINConv(nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, hidden)))
            for _ in range(4) # Reduced layers to 4 to speed up 512-batch processing
        ])
        self.lns = nn.ModuleList([nn.LayerNorm(hidden) for _ in range(4)])
        self.fc = nn.Sequential(
            nn.Linear(5 * hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 64)
        )

    def forward_one(self, data):
        # Normalize node values to help the Linear layer see the 'Alphabet'
        x = self.node_encoder(data.x / 1000.0) 
        
        row, _ = data.edge_index
        deg = torch.zeros((data.x.size(0), 1), device=data.x.device)
        deg.scatter_add_(0, row.unsqueeze(1), torch.ones((row.size(0), 1), device=data.x.device))
        x = x + self.lin_deg(deg)

        for conv, ln in zip(self.convs, self.lns):
            h = conv(x, data.edge_index)
            x = ln(F.relu(h) + x) # Removed Dropout to stabilize large batch loss

        # Pooling logic...
        num_graphs = data.num_graphs if hasattr(data, 'num_graphs') else 1
        hub_embeddings = []
        if hasattr(data, 'ptr') and data.ptr is not None:
            for i in range(num_graphs):
                start = data.ptr[i]
                hub_embeddings.append(x[start : start + 5].reshape(-1))
        else:
            hub_embeddings.append(x[:5].reshape(-1))

        return F.normalize(self.fc(torch.stack(hub_embeddings)), p=2, dim=1)

# --- SPEED-OPTIMIZED TRAINING ---
device = torch.device('cuda')
model = UniversalHubGNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005) # Higher LR for large batch
criterion = nn.TripletMarginLoss(margin=1.0)

# Generate a large buffer ONCE to keep GPU saturated
print("Generating Initial Buffer...")
train_buffer = generate_triplets_hardened_v2(2048, nielsen_moves=15)
loader = DataLoader(train_buffer, batch_size=128, shuffle=True) 
# Note: 128 is a safer sweet spot for 6GB than 512.



for epoch in range(1000):
    model.train()
    total_loss = 0
    
    for a, p, n in loader:
        optimizer.zero_grad(set_to_none=True)
        
        # Fast transfer to GPU
        ea = model.forward_one(a.to(device))
        ep = model.forward_one(p.to(device))
        en = model.forward_one(n.to(device))
        
        loss = criterion(ea, ep, en)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Refresh 10% of the buffer every epoch to keep it "new" without a full stop
    if epoch % 5 == 0:
        print(f"Epoch {epoch} | Avg Loss: {total_loss/len(loader):.4f}")
        # Partial refresh logic can go here


Generating Initial Buffer...
Epoch 0 | Avg Loss: 1.0731
Epoch 5 | Avg Loss: 0.9376
Epoch 10 | Avg Loss: 0.8802
Epoch 15 | Avg Loss: 0.8065
Epoch 20 | Avg Loss: 0.7299
Epoch 25 | Avg Loss: 0.5655
Epoch 30 | Avg Loss: 0.4641
Epoch 35 | Avg Loss: 0.4156
Epoch 40 | Avg Loss: 0.3521
Epoch 45 | Avg Loss: 0.3070
Epoch 50 | Avg Loss: 0.2988
Epoch 55 | Avg Loss: 0.2998
Epoch 60 | Avg Loss: 0.2783
Epoch 65 | Avg Loss: 0.2225
Epoch 70 | Avg Loss: 0.2015
Epoch 75 | Avg Loss: 0.1863
Epoch 80 | Avg Loss: 0.1630
Epoch 85 | Avg Loss: 0.1340
Epoch 90 | Avg Loss: 0.1578
Epoch 95 | Avg Loss: 0.1348
Epoch 100 | Avg Loss: 0.1234
Epoch 105 | Avg Loss: 0.0942
Epoch 110 | Avg Loss: 0.0866
Epoch 115 | Avg Loss: 0.1001
Epoch 120 | Avg Loss: 0.1005
Epoch 125 | Avg Loss: 0.0716
Epoch 130 | Avg Loss: 0.0768
Epoch 135 | Avg Loss: 0.0872
Epoch 140 | Avg Loss: 0.0921
Epoch 145 | Avg Loss: 0.0954
Epoch 150 | Avg Loss: 0.0649
Epoch 155 | Avg Loss: 0.0667
Epoch 160 | Avg Loss: 0.0683
Epoch 165 | Avg Loss: 0.0521
Epoch 1

In [19]:
def test_vault_crack(model, device, num_decoys=99, nielsen_moves=20):
    model.eval()
    
    # 1. Create the 'Secret' (The Anchor)
    secret_shift = random.randint(-400, 400)
    secret_basis = [(RAAGWord([secret_shift + i]), RAAGWord([secret_shift + i + 100])) for i in range(5)]
    anchor_graph = subgroup_to_graph(secret_basis).to(device)
    
    with torch.no_grad():
        # Get the 'fingerprint' of our secret
        anchor_emb = model.forward_one(anchor_graph)

    # 2. Setup the Vault
    vault_graphs = []
    correct_idx = random.randint(0, num_decoys)
    
    print(f"üîì Hiding secret in Vault Slot: {correct_idx}...")

    for i in range(num_decoys + 1):
        if i == correct_idx:
            # The True Secret (heavily scrambled)
            sub = apply_nielsen_moves_fixed(secret_basis, num_moves=nielsen_moves)
        else:
            # Decoy: Different alphabet, different algebra
            r_shift = random.choice([s for s in range(-600, 600) if abs(s - secret_shift) > 150])
            sub = [(RAAGWord([r_shift + j]), RAAGWord([r_shift + j + 100])) for j in range(5)]
            sub = apply_nielsen_moves_fixed(sub, num_moves=random.randint(1, 10))
            
        vault_graphs.append(subgroup_to_graph(sub))

    # 3. The Crack Operation (Batch Processing for Speed)
    loader = DataLoader(vault_graphs, batch_size=32)
    vault_embs = []
    
    with torch.no_grad():
        for batch in loader:
            emb = model.forward_one(batch.to(device))
            vault_embs.append(emb)
    
    vault_embs = torch.cat(vault_embs, dim=0)

    # 4. Calculate Similarities
    # We use Cosine Similarity to see how close the 'fingerprint' is to each vault item
    similarities = F.cosine_similarity(anchor_emb, vault_embs).cpu().numpy()
    
    # 5. Results
    predicted_idx = np.argmax(similarities)
    top_5_indices = np.argsort(similarities)[-5:][::-1]
    
    print(f"\n--- CRACK RESULT ---")
    print(f"Top Similarity: {similarities[predicted_idx]:.4f}")
    print(f"Correct Index: {correct_idx} | Predicted Index: {predicted_idx}")
    
    if predicted_idx == correct_idx:
        print("‚úÖ SUCCESS: The GNN cracked the vault!")
    else:
        rank = list(np.argsort(similarities)[::-1]).index(correct_idx) + 1
        print(f"‚ùå FAILURE: The secret was ranked #{rank} out of {num_decoys+1}")
        
    return predicted_idx == correct_idx

# Run 10 trials to get an accuracy percentage
success_count = 0
trials = 10

for t in range(trials):
    print(f"\nTrial {t+1}/{trials}")
    if test_vault_crack(model, device, num_decoys=99, nielsen_moves=20):
        success_count += 1

print(f"\nüèÜ Final Accuracy: {success_count/trials * 100}%")


Trial 1/10
üîì Hiding secret in Vault Slot: 27...

--- CRACK RESULT ---
Top Similarity: 0.9979
Correct Index: 27 | Predicted Index: 27
‚úÖ SUCCESS: The GNN cracked the vault!

Trial 2/10
üîì Hiding secret in Vault Slot: 4...

--- CRACK RESULT ---
Top Similarity: 0.9885
Correct Index: 4 | Predicted Index: 49
‚ùå FAILURE: The secret was ranked #4 out of 100

Trial 3/10
üîì Hiding secret in Vault Slot: 71...

--- CRACK RESULT ---
Top Similarity: 0.9592
Correct Index: 71 | Predicted Index: 11
‚ùå FAILURE: The secret was ranked #2 out of 100

Trial 4/10
üîì Hiding secret in Vault Slot: 48...

--- CRACK RESULT ---
Top Similarity: 0.9973
Correct Index: 48 | Predicted Index: 48
‚úÖ SUCCESS: The GNN cracked the vault!

Trial 5/10
üîì Hiding secret in Vault Slot: 66...

--- CRACK RESULT ---
Top Similarity: 0.9822
Correct Index: 66 | Predicted Index: 66
‚úÖ SUCCESS: The GNN cracked the vault!

Trial 6/10
üîì Hiding secret in Vault Slot: 33...

--- CRACK RESULT ---
Top Similarity: 0.9549
Co