In [1]:
import numpy as np
import networkx as nx

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [3]:
from sentence_transformers import SentenceTransformer

In [4]:
from sklearn.metrics.pairwise import cosine_similarity
from typing import List

In [5]:
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List

class QuantumWalkRetriever(nn.Module):
    def __init__(self, embed_model_name: str = 'all-MiniLM-L6-v2', k: int = 5, hidden_dim: int = 128, walk_steps: int = 3):
        super().__init__()
        self.embedder = SentenceTransformer(embed_model_name)
        for param in self.embedder.parameters():
            param.requires_grad = False
        self.k = k
        self.walk_steps = walk_steps
        embedding_dim = self.embedder.get_sentence_embedding_dimension()
        self.coin_net = nn.Sequential(
            nn.Linear(embedding_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, k)
        )

    def embed_sentences(self, sentences: List[str]) -> np.ndarray:
        return self.embedder.encode(sentences, convert_to_numpy=True)

    def build_graph(self, embeddings: np.ndarray) -> nx.Graph:
        sim = cosine_similarity(embeddings)
        n = embeddings.shape[0]
        G = nx.Graph()
        G.add_nodes_from(range(n))
        for i in range(n):
            neighbors = np.argsort(sim[i])[::-1][1:self.k+1]
            for j in neighbors:
                G.add_edge(i, j, weight=sim[i, j])
        return G

    def quantum_walk(self, G: nx.Graph, query_vec: np.ndarray, embeddings: np.ndarray) -> torch.Tensor:
        n = G.number_of_nodes()
        state = torch.ones(n, self.k, dtype=torch.cfloat) / float(np.sqrt(n * self.k))
        nbrs = [list(G.neighbors(i)) for i in range(n)]
        q_tensor = torch.from_numpy(query_vec).float()
        emb_tensor = torch.from_numpy(embeddings).float()
        for _ in range(self.walk_steps):
            coins = []
            for i in range(n):
                inp = torch.cat([emb_tensor[i], q_tensor], dim=0)
                amps = self.coin_net(inp)
                c_real = amps.unsqueeze(1) * amps.unsqueeze(0)
                c = c_real.to(torch.cfloat) / torch.norm(c_real)
                coins.append(c)
            new_state = torch.zeros_like(state)
            for i in range(n):
                s_prime = coins[i] @ state[i]
                for idx, j in enumerate(nbrs[i][:self.k]):
                    new_state[j, idx] += s_prime[idx]
            state = new_state / torch.norm(new_state)
        logits = state.abs().sum(dim=1)
        return logits

    def forward(self, question: str, sentences: List[str]) -> List[tuple[int, float]]:
        sent_emb = self.embed_sentences(sentences)
        q_emb = self.embedder.encode([question], convert_to_numpy=True)[0]
        G = self.build_graph(sent_emb)
        logits = self.quantum_walk(G, q_emb, sent_emb).detach().cpu().numpy()
        return sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    def train_coin_operator(self, train_data: List[dict], epochs: int = 5, lr: float = 1e-3):
        optimizer = optim.Adam(self.coin_net.parameters(), lr=lr)
        self.train()
        for epoch in range(1, epochs + 1):
            total_loss = 0.0
            for ex in train_data:
                labels = torch.tensor(ex['labels'], dtype=torch.float)
                if labels.sum() > 0:
                    labels = labels / labels.sum()
                sent_emb = self.embed_sentences(ex['sentences'])
                q_emb = self.embedder.encode([ex['question']], convert_to_numpy=True)[0]
                G = self.build_graph(sent_emb)
                logits = self.quantum_walk(G, q_emb, sent_emb)
                probs = torch.softmax(logits, dim=0)
                loss = F.kl_div(probs.log(), labels, reduction='batchmean')
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f"Epoch {epoch}/{epochs}, Loss: {total_loss/len(train_data):.4f}")
        self.eval()


In [6]:

retriever = QuantumWalkRetriever(embed_model_name='all-MiniLM-L6-v2', k=5, hidden_dim=128, walk_steps=3)

In [7]:
documents = [
    "The Eiffel Tower is in Paris.",
    "Paris is the capital of France.",
    "The Louvre houses the Mona Lisa.",
    "The Mona Lisa is a portrait by Leonardo da Vinci."
]
question = "Where can you find the Mona Lisa?"

In [8]:
rankings = retriever.forward(question, documents)
for idx, score in rankings[:3]:
    print(f"{documents[idx]} -> score: {score:.4f}")

The Eiffel Tower is in Paris. -> score: 0.7101
Paris is the capital of France. -> score: 0.6072
The Louvre houses the Mona Lisa. -> score: 0.5173


In [9]:
train_examples = [
    {'question': question, 'sentences': documents, 'labels': [0, 0, 1, 0]}
]
retriever.train_coin_operator(train_examples, epochs=3, lr=1e-3)

Epoch 1/3, Loss: 0.3518
Epoch 2/3, Loss: 0.3172
Epoch 3/3, Loss: 0.2721
