In [1]:
# TEST IPYNB FOR DAGNNs

In [2]:
# ML Imports
import numpy as np
import numpy.ma as ma
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib as mpl
import matplotlib.pyplot as plt

# DGL Graph Learning Imports
from dgl import save_graphs, load_graphs, batch
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader
from dgl.data.utils import save_info, load_info

# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as Functional
import torch.optim as optim
from dgl.data.utils import Subset

# PyTorch Ignite Imports
from ignite.engine import Engine, Events, EventEnum, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.contrib.metrics import ROC_AUC, RocCurve
from ignite.contrib.handlers.tensorboard_logger import *
from ignite.handlers import global_step_from_engine, EarlyStopping

# Miscellaneous Imports
import os

Using backend: pytorch


In [3]:
#--------------------------------------------------#
# Model Definitions
# - ApplyNodeFunc
# - MLP
# - GIN
#--------------------------------------------------#

from __future__ import absolute_import, division, print_function

# DGL Graph Learning Imports
import dgl
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling

# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as Functional

# GIN ARCHITECTURE

"""
How Powerful are Graph Neural Networks
https://arxiv.org/abs/1810.00826
https://openreview.net/forum?id=ryGs6iA5Km
Author's implementation: https://github.com/weihua916/powerful-gnns
"""

class ApplyNodeFunc(nn.Module):
    """Update the node feature hv with MLP, BN and ReLU."""
    def __init__(self, mlp):
        super(ApplyNodeFunc, self).__init__()
        self.mlp = mlp
        self.bn = nn.BatchNorm1d(self.mlp.output_dim)

    def forward(self, h):
        h = self.mlp(h)
        h = self.bn(h)
        h = Functional.relu(h)
        return h


class MLP(nn.Module):
    """MLP with linear output"""
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        """MLP layers construction

        Paramters
        ---------
        num_layers: int
            The number of linear layers
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction

        """
        super(MLP, self).__init__()
        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.output_dim = output_dim

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for i in range(self.num_layers - 1):
                h = Functional.relu(self.batch_norms[i](self.linears[i](h)))
            return self.linears[-1](h)


class GIN(nn.Module):
    """GIN model"""
    def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
                 output_dim, final_dropout, learn_eps, graph_pooling_type,
                 neighbor_pooling_type):
        """model parameters setting

        Paramters
        ---------
        num_layers: int
            The number of linear layers in the neural network
        num_mlp_layers: int
            The number of linear layers in mlps
        input_dim: int
            The dimensionality of input features
        hidden_dim: int
            The dimensionality of hidden units at ALL layers
        output_dim: int
            The number of classes for prediction
        final_dropout: float
            dropout ratio on the final linear layer
        learn_eps: boolean
            If True, learn epsilon to distinguish center nodes from neighbors
            If False, aggregate neighbors and center nodes altogether.
        neighbor_pooling_type: str
            how to aggregate neighbors (sum, mean, or max)
        graph_pooling_type: str
            how to aggregate entire nodes in a graph (sum, mean or max)

        """
        super(GIN, self).__init__()
        self.num_layers = num_layers
        self.learn_eps = learn_eps

        # List of MLPs
        self.ginlayers = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(self.num_layers - 1):
            if layer == 0:
                mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)

            self.ginlayers.append(
                GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Linear function for graph poolings of output of each layer
        # which maps the output of different layers into a prediction score
        self.linears_prediction = torch.nn.ModuleList()

        for layer in range(num_layers):
            if layer == 0:
                self.linears_prediction.append(
                    nn.Linear(input_dim, output_dim))
            else:
                self.linears_prediction.append(
                    nn.Linear(hidden_dim, output_dim))

        self.drop = nn.Dropout(final_dropout)

        if graph_pooling_type == 'sum':
            self.pool = SumPooling()
        elif graph_pooling_type == 'mean':
            self.pool = AvgPooling()
        elif graph_pooling_type == 'max':
            self.pool = MaxPooling()
        else:
            raise NotImplementedError

    def forward(self, g, key='data'):
        # list of hidden representation at each layer (including input)
        h = g.ndata[key].float()
        hidden_rep = [h]

        for i in range(self.num_layers - 1):
            h = self.ginlayers[i](g, h)
            h = self.batch_norms[i](h)
            h = Functional.relu(h)
            hidden_rep.append(h)

        score_over_layer = 0

        # perform pooling over all nodes in each graph in every layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linears_prediction[i](pooled_h))

        return score_over_layer

    @property
    def name(self):
        """Name of model."""
        return "GIN"
    
class Classifier(nn.Module):
    """
        Classifier
    """
    def __init__(self, input_size=512, num_classes=10):
        super(Classifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )
        
    def forward(self, h):
        c = self.layer(h)
        return c
    
class Discriminator(nn.Module):
    """
        Simple Discriminator w/ MLP
    """
    def __init__(self, input_size=512, num_classes=1):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, num_classes),
            nn.Sigmoid(),
        )
    
    def forward(self, h):
        y = self.layer(h)
        return y
    
