In [None]:
!pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html > /dev/null 2>&1
!pip install torch-geometric > /dev/null 2>&1
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.0.0+cu118.html > /dev/null 2>&1
!pip install git+https://github.com/datamol-io/graphium.git@2.2.0 > /dev/null 2>&1
!pip install rdkit > /dev/null 2>&1

In [None]:
import os
from google.cloud import storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/content/drive/MyDrive/instant-tape-******.json"
storage_client = storage.Client()

In [None]:
# Random Walk Structural Encoding
from typing import Tuple, Union, Optional, List, Dict, Any, Iterable
import torch
from torch_geometric.utils import to_dense_adj
from torch_scatter import scatter_add
from torch_geometric.utils.num_nodes import maybe_num_nodes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_rwse(
    adj: torch.Tensor,
    ksteps: Union[int, List[int]],
    num_nodes: int,
    cache: Dict[str, Any],
    pos_type: str = "rw_return_probs",
    space_dim: int = 0,
) -> Tuple[torch.Tensor, str, Dict[str, Any]]:

    base_level = "node" if pos_type == "rw_return_probs" else "nodepair"

    # Manually handles edge case of 1 atom molecules here
    if not isinstance(ksteps, Iterable):
        ksteps = list(range(1, ksteps + 1))
    if num_nodes == 1:
        if pos_type == "rw_return_probs":
            return torch.ones((1, len(ksteps))), base_level, cache
        else:
            return torch.ones((1, 1, len(ksteps))), base_level, cache

    edge_index, edge_weight = adj.nonzero(as_tuple=True), adj[adj != 0]

    # Compute the random-walk transition probabilities
    if "ksteps" in cache:
        cached_k = cache["ksteps"]
        missing_k = [k for k in ksteps if k not in cached_k]
        if not missing_k:
            pass
        elif min(missing_k) < min(cached_k):
            Pk_dict = get_Pks(missing_k, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
            cache["ksteps"] = sorted(missing_k + cached_k)
            for k in missing_k:
                cache["Pk"][k] = Pk_dict[k]
        else:
            start_k = min([max(cached_k), min(missing_k)])
            start_Pk = cache["Pk"][start_k]
            Pk_dict = get_Pks(
                missing_k,
                edge_index=edge_index,
                edge_weight=edge_weight,
                num_nodes=num_nodes,
                start_Pk=start_Pk,
                start_k=start_k,
            )
            cache["ksteps"] = sorted(cached_k + missing_k)
            for k in missing_k:
                cache["Pk"][k] = Pk_dict[k]
    else:
        Pk_dict = get_Pks(ksteps, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
        cache["ksteps"] = list(Pk_dict.keys())
        cache["Pk"] = Pk_dict

    pe_list = []
    if pos_type == "rw_return_probs":
        for k in ksteps:
            pe_list.append(torch.diagonal(cache["Pk"][k]) * (k ** (space_dim / 2)))
    else:
        for k in ksteps:
            pe_list.append(cache["Pk"][k])

    pe = torch.stack(pe_list, dim=-1)
    return pe, base_level, cache

def get_Pks(
    ksteps: List[int],
    edge_index: Tuple[torch.Tensor, torch.Tensor],
    edge_weight: Optional[torch.Tensor] = None,
    num_nodes: Optional[int] = None,
    start_Pk: Optional[torch.Tensor] = None,
    start_k: Optional[int] = None,
) -> Dict[int, torch.Tensor]:

    edge_index = (edge_index[0].to(device, dtype=torch.int64), edge_index[1].to(device, dtype=torch.int64))
    batch = torch.zeros(num_nodes, dtype=torch.long).to(device)

    if edge_weight is not None:
        edge_weight = edge_weight.to(device).float()
    if start_Pk is not None:
        start_Pk = start_Pk.to(device).float()

    if edge_weight is None:
        edge_weight = torch.ones(edge_index[0].size(0), device=edge_index[0].device, dtype=torch.float32)

    num_nodes = int(maybe_num_nodes(edge_index, num_nodes))
    src = edge_index[0]
    deg = scatter_add(edge_weight, src, dim=0, dim_size=num_nodes)  # Out degrees
    deg_inv = deg.pow(-1.0)
    deg_inv.masked_fill_(deg_inv == float("inf"), 0)

    if edge_index[0].numel() == 0:
        P = edge_index[0].new_zeros((num_nodes, num_nodes))
    else:
        P = torch.diag(deg_inv).float() @ to_dense_adj(
            edge_index, edge_attr=edge_weight, batch=batch, max_num_nodes=num_nodes
        ).squeeze(0)

        # P = torch.diag(deg_inv).float() @ to_dense_adj(
        #    edge_index, edge_weight, max_num_nodes=num_nodes
        #).squeeze(0)

    if start_Pk is not None:
        Pk = start_Pk @ P.clone().detach().matrix_power(min(ksteps) - start_k)
    else:
        Pk = P.clone().detach().matrix_power(min(ksteps))

    Pk_dict = {}
    for k in range(min(ksteps), max(ksteps) + 1):
        Pk_dict[k] = Pk
        Pk = Pk @ P

    return Pk_dict

In [None]:
# Laplacian Positional Encoding
import torch
from typing import Tuple, Union, Dict, Any

def compute_laplacian_pe(
    adj: torch.Tensor,
    num_pos: int,
    cache: Dict[str, Any],
    normalization: str = "none",
) -> Tuple[torch.Tensor, torch.Tensor, str, Dict[str, Any]]:

    base_level = "node"
    device = adj.device

    # Convert dense tensor to sparse tensor if not already sparse
    if not adj.is_sparse:
        if "csr_adj" not in cache:
            adj = adj.to_sparse()
            cache["csr_adj"] = adj
        else:
            adj = cache["csr_adj"]

    # Compute the Laplacian, and normalize it
    if f"L_{normalization}_sp" not in cache:
        D = torch.sum(adj, dim=1).flatten()

        # Create a diagonal matrix D_mat directly on GPU
        n = D.size(0)
        eye = torch.eye(n, device=D.device)
        D_mat = eye * D.unsqueeze(0)

        L = -adj + D_mat
        L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization)
        cache[f"L_{normalization}_sp"] = L_norm
    else:
        L_norm = cache[f"L_{normalization}_sp"]

    # Compute the eigenvectors for the graph
    if "lap_eig" not in cache:
        epsilon = 1e-8
        L_norm = L_norm.to_dense()  # Convert L_norm to a dense tensor
        L_norm += torch.eye(L_norm.size(0), device=L_norm.device) * epsilon

        # Convert back to sparse if needed (though for eigenvalue computation, dense might be better)
        L_norm = L_norm.to_sparse()

        eigvals, eigvecs = _get_positional_eigvecs(L_norm, num_pos=num_pos)

        eigvecs[~torch.isfinite(eigvecs)] = 0.0
        eigvals[~torch.isfinite(eigvals)] = 0.0

        # repeat eigenvals for each node
        eigvals = eigvals.unsqueeze(0).repeat(adj.shape[0], 1)

        cache["lap_eig"] = (eigvals, eigvecs)
    else:
        eigvals, eigvecs = cache["lap_eig"]

    return eigvals, eigvecs, base_level, cache

def _get_positional_eigvecs(
    matrix: torch.Tensor,
    num_pos: int
) -> Tuple[torch.Tensor, torch.Tensor]:

    matrix = matrix.to_dense()
    eigvals, eigvecs = torch.linalg.eigh(matrix)

    # Pad with non-sense eigenvectors if required
    if num_pos > matrix.shape[0]:
        temp_EigVal = torch.ones(num_pos - matrix.shape[0], dtype=torch.float64, device=device) + float("inf")
        temp_EigVec = torch.zeros((matrix.shape[0], num_pos - matrix.shape[0]), dtype=torch.float64, device=device)
        eigvals = torch.cat([eigvals, temp_EigVal], dim=0)
        eigvecs = torch.cat([eigvecs, temp_EigVec], dim=1)

    # Sort and keep only the first `num_pos` elements
    sort_idx = eigvals.argsort()
    eigvals = eigvals[sort_idx]
    eigvals = eigvals[:num_pos]
    eigvecs = eigvecs[:, sort_idx]
    eigvecs = eigvecs[:, :num_pos]

    # Normalize the eigvecs
    eigvecs = eigvecs / torch.maximum(torch.sqrt(torch.sum(eigvecs**2, dim=0, keepdim=True)), torch.tensor(1e-4, device=matrix.device, dtype=matrix.dtype))

    return eigvals, eigvecs

def normalize_matrix(
    matrix: torch.Tensor,
    degree_vector: torch.Tensor,
    normalization: str = None
) -> torch.Tensor:

    device = matrix.device

    if degree_vector is not None:
        degree_inv = degree_vector.pow(-0.5).unsqueeze(1).to_dense()  # Convert to dense for the assignment
        degree_inv[torch.isinf(degree_inv)] = 0
        degree_inv = degree_inv.to_sparse()

    # Compute the normalized matrix
    if (normalization is None) or (normalization.lower() == "none"):
        pass
    elif normalization.lower() == "sym":
        matrix = degree_inv * matrix * degree_inv.T
    elif normalization.lower() == "inv":
        matrix = (degree_inv**2) * matrix
    else:
        raise ValueError(
            f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided'
        )

    return matrix

In [None]:
# Transfer Position Levels
from typing import Tuple, Union, List, Dict, Any, Optional
import torch
from torch_geometric.utils import from_scipy_sparse_matrix

def transfer_pos_level(
    pe: torch.Tensor,
    in_level: str,
    out_level: str,
    adj: Union[torch.Tensor, torch.sparse.FloatTensor],
    num_nodes: int,
    cache: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:

    pe = pe.to('cuda')
    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = adj.to('cuda')

    if cache is None:
        cache = {}

    if in_level == "node":
        if out_level == "node":
            pass

        elif out_level == "edge":
            pe, cache = node_to_edge(pe, adj, cache)

        elif out_level == "nodepair":
            pe = node_to_nodepair(pe, num_nodes)

        elif out_level == "graph":
            raise NotImplementedError("Transfer function (node -> graph) not yet implemented.")
        else:
            raise ValueError(f"Unknown `pos_level`: {out_level}")
    elif in_level == "edge":
        raise NotImplementedError("Transfer function (edge -> *) not yet implemented.")
    elif in_level == "nodepair":
        if len(pe.shape) == 2:
            pe = torch.unsqueeze(pe, -1)

        if out_level == "node":
            pe = nodepair_to_node(pe)

        elif out_level == "edge":
            pe, cache = nodepair_to_edge(pe, adj, cache)

        elif out_level == "nodepair":
            pass

        elif out_level == "graph":
            raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.")
        else:
            raise ValueError(f"Unknown `pos_level`: {out_level}")
    elif in_level == "graph":
        if out_level == "node":
            pe = graph_to_node(pe, num_nodes, cache)

        elif out_level in ["edge", "nodepair"]:
            raise NotImplementedError("Transfer function (graph -> edge/nodepair) not yet implemented.")
        else:
            raise ValueError(f"Unknown `pos_level`: {out_level}")
    else:
        raise ValueError(f"Unknown `pos_level`: {in_level}")

    return pe

def node_to_edge(
    pe: torch.Tensor, adj: Union[torch.Tensor, torch.sparse.FloatTensor], cache: Optional[Dict[str, Any]] = None
) -> Tuple[torch.Tensor, Dict[str, Any]]:

    pe = pe.to('cuda')
    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = adj.to('cuda')

    if cache is None:
        cache = {}

    edge_index, _ = from_scipy_sparse_matrix(adj)
    src, dst = edge_index[0], edge_index[1]

    pe_sum = pe[src] + pe[dst]
    pe_abs_diff = torch.abs(pe[src] - pe[dst])

    edge_pe = torch.cat((pe_sum, pe_abs_diff), dim=-1)

    return edge_pe, cache

def node_to_nodepair(pe: torch.Tensor, num_nodes: int) -> torch.Tensor:
    pe = pe.to('cuda')
    expanded_pe = torch.unsqueeze(pe, dim=1)
    expanded_pe = expanded_pe.repeat(1, num_nodes, 1)

    pe_sum = expanded_pe + expanded_pe.transpose(0, 1)
    pe_abs_diff = torch.abs(expanded_pe - expanded_pe.transpose(0, 1))

    nodepair_pe = torch.cat((pe_sum, pe_abs_diff), dim=-1)

    return nodepair_pe

def node_to_graph(pe: torch.Tensor, num_nodes: int) -> torch.Tensor:
    raise NotImplementedError("Transfer function (node -> graph) not yet implemented.")

def edge_to_node(pe: torch.Tensor, adj: Union[torch.Tensor, torch.sparse.FloatTensor]) -> torch.Tensor:
    raise NotImplementedError("Transfer function (edge -> node) not yet implemented.")

def edge_to_nodepair(
    pe: torch.Tensor, adj: Union[torch.Tensor, torch.sparse.FloatTensor], num_nodes: int, cache: Optional[Dict[str, Any]] = None
) -> torch.Tensor:
    pe = pe.to('cuda')
    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = adj.to('cuda')

    if cache is None:
        cache = {}

    num_feat = pe.shape[-1]

    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = torch.sparse.FloatTensor(adj).to('cuda')
    dst, src = adj.indices()[0], adj.indices()[1]

    nodepair_pe = torch.zeros((num_nodes, num_nodes, num_feat), device='cuda')

    for i in range(len(dst)):
        nodepair_pe[dst[i], src[i], ...] = pe[i, ...]

    return nodepair_pe, cache

def edge_to_graph(pe: torch.Tensor) -> torch.Tensor:
    raise NotImplementedError("Transfer function (edge -> graph) not yet implemented.")

def nodepair_to_node(pe: torch.Tensor, stats_list: List = [torch.min, torch.mean, torch.std]) -> torch.Tensor:
    num_feat = pe.shape[-1]

    node_pe_list = []

    for stat in stats_list:
        for i in range(num_feat):
            node_pe_list.append(stat(pe[..., i], dim=0))
            node_pe_list.append(stat(pe[..., i], dim=1))
    node_pe = torch.stack(node_pe_list, dim=-1)

    return node_pe

def nodepair_to_edge(
    pe: torch.Tensor, adj: Union[torch.Tensor, torch.sparse.FloatTensor], cache: Optional[Dict[str, Any]] = None
) -> torch.Tensor:
    pe = pe.to('cuda')
    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = adj.to('cuda')

    if cache is None:
        cache = {}

    num_feat = pe.shape[-1]

    if not isinstance(adj, torch.sparse.FloatTensor):
        adj = torch.sparse.FloatTensor(adj).to('cuda')
    dst, src = adj.indices()[0], adj.indices()[1]

    edge_pe = torch.zeros((len(dst), num_feat), device='cuda')

    for i in range(len(src)):
        edge_pe[i, ...] = pe[dst[i], src[i]]

    return edge_pe, cache

def nodepair_to_graph(pe: torch.Tensor, num_nodes: int) -> torch.Tensor:
    raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.")

def graph_to_node(
    pe: Union[torch.Tensor, List], num_nodes: int, cache: Optional[Dict[str, Any]] = None
) -> torch.Tensor:
    if cache is None:
        cache = {}

    node_pe = None

    # The key 'components' is only in cache if disconnected_comp == True when computing base pe
    if "components" in cache:
        if len(cache["components"]) > 1:
            node_pe = torch.zeros((num_nodes, len(pe)), device='cuda')
            components = cache["components"]

            for i, component in enumerate(components):
                comp = list(component)
                node_pe[comp, :] = pe[i]

    if node_pe is None:
        node_pe = pe.repeat(num_nodes, 1)

    return node_pe

In [None]:
# Get Positional Encodings
from typing import Tuple, Union, Optional, Dict, Any, OrderedDict
from copy import deepcopy
import torch
from scipy.sparse import spmatrix
from collections import OrderedDict as OderedDictClass

# Define the device (GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_all_positional_encodings(
    adj: Union[torch.Tensor, spmatrix],
    num_nodes: int,
    pos_kwargs: Optional[Dict] = None,
) -> Tuple["OrderedDict[str, torch.Tensor]"]:

    pos_kwargs = {} if pos_kwargs is None else pos_kwargs

    pe_dict = OderedDictClass()
    cache = {}

    if len(pos_kwargs) > 0:
        for pos_name, this_pos_kwargs in pos_kwargs["pos_types"].items():
            this_pos_kwargs = deepcopy(this_pos_kwargs)
            pos_type = this_pos_kwargs.pop("pos_type", None)
            pos_level = this_pos_kwargs.pop("pos_level", None)
            this_pe, cache = graph_positional_encoder(
                adj.clone(),
                num_nodes,
                pos_type=pos_type,
                pos_level=pos_level,
                pos_kwargs=this_pos_kwargs,
                cache=cache,
            )
            if pos_level == "node":
                pe_dict.update({f"{pos_type}": this_pe})
            else:
                pe_dict.update({f"{pos_level}_{pos_type}": this_pe})

    return pe_dict

def graph_positional_encoder(
    adj: Union[torch.Tensor, spmatrix],
    num_nodes: int,
    pos_type: Optional[str] = None,
    pos_level: Optional[str] = None,
    pos_kwargs: Optional[Dict[str, Any]] = None,
    cache: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:

    pos_kwargs = deepcopy(pos_kwargs) if pos_kwargs else {}
    cache = cache if cache else {}

    pos_type2 = pos_kwargs.pop("pos_type", None)
    pos_level2 = pos_kwargs.pop("pos_level", None)

    # Conversion of tensors to device
    adj = adj.to(device)

    # Calculate positional encoding
    if pos_type == "laplacian_eigvec":
        _, pe, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs)
    elif pos_type == "laplacian_eigval":
        pe, _, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs)
    elif pos_type == "rw_return_probs":
        pe, base_level, cache = compute_rwse(
            adj, num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs
        )
    else:
        raise ValueError(f"Unknown `pos_type`: {pos_type}")

    # Convert between different pos levels
    if isinstance(pe, (list, tuple)):
        pe = [transfer_pos_level(this_pe, base_level, pos_level, adj, num_nodes, cache) for this_pe in pe]
    else:
        pe = transfer_pos_level(pe, base_level, pos_level, adj, num_nodes, cache)

    return pe, cache

