In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
!pip install scipy==1.10.1



In [4]:
import scipy
print(scipy.__version__)

1.10.1


In [5]:
# All Losses Together
import torch
import torch.nn as nn

class MyHashLoss(nn.Module):
    def __init__(self, q_wt=0.1, adv_wt=0.01, kern_wt=0.1):
        super(MyHashLoss, self).__init__()
        self.q_wt = q_wt        # weight for quant loss
        self.adv_wt = adv_wt    # weight for adv loss
        self.kern_wt = kern_wt  # (not used now but maybe later)

    def quant_loss(self, h):
        # want values close to -1 or 1, so abs - 1 should be small
        return torch.mean((h.abs() - 1) ** 2)

    def sim_loss(self, real_sim, hash_sim):
        # check how diff the sim matrices are
        return torch.mean((real_sim - hash_sim) ** 2)

    def kernel_loss(self, pred_kernel, true_kernel):
        # maybe use later - align kernel with GT
        return torch.mean((pred_kernel - true_kernel) ** 2)

    def forward(self, out, gt_sim, adv):
        cont = out['continuous_hash']
        bin_hash = out['binary_hash']  # not used now
        kern_sim = out['kernel_similarity']

        q_loss = self.quant_loss(cont)
        s_loss = self.sim_loss(gt_sim, kern_sim)

        total = s_loss + self.q_wt * q_loss + self.adv_wt * adv

        return total, {
            'quant_loss': q_loss.item(),
            'sim_loss': s_loss.item(),
            'adv_loss': adv.item(),
            'total_loss': total.item()
        }


In [6]:
!pip install geoopt



In [7]:
import torch
import geoopt

import torch
import geoopt

# Function to get the average point (Fréchet mean) on Lorentz manifold
def lorentz_mean(manifold, points, max_iter=10, tol=1e-5):
    """
    Finds average point (kind of like center) for given points on Lorentz manifold
    points: [N, D] shaped tensor
    """
    dev = points.device  # use same device as input
    mu = points[0].to(dev)  # just take the first point as starting point

    for _ in range(max_iter):
        tangents = manifold.logmap(mu, points)  # get tangent vectors from mu to all points
        avg_tangent = tangents.mean(dim=0)      # average of those tangents
        if avg_tangent.norm().item() < tol:     # stop if changes are very small
            break
        mu = manifold.expmap(mu, avg_tangent)   # move mu a bit in direction of avg_tangent

    return mu


import torch
import geoopt

# Basic k-means but for hyperbolic space using Lorentz model
def hyperbolic_kmeans(x, k, manifold, max_iter=20):
    """
    Does k-means in hyperbolic space
    x: [B, D] points on manifold
    k: number of clusters
    """
    B, D = x.shape
    perm = torch.randperm(B)  # shuffle indices
    centroids = x[perm[:k]].clone()  # pick k random points as starting centroids

    for _ in range(max_iter):
        # calculate distance from every x to each centroid
        dists = manifold.dist(
            x.unsqueeze(1),           # [B, 1, D]
            centroids.unsqueeze(0)    # [1, k, D]
        ).squeeze(1)                  # final shape: [B, k]

        labels = dists.argmin(dim=1)  # assign each x to nearest centroid

        new_centroids = []
        for ci in range(k):
            pts = x[labels == ci]  # all points belonging to cluster ci
            if pts.numel() == 0:
                # if no point got assigned, just pick random one again
                rand_idx = torch.randint(0, B, (1,)).item()
                new_centroids.append(x[rand_idx])
            else:
                mu = lorentz_mean(manifold, pts)  # get average point
                new_centroids.append(mu)

        centroids = torch.stack(new_centroids, dim=0)

    return centroids, labels  # return final cluster centers + which point goes where



In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import geoopt

import torch
import torch.nn as nn
import torch.nn.functional as F
import geoopt

# Generator model that gives both continuous and binary hashes
class HashGen(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, manifold=None):
        """
        in_dim  : input embedding size (from HypFormer maybe)
        hid_dim : size of middle layer
        out_dim : size of hash code (final output)
        """
        super().__init__()
        self.manifold = manifold or geoopt.manifolds.Lorentz()  # use Lorentz if not given

        self.fc1 = nn.Linear(in_dim, hid_dim)    # input to hidden
        self.fc_cont = nn.Linear(hid_dim, out_dim)  # for continuous hash
        self.fc_bin = nn.Linear(hid_dim, out_dim)   # for binary hash

    def forward(self, x):
        dev = x.device  # stick to same device

        # Step 1: take input from manifold to flat space
        x_tan = self.manifold.logmap0(x)

        # Step 2: pass through hidden layer
        h = F.relu(self.fc1(x_tan))

        # Step 3: one branch gives continuous hash
        cont_out = self.fc_cont(h)
        cont_hash = self.manifold.expmap0(cont_out)  # back to curved space

        # Step 4: other branch gives binary hash
        bin_out = self.fc_bin(h)
        bin_hash = torch.sign(bin_out)  # just hard sign

        return cont_hash.to(dev), bin_hash.to(dev)  # return both



