In [None]:
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch_geometric.data import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.nn import GATConv, GraphNorm, global_mean_pool

In [132]:
class PDBBindDataset(InMemoryDataset):
    def __init__(
        self,
        metadata_csv = "data/processed/refined_dataset_metadata.csv",
        complex_dir = "data/processed/complex_graphs",
        split_dir = "data/processed/splits",
        root = "data/processed/full_dataset",
        transform=None,
        pre_transform=None,
        force_rebuild=False
    ):
        self.metadata_csv = Path(metadata_csv)
        self.complex_dir = Path(complex_dir)
        self.root = Path(root)
        self.root.mkdir(parents=True, exist_ok=True)
        self.force_rebuild = force_rebuild
        
        self.cache_file = self.root / f"full_dataset.pt"

        super().__init__(self.root, transform, pre_transform)

        print("Processing complex dataset...")
        self.process()
        self.data, self.slices = torch.load(self.cache_file)
        print(self.data, self.slices)

    @property
    def raw_file_names(self):
        return [self.metadata_csv.name]

    @property
    def processed_file_names(self):
        return [self.cache_file.name]

    def process(self):
        df = pd.read_csv(self.metadata_csv)
        data_list = []
        skipped = 0

        for _, row in df.iterrows():
            cid = row["complex_id"]
            affinity = float(row["affinity"])
            graph_file = self.complex_dir / f"{cid}.pt"

            if not graph_file.exists():
                print(f"Complex graph missing for {cid}")
                skipped += 1
                continue

            try:
                graph_data = torch.load(graph_file)
                graph_data.y = torch.tensor([affinity], dtype=torch.float32)
                graph_data.complex_id = cid
                data_list.append(graph_data)
            except Exception as e:
                print(f"Failed to load complex {cid}: {e}")
                skipped += 1
                continue

            if len(data_list) % 200 == 0:
                print(f"Loaded {len(data_list)} complex graphs...")

        print(f"Finished. Loaded {len(data_list)} complexes. Skipped: {skipped}")

        if len(data_list) == 0:
            raise RuntimeError("No valid complex graphs loaded. Check paths and files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.cache_file)
        print(f"Saved processed dataset to {self.cache_file}")

    def len(self):
        return super().len()

    def get(self, idx):
        return super().get(idx)

In [215]:
class CoreGNN(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.2,
    ):
        super().__init__()

        self.encoder_layers = nn.ModuleList([
            GATConv(
                in_channels=node_dim if i == 0 else hidden_dim,
                out_channels=hidden_dim,
                edge_dim=edge_dim,
                heads=4,
                concat=False
            )
            for i in range(num_layers)
        ])
        self.encoder_norms = nn.ModuleList([GraphNorm(hidden_dim) for _ in range(num_layers)])

        self.cross_layers = nn.ModuleList([
            GATConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=edge_dim,
                heads=4,
                concat=False
            )
            for _ in range(num_layers)
        ])
        self.cross_norms = nn.ModuleList([GraphNorm(hidden_dim) for _ in range(num_layers)])

        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.dropout = dropout
        self.hidden_dim = hidden_dim
        
    def forward(self, data):
        x, edge_index, edge_attr, batch, node_type = (
            data.x, data.edge_index, data.edge_attr, data.batch, data.node_type
        )
        src, dst = edge_index
        print(batch)
        print('src', src)
        print('dst', dst)
        print('node_type', node_type)
        print('len(node_type)', len(node_type))
        print('max(edge_index)', max(edge_index.flatten()))
        print(node_type[src])
        print(node_type[dst])
        same_mask = (node_type[src] == node_type[dst])
        print(edge_attr)
        cross_mask = (node_type[src] != node_type[dst])

        intra_edge_index = edge_index[:, same_mask]
        intra_edge_attr = edge_attr[same_mask]
        cross_edge_index = edge_index[:, cross_mask]
        cross_edge_attr = edge_attr[cross_mask]

        # ligand + protein separately
        for conv, norm in zip(self.encoder_layers, self.encoder_norms):
            x = conv(x, intra_edge_index, intra_edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # ligand <-> protein)
        for conv, norm in zip(self.cross_layers, self.cross_norms):
            x = conv(x, cross_edge_index, cross_edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # pooling
        out = global_mean_pool(x, batch)
        
        # regression
        y_hat = self.readout(out)
        print(y_hat)
        return y_hat.squeeze(1)


In [216]:
dataset = PDBBindDataset()

Processing...


Complex graph missing for 2r58
Complex graph missing for 3c2f
Complex graph missing for 3g2y
Complex graph missing for 3pce
Complex graph missing for 4qsu
Complex graph missing for 4qsv
Complex graph missing for 4u54
Complex graph missing for 3ao4
Complex graph missing for 4cs9
Complex graph missing for 2w8w
Complex graph missing for 3gv9
Complex graph missing for 6r9u
Complex graph missing for 6abx
Complex graph missing for 4q90
Complex graph missing for 5cs3
Complex graph missing for 4tim
Complex graph missing for 5fe6
Complex graph missing for 6ghj
Complex graph missing for 3gqz
Complex graph missing for 4y3j
Complex graph missing for 5oxk
Complex graph missing for 5z5f
Complex graph missing for 4ahr
Complex graph missing for 4ahs
Complex graph missing for 4mre
Complex graph missing for 1x8d
Complex graph missing for 4g0z
Complex graph missing for 1m0n
Complex graph missing for 2aac
Complex graph missing for 4aci
Complex graph missing for 4ury
Complex graph missing for 3ao5
Complex 

Complex graph missing for 4b34
Complex graph missing for 6g27
Complex graph missing for 1e6q
Complex graph missing for 4de5
Complex graph missing for 4fzj
Complex graph missing for 5flo
Complex graph missing for 5fnf
Complex graph missing for 6a87
Complex graph missing for 2ymd
Complex graph missing for 3f34
Complex graph missing for 5fot
Complex graph missing for 5fox
Complex graph missing for 5ose
Complex graph missing for 1lbk
Complex graph missing for 3bug
Complex graph missing for 3g1v
Complex graph missing for 4gzx
Complex graph missing for 6dyn
Complex graph missing for 6ssy
Complex graph missing for 5cs6
Complex graph missing for 5ijr
Complex graph missing for 1s5z
Complex graph missing for 1br6
Complex graph missing for 1e6s
Complex graph missing for 1ew9
Complex graph missing for 1k1y
Complex graph missing for 1zc9
Complex graph missing for 2ri9
Complex graph missing for 2uy3
Complex graph missing for 4n5d
Complex graph missing for 1px4
Complex graph missing for 3zsy
Complex 

Done!


In [217]:
len(dataset)

24

In [218]:
node_dim = 22
edge_dim = 5

model = CoreGNN(node_dim=node_dim, edge_dim=edge_dim, hidden_dim=128)

In [219]:
dataloader = DataLoader(dataset, batch_size=4)



In [220]:
for batch in dataloader:
    model(batch)
    break

tensor([0, 0, 0,  ..., 3, 3, 3])
src tensor([   0,   11,    0,  ..., 2512, 2367, 2643])
dst tensor([  11,    0,   10,  ..., 2367, 2643, 2367])
node_type tensor([0, 0, 0,  ..., 1, 1, 1])
len(node_type) 3132
max(edge_index) tensor(3131)
tensor([0, 0, 0,  ..., 1, 0, 1])
tensor([0, 0, 0,  ..., 0, 1, 0])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [4.0560, 0.0000, 0.0000, 0.0000, 0.0000],
        [4.3776, 0.0000, 0.0000, 0.0000, 0.0000],
        [4.3776, 0.0000, 0.0000, 0.0000, 0.0000]])
tensor([[-0.0053],
        [ 0.1294],
        [ 0.0473],
        [ 0.1194]], grad_fn=<AddmmBackward0>)