# Define dataset class
class GraphDataset(DGLDataset):
    _url = None
    _sha1_str = None
    mode = "mode"
    num_classes = 2
    dataset = None

    def __init__(self, name, dataset=None, raw_dir=None, force_reload=False, verbose=False, num_classes=2):
        self.dataset = dataset
        super(GraphDataset, self).__init__(name=name,
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose
                                          )
        self.num_classes = num_classes

    def process(self):
        mat_path = os.path.join(self.raw_path,self.mode+'_dgl_graph.bin')
        # process data to a list of graphs and a list of labels
        if self.dataset != None:
            self.graphs, self.labels = self.dataset["data"], torch.LongTensor(self.dataset["target"])
        else:
            self.graphs, self.labels = load_graphs(mat_path)

    def __getitem__(self, idx):
        """ Get graph and label by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        """Number of graphs in the dataset"""
        return len(self.graphs)

    def save(self):
        # save graphs and labels
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
        save_graphs(graph_path, self.graphs, {'labels': self.labels})
        # save other information in python dict
        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
        save_info(info_path, {'num_classes': self.num_classes})
    
    def load(self):
        # load processed data from directory `self.save_path`
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
        self.graphs, label_dict = load_graphs(graph_path)
        self.labels = label_dict['labels']
        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
        self.num_classes = load_info(info_path)['num_classes']

    def has_cache(self):
        # check whether there are processed data in `self.save_path`
        graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
        info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
        return os.path.exists(graph_path) and os.path.exists(info_path)
    
    @property
    def num_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 2
    

def load_graph_dataset(dataset="",prefix="",split=0.75,max_events=1e5,batch_size=1024,drop_last=False,shuffle=True,num_workers=0,pin_memory=True, verbose=True):

    # Load training data
    train_dataset = GraphDataset(prefix+dataset) # Make sure this is copied into ~/.dgl folder
    train_dataset.load()
    num_labels = train_dataset.num_labels
    node_feature_dim = train_dataset.graphs[0].ndata["data"].shape[-1]
    index = int(min(len(train_dataset),max_events)*split)
    train_dataset = Subset(train_dataset,range(index))

    # Create training dataloader
    train_loader = GraphDataLoader(
        train_dataset,
        batch_size=batch_size,
        drop_last=drop_last,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=num_workers)

    # Load validation data
    val_dataset = GraphDataset(prefix+dataset) # Make sure this is copied into ~/.dgl folder
    val_dataset.load()
    val_dataset = Subset(val_dataset,range(index,len(val_dataset)))

    # Create testing dataloader
    val_loader = GraphDataLoader(
        val_dataset,
        batch_size=batch_size,
        drop_last=drop_last,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers)

    return train_loader, val_loader, num_labels, node_feature_dim   

In [4]:
# Training routine #TODO: Put this in train_step and val_step functions.

device = 'cpu'
dataset = 'gangelmc_100k_2021-07-28_noEtaOldChi2_addKin_train'
dataset = 'test_dataset' #DEBUGGING

batch_size=1024

# Initialize data loaders
train_loader, val_loader, num_classes, node_feature_dim = load_graph_dataset(
                                                                            dataset=dataset,
                                                                            prefix="",
                                                                            split=0.75,
                                                                            max_events=1e5,
                                                                            batch_size=1024,
                                                                            drop_last=False,
                                                                            shuffle=True,
                                                                            num_workers=0,
                                                                            pin_memory=True,
                                                                            verbose=True
                                                                           )
# Get domain data loader
dataset2 = "gangelmc_100k_2021-07-28_noEtaOldChi2_addKin_test"
domain_loader, _, num_classes2, node_feature_dim2 = load_graph_dataset(
                                                                            dataset=dataset2,
                                                                            prefix="",
                                                                            split=1.00,
                                                                            max_events=1e5,
                                                                            batch_size=1024,
                                                                            drop_last=False,
                                                                            shuffle=True,
                                                                            num_workers=0,
                                                                            pin_memory=True,
                                                                            verbose=True
                                                                           )
print("num_classes = ",num_classes)
print("node_feature_dim = ",node_feature_dim)
print("num_classes2 = ",num_classes2)
print("node_feature_dim2 = ",node_feature_dim2)


# Initialize GIN parameters
num_layers = 3
num_mlp_layers = 3
input_dim = node_feature_dim
hidden_dim = 64
output_dim = 64
final_dropout = 0.8
learn_eps = False
graph_pooling_type = "max"
neighbor_pooling_type = "max"

num_domains = 1


# Set models
F = GIN(num_layers, num_mlp_layers, input_dim, hidden_dim,
                 output_dim, final_dropout, learn_eps, graph_pooling_type,
                 neighbor_pooling_type).to(device)
C = Classifier(input_size=output_dim,num_classes=num_classes).to(device)
D = Discriminator(input_size=output_dim,num_classes=num_domains).to(device)

# Initialize losses
bce = nn.BCELoss()
xe  = nn.CrossEntropyLoss()

# Initialize optimizers
F_opt = torch.optim.Adam(F.parameters())
C_opt = torch.optim.Adam(C.parameters())
D_opt = torch.optim.Adam(D.parameters())

# Set training parameters
max_epoch = 50
step = 1
n_critic = 1 # for training more k steps about Discriminator
n_batches = len(train_loader)//batch_size
# lamda = 0.01

# Set classification and domain labels
D_src = torch.ones(batch_size, 1).to(device) # Discriminator Label to real
D_tgt = torch.zeros(batch_size, 1).to(device) # Discriminator Label to fake
D_labels = torch.cat([D_src, D_tgt], dim=0)

print("DONE")


num_classes =  2
node_feature_dim =  6
num_classes2 =  2
node_feature_dim2 =  8
DONE


In [33]:
# Get domain data loader
# dataset2 = "gangelmc_100k_2021-07-28_noEtaOldChi2_addKin_test"
# domain_loader, _, num_classes2, node_feature_dim2 = load_graph_dataset(
#                                                                             dataset=dataset2,
#                                                                             prefix="",
#                                                                             split=1.00,
#                                                                             max_events=1e5,
#                                                                             batch_size=1024,
#                                                                             drop_last=False,
#                                                                             shuffle=True,
#                                                                             num_workers=0,
#                                                                             pin_memory=True,
#                                                                             verbose=True
#                                                                            )
print("num_classes = ",num_classes)
print("node_feature_dim = ",node_feature_dim)
print("num_classes2 = ",num_classes2)
print("node_feature_dim2 = ",node_feature_dim2)


num_classes =  2
node_feature_dim =  8
num_classes2 =  2
node_feature_dim2 =  8


In [40]:
step=1
if n_batches<1: n_batches = 1
# Function to decrease lambda with epoch
def get_lambda(epoch, max_epoch):
    p = epoch / max_epoch
    return 2. / (1+np.exp(-10.*p)) - 1.

# Function to continuously sample target domain data
domain_set = iter(domain_loader)
def sample_domain(step, n_batches):
    global domain_set
    if step % n_batches == 0:
        domain_set = iter(domain_loader)
    return domain_set.next()

# Not really sure yet...
ll_c, ll_d = [], []
acc_lst = []

for epoch in range(1, max_epoch+1):
    for idx, (src_images, labels) in enumerate(train_loader):
        tgt_images, _ = sample_domain(step, n_batches)
        # Training Discriminator
        src, labels_, tgt = src_images.to(device), labels.to(device), tgt_images.to(device)
        
        # Important: since we have kinematics added the labels are actually just first entries.
        labels = labels_[:,0].clone().detach().long()
#         print(labels_[0,:10])
#         print(labels_[1,:10])
        
        x = dgl.unbatch(src)
#       #print(type(x))#DEBUGGING
        for el in dgl.unbatch(tgt):
            x.append(el)#OLD: torch.cat([src, tgt], dim=0)
        x = dgl.batch(x)
        h = F(x)
        y = D(h.detach())
        
        Ld = bce(y, D_labels)
        D.zero_grad()
        Ld.backward()
        D_opt.step()
        
        
        c = C(h[:batch_size])
        y = D(h)
#         print(np.shape(labels))#DEBUGGING
        Lc = xe(c, labels)
        Ld = bce(y, D_labels)
        lamda = 0.1*get_lambda(epoch, max_epoch)
        Ltot = Lc -lamda*Ld
        
        
        F.zero_grad()
        C.zero_grad()
        D.zero_grad()
        
        Ltot.backward()
        
        C_opt.step()
        F_opt.step()
        
        if step %  == 0
            dt = datetime.datetime.now().strftime('%H:%M:%S')
            print('Epoch: {}/{}, Step: {}, D Loss: {:.4f}, C Loss: {:.4f}, lambda: {:.4f} ---- {}'.format(epoch, max_epoch, step, Ld.item(), Lc.item(), lamda, dt))
            ll_c.append(Lc)
            ll_d.append(Ld)
        
        if step % 5 == 0:
            F.eval()
            C.eval()
            with torch.no_grad():
                corrects = torch.zeros(1).to(device)
                for idx, (src, labels) in enumerate(val_loader):
                    src, labels = src.to(device), labels.to(device)
                    c = C(F(src))
                    _, preds = torch.max(c, 1)
                    corrects += (preds == labels).sum()
                acc = corrects.item() / len(val_loader.dataset)
                print('***** Eval Result: {:.4f}, Step: {}'.format(acc, step))
                
                corrects = torch.zeros(1).to(device)
                for idx, (tgt, labels) in enumerate(vdomain_loader):
                    tgt, labels = tgt.to(device), labels.to(device)
                    c = C(F(tgt))
                    _, preds = torch.max(c, 1)
                    corrects += (preds == labels).sum()
                acc = corrects.item() / len(domain_loader.dataset)
                print('***** Test Result: {:.4f}, Step: {}'.format(acc, step))
                acc_lst.append(acc)
                
            F.train()
            C.train()
        step += 1




ValueError: Using a target size (torch.Size([2048, 1])) that is different to the input size (torch.Size([2045, 1])) is deprecated. Please ensure they have the same size.

In [None]:
print(step)

In [None]:
n_batches

In [None]:
print(len(train_loader))
print(batch_size)