# Simple Discriminator for GAN training
class Discrim(nn.Module):
    def __init__(self, hash_dim=64, hid_dim=128, manifold=None):
        """
        hash_dim : size of the input hash (should match generator output)
        hid_dim  : hidden layer size
        """
        super().__init__()
        self.manifold = manifold or geoopt.manifolds.Lorentz()

        self.fc1 = nn.Linear(hash_dim, hid_dim)
        self.fc2 = nn.Linear(hid_dim, 1)  # final binary output

    def forward(self, h):
        dev = h.device

        # Step 1: take from manifold to flat space
        h_tan = self.manifold.logmap0(h)

        # Step 2: normal 2-layer MLP
        x = F.relu(self.fc1(h_tan))
        out = self.fc2(x)
        real_score = torch.sigmoid(out).view(-1)  # squish to [0,1]

        return real_score.to(dev)

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import geoopt

# Trainer class to handle GAN-style training for hyperbolic hash learning
class HashTrainer:
    def __init__(
        self,
        gen: nn.Module,
        disc: nn.Module,
        manifold: geoopt.manifolds.Lorentz,
        kmeans_func,
        num_clusters: int,
        lr: float = 2e-4,
        beta1: float = 0.5,
        beta2: float = 0.999,
        noise_std: float = 0.1,
        alpha: float = 0.7,
    ):
        self.gen = gen.to("cpu")
        self.disc = disc.to("cpu")
        self.manifold = manifold
        self.kmeans_func = kmeans_func
        self.k = num_clusters
        self.noise_std = noise_std
        self.alpha = alpha

        self.adv_loss = nn.BCEWithLogitsLoss()

        self.opt_g = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(beta1, beta2))
        self.opt_d = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(beta1, beta2))

    def lorentz_dot(self, x, y):
        return -x[..., 0] * y[..., 0] + torch.sum(x[..., 1:] * y[..., 1:], dim=-1)

    def proj_tangent(self, x, v):
        inner = self.lorentz_dot(x, v).unsqueeze(-1)
        return v + inner * x

    def train_step(self, embeds: torch.Tensor):
        embeds = embeds.to("cpu")
        embeds = self.manifold.projx(embeds)

        B = embeds.shape[0]
        real_label = torch.ones(B)
        fake_label = torch.zeros(B)

        # Get pseudo-labels using k-means in hyperbolic space
        with torch.no_grad():
            cents, pseudo = self.kmeans_func(embeds, k=self.k)
            cents = self.manifold.projx(cents)

        # Run Generator to get both types of hash
        cont_raw, bin_hash = self.gen(embeds)
        cont_hash = self.manifold.projx(cont_raw)

        # Cluster loss: make points closer to their cluster centers
        dists = self.manifold.dist(
            embeds.unsqueeze(1),
            cents.unsqueeze(0),
        ).squeeze(1)
        c_loss = F.cross_entropy(-dists, pseudo)

        # Add hyperbolic noise
        rand_vec = torch.randn_like(embeds)
        tan_vec = self.proj_tangent(embeds, rand_vec)

        speed_sq = self.lorentz_dot(tan_vec, tan_vec).abs().clamp(min=1e-12)
        speed = speed_sq.sqrt().unsqueeze(-1)

        noise = tan_vec / speed * self.noise_std
        noisy_embed = self.manifold.expmap(embeds, noise)

        # Hash the noisy versions
        fake_raw, _ = self.gen(noisy_embed)
        fake_hash = self.manifold.projx(fake_raw)

        # Train Discriminator
        real_out = self.disc(cont_hash.detach())
        fake_out = self.disc(fake_hash.detach())

        loss_real = self.adv_loss(real_out, real_label)
        loss_fake = self.adv_loss(fake_out, fake_label)
        d_loss = 0.5 * (loss_real + loss_fake)

        self.opt_d.zero_grad()
        d_loss.backward()
        self.opt_d.step()

        # Train Generator
        for p in self.disc.parameters():
            p.requires_grad_(False)

        gen_out = self.disc(fake_hash)
        g_adv = self.adv_loss(gen_out, real_label)
        g_loss = self.alpha * c_loss + (1 - self.alpha) * g_adv

        self.opt_g.zero_grad()
        g_loss.backward()
        self.opt_g.step()

        for p in self.disc.parameters():
            p.requires_grad_(True)

        return d_loss.item(), g_loss.item(), c_loss.item()