In [None]:
import torch
from io import BytesIO

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up GCS storage configurations
SOURCE_BUCKET_NAME = 'pyg-molecular-sample'
TARGET_BUCKET_NAME = 'pyg-molecular-positional-2'
CLIENT = storage.Client()
SOURCE_BUCKET = CLIENT.bucket(SOURCE_BUCKET_NAME)
TARGET_BUCKET = CLIENT.bucket(TARGET_BUCKET_NAME)

# Load PyG object from GCS
def load_pyg_object_from_gcs(blob_name):
    blob = SOURCE_BUCKET.blob(blob_name)
    byte_stream = blob.download_as_bytes()
    buffer = BytesIO(byte_stream)
    content = torch.load(buffer)
    data = content['graph_with_features']
    return data.to(device)

# Save PyG object with encodings to GCS
def save_to_gcs(data, blob_name):
    data = data.cpu()
    buffer = BytesIO()
    content = {'graph_with_features': data}
    torch.save(content, buffer)
    buffer.seek(0)
    blob = TARGET_BUCKET.blob(blob_name)
    blob.upload_from_file(buffer)

# Compute the positional encodings and store them in the PyG object
def compute_and_store_encodings(data):
    adjacency_matrix = torch.zeros((data.num_nodes, data.num_nodes), device=device)
    for (src, dst) in data.edge_index.T:
        src, dst = int(src), int(dst)  # Convert to int for indexing
        adjacency_matrix[src][dst] = 1
        adjacency_matrix[dst][src] = 1

    pos_kwargs = {
        "pos_types": {
            "lap_eigvec": {
                "pos_level": "node",
                "pos_type": "laplacian_eigvec",
                "num_pos": 8,
                "normalization": "none",
            },
            "lap_eigval": {
                "pos_level": "node",
                "pos_type": "laplacian_eigval",
                "num_pos": 8,
                "normalization": "none",
            },
            "rw_return_probs": {
                "pos_type": "rw_return_probs",
                "pos_level": "node",
                "ksteps": [4, 8]
            }
        }
    }

    results = get_all_positional_encodings(adjacency_matrix, data.num_nodes, pos_kwargs=pos_kwargs)

    # Directly attach the tensor attributes to the PyG Data object
    for key, encoding in results.items():
        setattr(data, key, encoding.to(device))

    return data

for blob in SOURCE_BUCKET.list_blobs():
    data = load_pyg_object_from_gcs(blob.name)
    updated_data = compute_and_store_encodings(data)
    save_to_gcs(updated_data, blob.name)

    # Explicitly delete and free GPU memory
    del data
    del updated_data
    torch.cuda.empty_cache()

KeyboardInterrupt: ignored