In [1]:
import duckdb 
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
from uuid import UUID
import torch.nn as nn
import torch.optim as optim
import json

In [2]:
%load_ext sql
%sql duckdb://

%sql INSTALL vss
%sql LOAD vss

  """
  """


Success


In [3]:
DIMENSIONS = 768
EMBEDDING_MODEL = "BAAI/bge-base-en-v1.5"

# Basic Table Setup

Create the relations table, which will store shuffle masks

In [4]:
%%sql 
CREATE TABLE IF NOT EXISTS relations (
    id TEXT PRIMARY KEY, 
    shuffle_mask FLOAT[{{DIMENSIONS}}],
    projection_weights FLOAT[{{DIMENSIONS}}]
)

Count


Create the entities table, which will store the entity data

In [5]:
%%sql 
CREATE TABLE IF NOT EXISTS entities (
    id TEXT PRIMARY KEY,  
    version INTEGER,
    embedding FLOAT[{{DIMENSIONS}}]
)

Count


In [6]:
%%sql 
CREATE TABLE IF NOT EXISTS adjacency (
    head_id TEXT, 
    tail_id TEXT, 
    relation_id TEXT,
    PRIMARY KEY (head_id, tail_id, relation_id)
)

Count


# Embedding Model Setup

In [7]:
embedding_model = SentenceTransformer(EMBEDDING_MODEL)

# Database Functions


In [8]:
def get_or_create_entity(term: str):
    result = %sql SELECT embedding, version FROM entities WHERE id = '{{term}}'
    if result:
        logging.info(f"Found Result: {term}")
        d = result.dict()
        emb = np.array(d['embedding'][0])
        return torch.from_numpy(emb).clone(), d['version'][0]
    else:
        logging.info(f"Cold Start: Embedding {term}")
        emb = embedding_model.encode(term)
        emb_list = emb.tolist()
        %sql INSERT INTO entities (id, embedding, version) VALUES ('{{term}}', {{emb_list}}, 1)
        return torch.from_numpy(emb).clone(), 1 

In [9]:
def get_or_create_relation_with_weights(rel_id):
    """
    Fetches (Permutation Mask, Projection Weights).
    If new, initializes Mask to Random and Weights to 1.0 (Strict).
    """
    # 1. Check if relation exists
    result = %sql SELECT shuffle_mask, projection_weights FROM relations WHERE id = '{{rel_id}}'
    
    if result:
        d = result.dict()
        mask = np.array(d['shuffle_mask'][0])
        
        # Handle case where weights might be null (legacy data)
        if d.get('projection_weights') and d['projection_weights'][0]:
            weights = np.array(d['projection_weights'][0])
        else:
            weights = np.ones(DIMENSIONS) # Default to 1.0
            
        return (
            torch.from_numpy(mask).long().clone(), 
            torch.from_numpy(weights).float().clone()
        )
    else:
        # 2. Create New
        mask = np.random.permutation(DIMENSIONS) # Random Shuffle
        weights = np.ones(DIMENSIONS)            # Strict Attention (1.0)
        
        mask_list = mask.tolist()
        weights_list = weights.tolist()
        
        # Note: You might need to ALTER TABLE relations ADD COLUMN projection_weights FLOAT[768]
        # if using the old schema.
        %sql INSERT INTO relations (id, shuffle_mask, projection_weights) VALUES ('{{rel_id}}', {{mask_list}}, {{weights_list}})
        
        return (
            torch.from_numpy(mask).long().clone(), 
            torch.from_numpy(weights).float().clone()
        )

In [10]:
def get_neighbors(ent_id: UUID):
    result = %sql SELECT head_id, relation_id, tail_id FROM adjacency WHERE head_id = '{{ent_id}}' OR tail_id = '{{ent_id}}' 
    return list(result.dicts())

In [11]:
def cas_update(updates):
    for ent_id, (new_emb, old_ver) in updates.items():
        emb = new_emb.detach().numpy().astype(np.float32).tolist()
        %sql UPDATE entities SET embedding = {{emb}}, version = {{old_ver + 1}} WHERE id = '{{ent_id}}' and version = {{old_ver}}
    return True

In [12]:
def cas_update_relation(rel_id, new_weights):
    """Updates the projection weights for a relation."""
    w_list = new_weights.detach().numpy().tolist()
    %sql UPDATE relations SET projection_weights = {{w_list}} WHERE id = '{{rel_id}}'
    return True