In [None]:
import torch
import geoopt
import numpy as np
from tqdm import tqdm

device = torch.device("cpu")

# Load precomputed embeddings (saved as .npy file)
raw_embeddings = np.load('hyperbolic_embeddings.npy')  # shape [N, D]
manifold = geoopt.manifolds.Lorentz()
embeddings = torch.from_numpy(raw_embeddings).float().to(device)
embeddings = manifold.projx(embeddings)  # project to hyperboloid

# Model stuff
B, input_dim = embeddings.shape
hidden_dim = 512
hash_dim = 64
num_classes = 10
batch_size = 64
epochs = 50

# Models
gen = HashGen(input_dim, hidden_dim, hash_dim, manifold).to(device)
disc = Discrim(hash_dim).to(device)

# Trainer
trainer = HashTrainer(
    gen=gen,
    disc=disc,
    manifold=manifold,
    kmeans_func=lambda x, k: hyperbolic_kmeans(x, k, manifold),
    num_clusters=num_classes,
    lr=2e-4,
    noise_std=0.1,
    alpha=0.7,
)

# Training loop
for ep in range(epochs):
    d_total, g_total, c_total = 0.0, 0.0, 0.0
    steps = 0

    for i in tqdm(range(0, B, batch_size), desc=f"Epoch {ep + 1}/{epochs}"):
        batch = embeddings[i:i + batch_size]
        d_loss, g_loss, c_loss = trainer.train_step(batch)

        d_total += d_loss
        g_total += g_loss
        c_total += c_loss
        steps += 1

    avg_d = d_total / steps
    avg_g = g_total / steps
    avg_c = c_total / steps

    print(f"[Epoch {ep + 1}/{epochs}] D: {avg_d:.4f} | G: {avg_g:.4f} | Cluster: {avg_c:.4f}")


Epoch 1/50:   0%|          | 0/515 [00:00<?, ?it/s]

Epoch 1/50:  24%|██▍       | 124/515 [01:38<05:17,  1.23it/s]

In [22]:
cont_hash, bin_hash = generator(embeddings)

In [29]:
np.save('hashes.npy', bin_hash.detach().numpy())

In [12]:
import numpy as np
hashes = np.load("hashes.npy")

In [10]:
emb_np = np.load('hyperbolic_embeddings.npy')          # shape [N, D]
manifold = geoopt.manifolds.Lorentz()
emb = torch.from_numpy(emb_np).float()
embeddings = manifold.projx(emb)

In [13]:
def hamming_distance(a, b):
    return np.sum(a != b)

def lorentz_inner_product(x, y):
    return -x[0] * y[0] + np.dot(x[1:], y[1:])

def lorentz_distance(x, y):
    inner = -lorentz_inner_product(x, y)
    inner = np.clip(inner, 1 + 1e-5, None)
    return np.arccosh(inner)

def retrieve_neighbors(embeddings, hashes, query_index, hamming_k=100, top_k=5):
    query_embedding = embeddings[query_index]
    query_hash = hashes[query_index]

    # Step 1: Get top `hamming_k` neighbors by Hamming distance
    hamming_ranked = [
        (hamming_distance(query_hash, hashes[i]), i)
        for i in range(len(hashes)) if i != query_index
    ]
    hamming_ranked.sort(key=lambda x: x[0])
    top_candidates = [i for _, i in hamming_ranked[:hamming_k]]

    # Step 2: Re-rank using Lorentzian distance
    lorentz_ranked = [
        (lorentz_distance(query_embedding, embeddings[i]), i)
        for i in top_candidates
    ]
    lorentz_ranked.sort(key=lambda x: x[0])

    return [i for _, i in lorentz_ranked[:top_k]]

# Example
query_idx = 7
retrieved_indices = retrieve_neighbors(embeddings, hashes, query_idx)
print(f"Retrieved graph indices for query {query_idx}:")
for idx in retrieved_indices:
    print(f" - Graph index: {idx}")

Retrieved graph indices for query index 7:
 - Graph index: 78
 - Graph index: 53
 - Graph index: 83
 - Graph index: 37
 - Graph index: 34
