In [None]:

import torch
import sys
import os

import argparse
sys.path.append(os.path.dirname(rf"Hypformer/hypformer.py"))


from hypformer import HypFormer  


import numpy as np

from sklearn.metrics import roc_auc_score



import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.graphgym import train
from torch_geometric.loader import DataLoader
from torch_geometric.utils import scatter
from torch_geometric.nn.pool import global_mean_pool


from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.data import Data, Batch

from tqdm import tqdm







### Loading Model


In [None]:
args = argparse.Namespace()
args.batch_size = 16
args.hidden_dim = 128
args.trans_num_layers = 2
args.trans_num_heads = 4
args.trans_dropout = 0.2
args.lr = 1e-5
args.weight_decay = 1e-4
args.epochs = 5
args.k_in = 1.0
args.k_out = 1.0
args.decoder_type = 'hyp'
args.device = 'cuda:0' 
args.add_positional_encoding = False
args.attention_type = 'full'
args.power_k = 2
args.trans_heads_concat = False
args.aggregate = False
args.extra_hyp_linears = 10

model = HypFormer(
    in_channels=3,
    hidden_channels=1024,
    extra_dims=[64,256,512,1024,2048,4096,2048,1024,512,1024,1024],
    out_channels=2, 
    trans_num_layers=2,
    trans_num_heads=8,
    trans_dropout=0.1,
    trans_use_bn=True,
    trans_use_residual=True,
    trans_use_weight=True,
    trans_use_act=True,
    args=args
)


In [3]:
checkpoint_path = "checkpoints/best_model.pt"
checkpoint = torch.load(checkpoint_path)


In [4]:
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

epoch = checkpoint['epoch']
val_rocauc = checkpoint['val_rocauc']

print(f"Loaded model from epoch {epoch} with validation ROC-AUC: {val_rocauc:.4f}")


Loaded model from epoch 17 with validation ROC-AUC: 0.5347


In [5]:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [7]:
def prepare_dataloaders(batch_size):
    dataset = PygGraphPropPredDataset(name='ogbg-molhiv')
    splits = dataset.get_idx_split()
    train_ds = dataset[splits['train']]
    valid_ds = dataset[splits['valid']]
    test_ds  = dataset[splits['test']]

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
    valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)
    return train_loader, valid_loader, test_loader, dataset.num_features

In [8]:
train_loader, valid_loader, test_loader, in_channels = prepare_dataloaders(args.batch_size)

#### Extracting Embeddings

In [None]:
from pathlib import Path
import h5py

def save_hyperbolic_embeddings(model, loader, device, save_path, layers=('fc3', 'fc11')):
    model.eval()
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    
    layer_dims = {}
    for layer in layers:
        layer_index = int(layer[2:]) 
        layer_dims[layer] = model.trans_conv.extra_dims[layer_index] + 1  # +1 for Lorentz time dim

    with h5py.File(save_path, 'w') as hf:
        graph_group = hf.create_group("graphs")
        metadata_group = hf.create_group("metadata")
        
        hyperbolic_embeddings = {}
        def get_activation(layer_name):
            def hook(module, input, output):
                hyperbolic_embeddings[layer_name] = output.detach().cpu()
            return hook
        
        handles = []
        for layer in layers:
            handles.append(model.trans_conv.fcs[int(layer[2:])].register_forward_hook(get_activation(layer)))

        graph_idx = 0
        for batch in tqdm(loader, desc="Processing batches"):

            src, _ = batch.edge_index
            batch.x = scatter(batch.edge_attr, src, 
                            dim=0, 
                            dim_size=batch.num_nodes,
                            reduce='sum')
            batch = batch.to(device)
            
            _ = model(batch.x)
            

            batch_indices = batch.batch.cpu().numpy()
            unique_graphs = np.unique(batch_indices)
            
            for graph_num in unique_graphs:

                graph_name = f"graph_{graph_idx:06d}"
                current_group = graph_group.create_group(graph_name)
                
                node_mask = (batch_indices == graph_num)
                
                for layer in layers:
                    emb = hyperbolic_embeddings[layer][node_mask].numpy()
                    current_group.create_dataset(
                        layer,
                        data=emb,
                        shape=emb.shape,
                        compression='gzip'
                    )
                
                current_group.create_dataset("label", data=batch.y[graph_num].cpu().numpy())
                current_group.attrs['num_nodes'] = np.sum(node_mask)
                
                graph_idx += 1

        metadata_group.attrs['manifold'] = 'Lorentz'
        metadata_group.attrs['curvature'] = model.trans_conv.manifold_hidden.k.item()
        metadata_group.attrs['total_graphs'] = graph_idx
        
    for handle in handles:
        handle.remove()