In [13]:
def add_edge(head_id: UUID, relation_id: UUID, tail_id: UUID):
    %sql INSERT INTO adjacency (head_id, relation_id, tail_id) VALUES ('{{head_id}}', '{{relation_id}}', '{{tail_id}}')
    return True

# ==========================================
# Reshuffle Core
# ==========================================


In [14]:
class StreamingReshuffleModule(nn.Module):
    """
    PyTorch Logic Core. 
    Now implements Automatic Symmetric Tension in the forward pass.
    """

    def __init__(self, embeddings_dict, relations_dict, weights_dict):
        super().__init__()
        self.ent_map = {k: i for i, k in enumerate(embeddings_dict.keys())}
        self.rel_map = {k: i for i, k in enumerate(relations_dict.keys())}

        # 1. Load Entities
        if embeddings_dict:
            self.entities = nn.Parameter(torch.stack(list(embeddings_dict.values())))
        else:
            self.entities = nn.Parameter(torch.empty(0, DIMENSIONS))

        # 2. Load Relations & Pre-compute Inverses
        if relations_dict:
            self.shuffles = torch.stack(list(relations_dict.values()))
            # AUTOMATIC: Pre-compute inverse masks (The "Un-shuffle")
            # If Forward is A -> B, Inverse is B -> A
            self.inv_shuffles = torch.argsort(self.shuffles, dim=1)

            # --- Learnable Projection Weights ---
            # We initialize them to 1.0 (Full Attention).
            # The system will learn to lower them to 0.0 for irrelevant features.
            # num_rels = len(relations_dict)
            # self.rel_weights = nn.Parameter(torch.ones(num_rels, DIMENSIONS))
            self.rel_weights = nn.Parameter(torch.stack(list(weights_dict.values())))

        else:
            self.shuffles = torch.empty(0, DIMENSIONS, dtype=torch.long)
            self.inv_shuffles = torch.empty(0, DIMENSIONS, dtype=torch.long)
            self.rel_weights = nn.Parameter(torch.empty(0, DIMENSIONS))

    def forward(self, h_idx, r_idx, t_idx):
        """
        Calculates Symmetric Tension.
        Minimizing this value forces the embeddings to satisfy the relationship 
        bi-directionally (Structural Equality).
        """
        h = self.entities[h_idx]
        t = self.entities[t_idx]
        
        # --- Apply Projection (Focus) ---
        # Fetch the weight mask for this specific relation
        w = self.rel_weights[r_idx]
        
        # Project both entities into the "Relation Subspace"
        # If w[i] is near 0, that dimension effectively vanishes.
        h_proj = h * w
        t_proj = t * w

        # 1. Forward Logic: Head vs Shuffled Tail
        # "Does Head look like a permutation of Tail?"
        mask = self.shuffles[r_idx]
        t_shuffled = t_proj[mask]
        fwd_tension = torch.relu(h_proj - t_shuffled).sum()

        # 2. Inverse Logic: Tail vs Un-shuffled Head
        # "Does Tail look like the inverse permutation of Head?"
        inv_mask = self.inv_shuffles[r_idx]
        h_unshuffled = h_proj[inv_mask]
        inv_tension = torch.relu(t_proj - h_unshuffled).sum()

        # C. Regularization (The "Anti-Collapse" Term)
        # We penalize weights that are too small to prevent the "Zero Solution."
        # This keeps the system honest.
        # Logic: maximize sum(w) -> minimize -sum(w)
        reg_loss = -torch.sum(w) * 0.05

        # 3. Combine
        # We average them so the optimizer cares about both directions equally.
        return ((fwd_tension + inv_tension) / 2.0) + reg_loss

# ==========================================
# LAYER 3: THE WORKER (Ripple Update)
# ==========================================

