In [1]:
import uproot
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from scipy import sparse
import os
from sklearn import metrics

os.environ["DGLBACKEND"] = "pytorch"
import dgl
from dgl.dataloading import GraphDataLoader
from dgl.data import DGLDataset
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"

In [3]:
class MCDataset_dgl(DGLDataset):
    def __init__(self, path):
        self.path = path
        self.label = np.load(os.path.join(self.path, f"label.npy"))
        self.label = torch.tensor(self.label, dtype=torch.long)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        g, _ = dgl.load_graphs(os.path.join(self.path, f"graph.bin"), [index])
        return (
            g[0],
            self.label[index],
        )

In [4]:
MCdataset_train = MCDataset_dgl("/tmp/hky/gnndata_gin_Pgamma/train_data/")
MCdataset_val = MCDataset_dgl("/tmp/hky/gnndata_gin_Pgamma/val_data/")

In [5]:
train_dataloader = GraphDataLoader(MCdataset_train, batch_size=64, drop_last=False,num_workers=8,shuffle=True)
val_dataloader = GraphDataLoader(MCdataset_val, batch_size=64, drop_last=False,num_workers=8,shuffle=True)

In [6]:
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling


class MLP(nn.Module):
    """Construct two-layer MLP-type aggreator for GIN model"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)


class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        num_layers = 5
        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        for layer in range(num_layers - 1):  # excluding the input layer
            if layer == 0:
                mlp = MLP(input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
            self.ginlayers.append(
                GINConv(mlp, learn_eps=False)
            )  # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        # linear functions for graph sum poolings of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
        self.drop = nn.Dropout(0.5)
        self.pool = (
            MaxPooling()
        )  # change to mean readout (AvgPooling) on social network datasets

    def forward(self, g, h):
        # list of hidden representation at each layer (including the input layer)
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            hidden_rep.append(h)
        score_over_layer = 0
        # perform graph sum pooling over all nodes in each layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
        return score_over_layer

In [7]:
MCdataset_train[0][0].to(device)
# model = GIN(9, 64, 2).to(device)
# for batched_graph, labels in train_dataloader:
#     print(batched_graph.num_nodes())
# print(labels)
# batched_graph, labels = batched_graph.to(device), labels.to(device)

Graph(num_nodes=83, num_edges=6806,
      ndata_schemes={'xdata': Scheme(shape=(9,), dtype=torch.float32)}
      edata_schemes={})

In [8]:
model = GIN(9, 64, 2).to(device)
maxtpoch = 100
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, maxtpoch
)
m = nn.Softmax(dim=1)

for epoch in range(maxtpoch):
    sumloss = list()
    print(epoch)
    model.train()
    for batched_graph, labels in train_dataloader:
        if(batched_graph.num_edges()>2e8):
                continue
        batched_graph, labels = batched_graph.to(device), labels.to(device)
        pred = model(batched_graph, batched_graph.ndata["xdata"])
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        sumloss.append(loss.item())
        optimizer.step()
    lr_scheduler.step()
    print(f"loss:{np.mean(sumloss):.4f}")
    y_pred = list()
    y_orgin = list()
    model.eval()
    with torch.no_grad():
        for batched_graph, labels in val_dataloader:
            if(batched_graph.num_edges()>2e8):
                continue
            batched_graph, labels = batched_graph.to(device), labels.to(device)
            pred = model(batched_graph, batched_graph.ndata["xdata"])
           
            y_pred.append(pred[:, 1].cpu().numpy())
            y_orgin.append(labels.cpu().numpy())
    y_pred = np.concatenate(y_pred)
    y_orgin = np.concatenate(y_orgin)
    auc = metrics.roc_auc_score(y_orgin, y_pred)
    print(f"auc:{auc:.4f}")

0


  assert input.numel() == input.storage().size(), (


loss:0.5017
auc:0.8922
1
loss:0.3481
auc:0.9237
2
loss:0.3412
auc:0.9200
3
loss:0.3300
auc:0.8973
4
loss:0.3323
auc:0.9256
5
loss:0.3172
auc:0.9270
6
loss:0.3140
auc:0.9036
7
loss:0.3057
auc:0.9326
8
loss:0.3069
auc:0.9385
9
loss:0.3030
auc:0.9220
10
loss:0.3022
auc:0.9198
11
loss:0.2936
auc:0.9073
12
loss:0.2952
auc:0.9428
13
loss:0.2930
auc:0.9332
14
loss:0.3027
auc:0.9319
15
loss:0.2929
auc:0.9267
16
loss:0.2858
auc:0.9371
17
loss:0.2858
auc:0.9097
18
loss:0.2801
auc:0.9278
19
loss:0.2828
auc:0.9402
20
loss:0.2818
auc:0.9437
21
loss:0.2746
auc:0.9332
22
loss:0.2738
auc:0.9172
23
loss:0.2774
auc:0.9146
24
loss:0.2736
auc:0.9279
25
loss:0.2730
auc:0.9127
26
loss:0.2687
auc:0.9240
27
loss:0.2711
auc:0.9173
28
loss:0.2677
auc:0.9405
29
loss:0.2664
auc:0.9195
30
loss:0.2687
auc:0.9387
31
loss:0.2648
auc:0.9323
32
loss:0.2662
auc:0.9383
33
loss:0.2657
auc:0.9102
34
loss:0.2661
auc:0.9298
35
loss:0.2657
auc:0.9162
36
loss:0.2645
auc:0.9261
37
loss:0.2617
auc:0.9254
38
loss:0.2636
auc:0.940