In [10]:
save_hyperbolic_embeddings(model, test_loader, device, 'embedding/test_hyperbolic_embeddings.h5')

Processing batches: 100%|██████████| 258/258 [00:30<00:00,  8.36it/s]


In [11]:
save_hyperbolic_embeddings(model, valid_loader, device, 'embedding/valid_hyperbolic_embeddings.h5')

Processing batches: 100%|██████████| 258/258 [00:32<00:00,  7.87it/s]


In [12]:
save_hyperbolic_embeddings(model, train_loader, device, 'embedding/train_hyperbolic_embeddings.h5')

Processing batches: 100%|██████████| 2057/2057 [04:54<00:00,  7.00it/s]


#### Code to load the embeddings


In [None]:
from manifolds.lorentz import Lorentz 

def load_hyperbolic_embeddings(file_path, device='cpu'):
    """Load hyperbolic embeddings with manual geometric validation"""
    with h5py.File(file_path, 'r') as hf:
        # print(hf.keys())
        k = hf['metadata'].attrs['curvature']
        manifold = Lorentz(k=k)
        
        embeddings = {}
        labels = []
        
        for graph_name in tqdm(hf['graphs'], desc="Loading graphs"):
            graph_group = hf['graphs'][graph_name]
            # print(graph_group.keys())
            
            for layer in ['fc3', 'fc11']:  
                emb_np = graph_group[layer][:]
                emb = torch.from_numpy(emb_np).to(device)
                
  
                time_dim = emb[..., 0]
                space_dims = emb[..., 1:]
                inner_product = -time_dim**2 + (space_dims**2).sum(dim=-1)
             
                if layer not in embeddings:
                    embeddings[layer] = []
                embeddings[layer].append(emb)
            
    
            labels.append(torch.from_numpy(graph_group['label'][:]).to(device))
        

        metadata = {
            'manifold': 'Lorentz',
            'curvature': k,
            'total_graphs': len(hf['graphs'])
        }
        
    return embeddings, labels, metadata


In [None]:

embeddings, labels, metadata = load_hyperbolic_embeddings(
    'embedding/train_hyperbolic_embeddings.h5', 
    device='cuda'
)

manifold = metadata['manifold']


Loading graphs: 100%|██████████| 32901/32901 [02:24<00:00, 227.69it/s]


In [None]:

type(embeddings['fc3'])

list

In [16]:
len(embeddings['fc3'])

32901

In [20]:
embeddings['fc3'][2].shape

torch.Size([16, 1025])

#### DataLoader for the embeddings

In [None]:
class HyperbolicDataset(torch.utils.data.Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        with h5py.File(file_path, 'r') as hf:
            self.graph_names = list(hf['graphs'].keys())
            
    def __len__(self):
        return len(self.graph_names)
    
    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as hf:
            group = hf['graphs'][self.graph_names[idx]]
            return {
                'fc3': torch.from_numpy(group['fc3'][:]),
                'fc11': torch.from_numpy(group['fc11'][:]),
                'label': torch.from_numpy(group['label'][:])
            }

In [None]:

dataset = HyperbolicDataset('embedding/train_hyperbolic_embeddings.h5')
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