In [15]:
class RippleWorker:
    """
    Stateless worker performing Local Relaxation.
    """

    def process_fact(self, h_id, r_id, t_id, stiffness=0.5, steps=50):
        """The Optimistic Ripple Update Algorithm (Unchanged)"""
        print(f"Worker: Updating {h_id} --[{r_id}]--> {t_id} (Stiffness: {stiffness})")

        # 1. Fetch Active Set
        active_ids = {h_id, t_id}
        neighbors = get_neighbors(h_id) + get_neighbors(t_id)
        for n in neighbors:
            active_ids.add(n['head_id'])
            active_ids.add(n['tail_id'])

        # 2. Fetch Vectors & Versions
        local_embs = {}
        local_vers = {}
        local_rels = {}
        local_weights = {}

        # Helper to load context
        def load_context(eid=None, rid=None):
            if eid and eid not in local_embs:
                e, v = get_or_create_entity(eid)
                local_embs[eid] = e
                local_vers[eid] = v
            if rid and rid not in local_rels:
                mask, w = get_or_create_relation_with_weights(rid)
                local_rels[rid] = mask
                local_weights[rid] = w

        # Load primary fact
        load_context(eid=h_id)
        load_context(eid=t_id)
        load_context(rid=r_id)
        
        # Load neighbors
        for n in neighbors:
            load_context(eid=n['head_id'])
            load_context(eid=n['tail_id'])
            load_context(rid=n['relation_id'])

        # 3. Initialize Compute
        model = StreamingReshuffleModule(local_embs, local_rels, local_weights)

        # Prepare Indices
        h_idx = torch.tensor(model.ent_map[h_id])
        t_idx = torch.tensor(model.ent_map[t_id])
        r_idx = torch.tensor(model.rel_map[r_id])
        
        anchor_embs = model.entities.clone().detach() # Elasticity Anchors

        # --- 3. THE TRAINING LOOP (Alternating Minimization) ---
        
        # PHASE 1: FIT (Move Entities, Freeze Weights)
        # Logic: "Assuming the rules are strict, can we fit the data?"
        model.entities.requires_grad = True
        model.rel_weights.requires_grad = False
        
        optimizer_ent = optim.SGD([model.entities], lr=0.1)

        for i in range(steps // 2):
            optimizer_ent.zero_grad()
            
            # Primary Tension
            loss = model(h_idx, r_idx, t_idx)
            
            # Anchor Loss (Stay close to previous state)
            loss_anchor = torch.sum((model.entities - anchor_embs) ** 2)
            
            # Neighbor Constraint (Don't break existing edges)
            loss_neighbor = 0
            for n in neighbors:
                ni_h = model.ent_map[n['head_id']]
                ni_t = model.ent_map[n['tail_id']]
                ni_r = model.rel_map[n['relation_id']]
                loss_neighbor += model(torch.tensor(ni_h), torch.tensor(ni_r), torch.tensor(ni_t))
            
            total_loss = loss + (stiffness * loss_anchor) + (0.1 * loss_neighbor)
            total_loss.backward()
            optimizer_ent.step()
        
        # PHASE 2: RELAX (Move Weights, Freeze Entities)
        # Logic: "If we still have tension, the rule must be too strict. Loosen it."
        model.entities.requires_grad = False
        model.rel_weights.requires_grad = True
        
        # Lower learning rate for weights to prevent collapse
        optimizer_rel = optim.SGD([model.rel_weights], lr=0.01)

        for i in range(steps // 2):
            optimizer_rel.zero_grad()
            
            # Recalculate Tension (Same formula)
            loss = model(h_idx, r_idx, t_idx)
            
            # We also care about neighbors! 
            # If we relax the rule for Vader, we relax it for everyone.
            # We must check if relaxing hurts other facts using this same relation.
            loss_neighbor = 0
            for n in neighbors:
                if n['relation_id'] == r_id: # Only check relevant neighbors
                    ni_h = model.ent_map[n['head_id']]
                    ni_t = model.ent_map[n['tail_id']]
                    ni_r = model.rel_map[n['relation_id']]
                    loss_neighbor += model(torch.tensor(ni_h), torch.tensor(ni_r), torch.tensor(ni_t))

            # Note: Reg_loss is already inside model.forward()
            total_loss = loss + (0.1 * loss_neighbor)
            total_loss.backward()
            optimizer_rel.step()
            
            # Clamp weights to keep them physical [0.0, 1.0]
            with torch.no_grad():
                model.rel_weights.clamp_(0.0, 1.0)
        
        
        # --- 4. CAS Update (Commit Changes) ---
        
        # A. Commit Entities
        updates = {}
        keys = list(local_embs.keys())
        for i, eid in enumerate(keys):
            updates[eid] = (model.entities[i], local_vers[eid])
        
        ent_success = cas_update(updates)
        
        # B. Commit Relation Weights
        # We only update the relation we focused on (r_id) to avoid race conditions on others
        new_w = model.rel_weights[model.rel_map[r_id]]
        rel_success = cas_update_relation(r_id, new_w)
        
        if ent_success and rel_success:
            # Only add the edge if both commits worked
            add_edge(h_id, r_id, t_id)
            print("  -> Update Committed (Entities & Weights).")
        else:
            print("  -> Write Conflict. Retrying needed.")

    def reject_fact(self, h_id, r_id, t_id, stiffness=0.5, steps=50, margin=10.0):
        """
        Negative Constraint: Learns that h --[r]--> t is FALSE.
        Uses Max-Margin Loss to push embeddings apart.
        Strategy: Updates Entities Only (Locks Weights).
        """
        print(f"Worker: REJECTING {h_id} --[{r_id}]--> {t_id} (Target Margin: {margin})")

        # --- 1. Load Context ---
        # active_ids = {h_id, t_id}
        neighbors = get_neighbors(h_id) + get_neighbors(t_id)
        
        local_embs = {}
        local_vers = {}
        local_rels = {}
        local_weights = {}

        def load_context(eid=None, rid=None):
            if eid and eid not in local_embs:
                e, v = get_or_create_entity(eid)
                local_embs[eid] = e
                local_vers[eid] = v
            if rid and rid not in local_rels:
                mask, w = get_or_create_relation_with_weights(rid)
                local_rels[rid] = mask
                local_weights[rid] = w

        load_context(eid=h_id)
        load_context(eid=t_id)
        load_context(rid=r_id)
        for n in neighbors:
            load_context(eid=n['head_id'])
            load_context(eid=n['tail_id'])
            load_context(rid=n['relation_id'])

        # --- 2. Initialize Physics Engine ---
        model = StreamingReshuffleModule(local_embs, local_rels, local_weights)
        
        h_idx = torch.tensor(model.ent_map[h_id])
        t_idx = torch.tensor(model.ent_map[t_id])
        r_idx = torch.tensor(model.rel_map[r_id])
        
        anchor_embs = model.entities.clone().detach()

        # --- 3. THE REJECTION LOOP ---
        # We LOCK weights. We only move entities to satisfy the constraint.
        model.rel_weights.requires_grad = False
        model.entities.requires_grad = True
        
        optimizer = optim.SGD([model.entities], lr=0.1)

        for i in range(steps):
            optimizer.zero_grad()
            
            # Calculate current tension (How true is it?)
            tension_val = model(h_idx, r_idx, t_idx)
            
            # --- CONTRASTIVE LOSS ---
            # We want tension > margin. 
            # If tension is 2.0 and margin is 10.0, loss is 8.0.
            # If tension is 15.0, loss is 0.
            loss_reject = torch.relu(margin - tension_val)
            
            # Standard Constraints (Anchors + Neighbors)
            # Crucial: We must push 'h' away from 't', but keep it connected to its other friends.
            loss_anchor = torch.sum((model.entities - anchor_embs) ** 2)
            
            loss_neighbor = 0
            for n in neighbors:
                ni_h = model.ent_map[n['head_id']]
                ni_t = model.ent_map[n['tail_id']]
                ni_r = model.rel_map[n['relation_id']]
                # Minimize neighbor tension (keep valid facts valid)
                loss_neighbor += model(torch.tensor(ni_h), torch.tensor(ni_r), torch.tensor(ni_t))

            # Total Loss
            loss = loss_reject + (stiffness * loss_anchor) + (0.1 * loss_neighbor)
            
            if loss_reject.item() == 0:
                # Early stopping if we pushed them far enough apart
                break
                
            loss.backward()
            optimizer.step()

        # --- 4. CAS Update (Entities Only) ---
        updates = {}
        keys = list(local_embs.keys())
        for i, eid in enumerate(keys):
            updates[eid] = (model.entities[i], local_vers[eid])

        if cas_update(updates):
            # If this edge existed in adjacency, we should delete it.
            # For now, we assume this is just correcting the vector space.
            print("  -> Rejection Committed.")
        else:
            print("  -> Write Conflict. Retrying needed.")

    def trace(self, h_id, r_id, t_id):
        """
        Traceability Engine.
        Returns Confidence, Tension, and Subspace Usage.
        """
        # 1. Load Minimal Context
        h, _ = get_or_create_entity(h_id)
        t, _ = get_or_create_entity(t_id)
        mask, w = get_or_create_relation_with_weights(r_id)
        
        # 2. Setup Temporary Model
        # (We bypass the dict overhead for speed here)
        model = StreamingReshuffleModule(
            {h_id: h, t_id: t}, 
            {r_id: mask}, 
            {r_id: w}
        )
        
        h_idx = torch.tensor(0) # head is first
        t_idx = torch.tensor(1) # tail is second
        r_idx = torch.tensor(0) # relation is first

        # 3. Compute Metrics
        with torch.no_grad():
            # Raw Tension (Energy)
            raw_tension = model(h_idx, r_idx, t_idx).item()
            
            # Subspace Analysis
            # How "strict" is this rule? 
            # If sum(w) is low, the rule is weak/vague.
            active_mass = torch.sum(w).item()
            sparsity = active_mass / DIMENSIONS
            
            # Normalized Tension (Error per Active Dimension)
            # If we only look at 10 dimensions, an error of 5.0 is huge.
            # If we look at 768 dimensions, an error of 5.0 is tiny.
            if active_mass > 0:
                avg_tension = raw_tension / active_mass
            else:
                avg_tension = raw_tension # Avoid div/0
                
            # Confidence Score (Sigmoid-like decay)
            # High tension -> Low Confidence
            confidence = 1.0 / (1.0 + (avg_tension * 10))

        return {
            "result": confidence > 0.5,
            "confidence": round(confidence, 4),
            "tension": round(raw_tension, 4),
            "sparsity": round(sparsity, 4) # 1.0 = Strict, 0.1 = Very Fuzzy
        }

# ==========================================
# LAYER 4: THE INTERFACE (API Simulation)
# ==========================================

In [16]:
class AgentInterface:
    """Simulates the API Gateway."""

    def __init__(self, worker):
        self.worker = worker

    def ingest(self, json_payload):
        """Implementation of POST /ingest"""
        data = json.loads(json_payload)
        for event in data["events"]:
            if event["type"] == "FACT":
                self.worker.process_fact(
                    event["h"],
                    event["r"],
                    event["t"],
                    stiffness=event.get("stiffness", 0.5),
                )
            elif event["type"] == "NEGATION":
                self.worker.reject_fact(
                    event["h"],
                    event["r"],
                    event["t"],
                    stiffness=event.get("stiffness", 0.5),
                    margin=event.get("margin", 10.0) # Default margin
                )

    def query(self, json_payload):
        """Implementation of POST /reason"""
        data = json.loads(json_payload)
        
        if data.get("trace"):
            return self.worker.trace(
                data["head"], 
                data["relation"], 
                data["tail"]
            )
            

# ==========================================
# EXECUTION DEMO
# ==========================================

## 1. Setup

In [17]:
worker = RippleWorker()
api = AgentInterface(worker)

## 2. Ingest Data

In [18]:
print("\n--- 2. Ingestion Stream ---")
payload = json.dumps(
    {
        "events": [
            # Fact 1: Strong Truth
            {
                "type": "FACT",
                "h": "Service_A",
                "r": "depends_on",
                "t": "Lib_Foo",
                "stiffness": 1.0,
            },
            # Fact 2: Strong Truth
            {
                "type": "FACT",
                "h": "Lib_Foo",
                "r": "has_status",
                "t": "Vulnerable",
                "stiffness": 1.0,
            },
            # Fact 3: Agent Hypothesis (Weaker stiffness)
            {
                "type": "FACT",
                "h": "Service_A",
                "r": "has_risk",
                "t": "High",
                "stiffness": 0.5,
            },
        ]
    }
)
api.ingest(payload)


--- 2. Ingestion Stream ---
Worker: Updating Service_A --[depends_on]--> Lib_Foo (Stiffness: 1.0)


  -> Update Committed (Entities & Weights).
Worker: Updating Lib_Foo --[has_status]--> Vulnerable (Stiffness: 1.0)


  -> Update Committed (Entities & Weights).
Worker: Updating Service_A --[has_risk]--> High (Stiffness: 0.5)


  -> Update Committed (Entities & Weights).


## 3. Query and Trace

In [19]:
print("\n--- 3. Reasoning & Traceability ---")
# Query: Is Service_A's risk High?
query_payload = json.dumps(
    {"head": "Service_A", "relation": "has_risk", "tail": "High", "trace": True}
)
result = api.query(query_payload)
print(f"Query Result: {result}")

# Query: Counter-factual (Is Service_A's risk Low?)
# Note: We haven't taught it "Low", but semantic init might separate High/Low.
query_payload_2 = json.dumps(
    {"head": "Service_A", "relation": "has_risk", "tail": "Low", "trace": True}
)
result_2 = api.query(query_payload_2)
print(f"Counter-factual Result: {result_2}")


--- 3. Reasoning & Traceability ---


Query Result: {'result': True, 'confidence': 1.4247, 'tension': -22.8904, 'sparsity': 0.9999}


Counter-factual Result: {'result': True, 'confidence': 1.5294, 'tension': -26.5807, 'sparsity': 0.9999}
