In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, NamedTuple, Optional
import scipy.sparse as sp
# from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, precision_score, recall_score
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from tqdm import tqdm

import sys
sys.path.append('../graph_sage')
import parser
import dataset

import datetime
from datetime import timedelta

from custom_parser import get_parser

import numpy as np 
import pandas as pd 
import torch
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder

import torch_geometric
from torch_geometric.utils import from_networkx, to_undirected
from torch_geometric.data import Data, DataLoader, Dataset, NeighborSampler
from torch_geometric.nn import SAGEConv, MessagePassing

from torch_cluster import random_walk

from torch import Tensor
# from torch_geometric.data import InMemoryDataset

from tqdm import tqdm, tqdm_notebook, trange
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.linear_model import LogisticRegression

from collections import defaultdict
import random
from xgboost import XGBClassifier
%config Completer.use_jedi = False


In [16]:
data = dataset.Tdata(path='../../tdata.csv')
parser = get_parser()
args = parser.parse_args(args=
                         ["--data","real-t", 
                          "--sampling","xgb",
                          "--mode","scratch",
                          "--train_from","20170101",
                          "--test_from","20190101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "5",
                          "--final_inspection_rate", "10",
                         ])

In [17]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
mode = args.mode
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (77391, 41), Train unlabeled: (1470434, 41), Valid labeled: (134457, 41), Valid unlabeled: (0, 13), Test: (703090, 41)
Checking label distribution
Training: 0.09757342825942052
Validation: 0.09589052260946108
Testing: 0.10476480792437651


In [18]:
sys.path.append('../graph_sage')
from utils import *
from pygData_util import *

In [19]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories)

Training XGBoost model...


In [20]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 1345
Precision: 0.6349, Recall: 0.0726, Revenue: 0.0725
Checking top 2% suspicious transactions: 2690
Precision: 0.5520, Recall: 0.1262, Revenue: 0.1187
Checking top 5% suspicious transactions: 6723
Precision: 0.4261, Recall: 0.2435, Revenue: 0.2367
Checking top 10% suspicious transactions: 13446
Precision: 0.3139, Recall: 0.3588, Revenue: 0.3490
--------------------------------------------------
Checking top 1% suspicious transactions: 7031
Precision: 0.7042, Recall: 0.0743, Revenue: 0.1217
Checking top 2% suspicious transactions: 14062
Precision: 0.6039, Recall: 0.1274, Revenue: 0.2024
Checking top 5% suspicious transactions: 35155
Precision: 0.4473, Recall: 0.2359, Revenue: 0.3511
Checking top 10% suspicious transactions: 70309
Precision: 0.3258, Recall: 0.3436, Revenue: 0.4705


In [21]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx
trainLab_data

Data(edge_attr=[154782], edge_index=[2, 309564], edge_label=[309564], node_idx=[77391], rev=[92606], x=[92606, 100], y=[92606])

In [22]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx
unlab_data

Data(edge_attr=[2940868], edge_index=[2, 5881736], edge_label=[5881736], node_idx=[1470434], rev=[1515745], x=[1515745, 100], y=[1515745])

In [23]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx
valid_data

Data(edge_attr=[268914], edge_index=[2, 537828], edge_label=[537828], node_idx=[134457], rev=[137035], x=[137035, 100], y=[137035])

In [24]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx
test_data

Data(edge_attr=[1406180], edge_index=[2, 2812360], edge_label=[2812360], node_idx=[703090], rev=[718094], x=[718094, 100], y=[718094])

In [25]:
torch.save(trainLab_data, 'train_lab_data.pt')
torch.save(unlab_data, 'train_unlab_data.pt')
torch.save(valid_data, 'valid_data.pt')
torch.save(test_data, 'test_data.pt')

In [27]:
# trainLab_data, gdata.leaf_dim
# gdata.leaf_dim # 1562

In [310]:
# torch.save(trainLab_data, 'train_data.pt')
from data import StackData


In [8]:
train_lab_data = torch.load('train_lab_data.pt')
train_lab_data

train_unlab_data = torch.load('train_unlab_data.pt')
train_unlab_data

valid_data = torch.load('valid_data.pt')
valid_data

test_data = torch.load('test_data.pt')
test_data.y[test_data.node_idx]

tensor([0., 0., 0.,  ..., 0., 1., 0.])

In [13]:
test_data

Data(edge_attr=[1716360], edge_index=[2, 3432720], edge_label=[3432720], node_idx=[858180], rev=[901165], x=[901165, 100], y=[901165])

# Data

In [16]:
class Batch(NamedTuple):
    '''
    convert batch data for pytorch-lightning
    '''
    x: Tensor
    y: Tensor
    rev: Tensor
    adjs_t: NamedTuple

    def to(self, *args, **kwargs):
        return Batch(
            x=self.x.to(*args, **kwargs),
            y=self.y.to(*args, **kwargs),
            rev=self.rev.to(*args, **kwargs),
            adjs_t=[(adj_t.to(*args, **kwargs), eid.to(*args, **kwargs), size) for adj_t, eid, size in self.adjs_t],
        )
    
class UnsupNeighborSampler(NeighborSampler):
    
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(UnsupNeighborSampler, self).sample(batch)
    
class CustomData(LightningDataModule):
    def __init__(self,train_data, valid_data, sizes, batch_size):
        super(CustomData,self).__init__()

        self.train_data = train_data
        self.valid_data = valid_data
        
        self.sizes = sizes
        self.batch_size = batch_size

    def train_dataloader(self):
        return UnsupNeighborSampler(self.train_data.edge_index, node_idx=self.train_data.node_idx,
                               sizes=self.sizes, return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=True,
                               drop_last=True,
                               transform=self.convert_train_batch,
                               num_workers=16)

    def val_dataloader(self):
        return NeighborSampler(self.valid_data.edge_index, 
                               node_idx=self.valid_data.node_idx,
                               sizes=self.sizes, 
                               return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=False,
                               drop_last=True,
                               transform=self.convert_valid_batch
                              )

    def convert_train_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.train_data.x[n_id],
            y=self.train_data.y[n_id[:batch_size]],
            rev = self.train_data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )
    
    
    def convert_valid_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.valid_data.x[n_id],
            y=self.valid_data.y[n_id[:batch_size]],
            rev = self.valid_data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

# Unsupervised Graph Sage Model

In [11]:
class SAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, leaf_len):
        super(SAGE, self).__init__()
        
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        
        self.emb = nn.Embedding(leaf_len, in_channels, padding_idx=0)
#         self.bn = nn.BatchNorm1d(in_channels)
        self.ln = nn.LayerNorm(in_channels)

        for i in range(num_layers):
            in_channels = in_channels if i == 0 else hidden_channels
            self.convs.append(SAGEConv(in_channels, hidden_channels))

        
    def forward(self, x, adjs):
        
        x = self.emb(x)

        x = torch.sum(x,dim=1) # summation over the leaves
#         x = F.relu(self.bn(x))
        x = F.relu(self.ln(x))
        
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
                
        return x

#     def full_forward(self, x, edge_index):
        
#         x = self.emb(x)

#         x = torch.sum(x,dim=1) # summation over the leaves
#         x = F.relu(self.bn(x))
        
#         for i, conv in enumerate(self.convs):
#             x = conv(x, edge_index)
#             if i != self.num_layers - 1:
#                 x = x.relu()
#                 x = F.dropout(x, p=0.5, training=self.training)
#         return x

In [33]:
from models import SAGE

class SageModel(pl.LightningModule):
    def __init__(self):
        super(SageModel, self).__init__()
        
        self.sage = SAGE( 
            in_channels=128, 
            hidden_channels=128, 
            num_layers=2, 
            leaf_len=1562)
        
        self.X = None
        self.y = None
        
    def forward(self, x, adjs):
        
        return self.sage(x, adjs)
        
        
    def training_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        out = self.forward(x, adjs)
        
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        out_y = y[:out.size(0)]

        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        loss = -pos_loss - neg_loss
        
#         print(out.shape, y.shape)
        if not self.X is None:
            self.X = torch.cat([self.X, out.detach().cpu()])
            self.y = torch.cat([self.y, out_y.flatten().detach().cpu()])
        else:
            self.X = out.detach().cpu()
            self.y = out_y.flatten().detach().cpu()
        
        return loss

    def validation_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        features = self.forward(x, adjs)
                
        return {'features':features, 'y': y.flatten()}
        
    def validation_epoch_end(self, outputs):
        
        valid_X = torch.cat([x["features"] for x in outputs], dim=0).detach().cpu()
        y_true = torch.cat([x["y"] for x in outputs], dim=0).detach().cpu()
        
        clf = LogisticRegression(max_iter=1e3)
        clf.fit(self.X, self.y)
        
        y_pred = clf.predict(valid_X)
        
        f1 = f1_score(y_true, y_pred, average='macro')
        
        self.log('f1', f1, prog_bar=True)
        
        self.X = None
        self.y = None
            
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None, on_tpu=False, using_native_amp=True, using_lbfgs=False):
                
        optimizer.step(closure=second_order_closure)

        self.scheduler.step(current_epoch + batch_nb / len(data_loader.train_dataloader()))
        
    
    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-6)
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs, T_mult=1, eta_min=0, last_epoch=-1, verbose=True)
                
        return optimizer

In [39]:
m = SageModel.load_from_checkpoint(checkpoint_path='unsup_checkpoints/epoch=01-f1=0.6181.ckpt')
torch.save(m.sage.state_dict(), 'large_neighb_unsup_sage.pt')

# Unsupervised training of Graph Sage

In [92]:
leaf_len = 1549

batch_size = 256
sizes = [-1, 200]
epochs=100

data_loader = CustomData(train_lab_data, valid_lab_data, sizes, batch_size)

model = SageModel( 
    in_channels=128, 
    hidden_channels=64, 
    num_layers=2, 
    leaf_len=leaf_len)

trainer = Trainer(
    gpus=[0],
    progress_bar_refresh_rate=0,
    num_sanity_val_steps=0,
    max_epochs=epochs)
trainer.fit(model, data_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name | Type | Params
------------------------------
0 | sage | SAGE | 223 K 
------------------------------
223 K     Trainable params
0         Non-trainable params
223 K     Total params


1

# Supervised data 

In [40]:
class SupData(LightningDataModule):
    def __init__(self,train_data, valid_data, sizes, batch_size):
        super(SupData,self).__init__()

        self.train_data = train_data
        self.valid_data = valid_data
        
        self.sizes = sizes
        self.batch_size = batch_size

    def train_dataloader(self):
        return NeighborSampler(self.train_data.edge_index, node_idx=self.train_data.node_idx,
                               sizes=self.sizes, return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=True,
                               drop_last=True,
                               transform=self.convert_train_batch,
                               num_workers=16)

    def val_dataloader(self):
        return NeighborSampler(self.valid_data.edge_index, 
                               node_idx=self.valid_data.node_idx,
                               sizes=self.sizes, 
                               return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=False,
                               drop_last=True,
                               transform=self.convert_valid_batch
                              )

    def convert_train_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.train_data.x[n_id],
            y=self.train_data.y[n_id[:batch_size]],
            rev = self.train_data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )
    
    
    def convert_valid_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.valid_data.x[n_id],
            y=self.valid_data.y[n_id[:batch_size]],
            rev = self.valid_data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

# LaGNN + GraphSAGE

In [17]:
class LabelPredictor(nn.Module):
    def __init__(self, in_channels, **kwargs):
        super(LabelPredictor, self).__init__()
        
        self.linear = nn.Linear(in_channels * 3, 1)
        
    def forward(self, x): 
      
        emb_a = x
        emb_b = x.roll(1,0)
        
        emb_abs = torch.abs(emb_a - emb_b)
        emb_sum = emb_a + emb_b
        emb_mult = emb_a * emb_b
        
        x = torch.cat([emb_abs, emb_sum, emb_mult], dim=-1)

        x = self.linear(x)
        x = torch.sigmoid(x)
        
        return x
    
    def edge_forward(self, emb_a, emb_b):
                
        emb_abs = torch.abs(emb_a - emb_b)
        emb_sum = emb_a + emb_b
        emb_mult = emb_a * emb_b
        
        x = torch.cat([emb_abs, emb_sum, emb_mult], dim=-1)
        
        x = self.linear(x)
        x = torch.sigmoid(x)
        
        return x
     

In [43]:

    
class LabelPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(LabelPredictor, self).__init__()
        
        self.lin1 = nn.Linear(in_channels * 3, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, 1)
        
        self.ln = nn.LayerNorm(hidden_channels)

#         self.linear = nn.Linear(in_channels * 3, 1)
        
    def forward(self, x): 
      
        emb_a = x

        batch_size = x.size(0)
        index = torch.randperm(batch_size)
        
        emb_b = x[index]
        
        emb_a = emb_a.detach()
        emb_b = emb_b.detach()
        
        emb_abs = torch.abs(emb_a - emb_b)
        emb_sum = emb_a + emb_b
        emb_mult = emb_a * emb_b
        
        x = torch.cat([emb_abs, emb_sum, emb_mult], dim=-1)

        x = F.relu(self.ln(self.lin1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = torch.sigmoid(self.lin2(x))


#         x = self.linear(x)
#         x = torch.sigmoid(x)
        
        return x, index
    
    def edge_forward(self, emb_a, emb_b):
                
        emb_abs = torch.abs(emb_a - emb_b)
        emb_sum = emb_a + emb_b
        emb_mult = emb_a * emb_b
        
        x = torch.cat([emb_abs, emb_sum, emb_mult], dim=-1)
        
        x = F.relu(self.lin1(x))
        x = torch.sigmoid(self.lin2(x))
        
        return x

class LabelModel(pl.LightningModule):
    def __init__(self):
        super(LabelModel, self).__init__()
        
        self.unsup_sage = SAGE( 
            in_channels=128, 
            hidden_channels=128, 
            num_layers=2, 
            leaf_len=1562)
        self.unsup_sage.load_state_dict(torch.load('large_neighb_unsup_sage.pt'))
        
        for p in self.unsup_sage.parameters():
            p.requires_grad = False
        
        
        self.label_predictor = LabelPredictor(in_channels=128, hidden_channels=128)
        
    def forward(self, x, adjs):
        
#         with torch.no_grad():
        x = self.unsup_sage(x, adjs)        
        return self.label_predictor(x)
                
    def training_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        y_pred, index = self.forward(x, adjs)

        y_true = (y == y[index]).float()

        loss = F.binary_cross_entropy(y_pred.flatten(), y_true.flatten())
        
        return loss

    def validation_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        threshold = 0.5
        
        y_pred, index = self.forward(x, adjs)

        y_true = (y == y[index]).float()

        val_loss = F.binary_cross_entropy(y_pred.flatten(), y_true.flatten())
        
        y_pred[y_pred >= threshold] = 1
        y_pred[y_pred < threshold] = 0
                
        return {'y_pred':y_pred.flatten(), 'y_true': y_true.flatten(), 'val_loss': val_loss.item()}
        
    def validation_epoch_end(self, outputs):
        
        val_loss = np.mean([x["val_loss"] for x in outputs])
        y_pred = torch.cat([x["y_pred"] for x in outputs], dim=0).detach().cpu()
        y_true = torch.cat([x["y_true"] for x in outputs], dim=0).detach().cpu()
        
        
        f1 = f1_score(y_true, y_pred, average='macro')
        acc = accuracy_score(y_true, y_pred)
        
        self.log('f1', f1, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_loss', val_loss, prog_bar=True)
        
            
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None, on_tpu=False, using_native_amp=True, using_lbfgs=False):
                
        optimizer.step(closure=second_order_closure)

        self.scheduler.step(current_epoch + batch_nb / len(data_loader.train_dataloader()))
        
    
    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-6)
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs, T_mult=1, eta_min=0, last_epoch=-1, verbose=True)
                
        return optimizer


m = LabelModel.load_from_checkpoint('label_checkpoints/epoch=03-f1=0.6386.ckpt')

# torch.save(m.unsup_sage.state_dict(), 'la_sage.pt')
torch.save(m.label_predictor.state_dict(), 'la_predictor.pt')

In [216]:
class LabelModel(pl.LightningModule):
    def __init__(self, in_channels, hidden_channels, **kwargs):
        super(LabelModel, self).__init__()
        
        self.unsup_sage = = SAGE( 
            in_channels=128, 
            hidden_channels=128, 
            num_layers=2, 
            leaf_len=leaf_len)
        
        self.label_predictor = LabelPredictor()
        
    def forward(self, x, adjs):
        
        x = self.unsup_sage(x, adjs)        
        x = self.label_predictor(x)
        
        return x
        
    def training_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        y_pred = self.forward(x, adjs).flatten()

        y_true = (y_true == y.roll(1,0)).long().flatten()

        loss = F.binary_cross_entropy(y_pred, y_true)
        
        return loss

    def validation_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        threshold = 0.5
        
        y_pred = self.forward(x, adjs)
        y_pred[y_pred >= threshold] = 1
        y_pred[y_pred < threshold] = 0
                
        y_true = (y_true == y.roll(1,0)).long()
                
        return {'y_pred':y_pred.flatten(), 'y_true': y.flatten()}
        
    def validation_epoch_end(self, outputs):
        
        y_pred = torch.cat([x["y_pred"] for x in outputs], dim=0).detach().cpu()
        y_true = torch.cat([x["y_true"] for x in outputs], dim=0).detach().cpu()
        
        
        f1 = f1_score(y_true, y_pred)
        
        self.log('f1', f1, prog_bar=True)
        
            
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None, on_tpu=False, using_native_amp=True, using_lbfgs=False):
                
        optimizer.step(closure=second_order_closure)

        self.scheduler.step(current_epoch + batch_nb / len(data_loader.train_dataloader()))
        
    
    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=1, eta_min=0, last_epoch=-1, verbose=True)
                
        return optimizer


In [None]:
def StackData(train_data, unlab_data, valid_data):
    '''
    stack pyG dataset.
    because the valid/test data should include train/unlab edges
    '''
    stack = Data()
    x, y, edge_index, edge_label, rev = [],[],[],[],[]
    
    # feature
    x.append(train_data.x)
    x.append(unlab_data.x)
    x.append(valid_data.x)
    x.append(test_data.x)
    x = torch.cat(x,dim=0)
    stack.x = x
    
    # target
    y.append(train_data.y)
    y.append(unlab_data.y)
    y.append(valid_data.y)
    y.append(test_data.y)
    y = torch.cat(y,dim=-1)
    stack.y = y
    
    # revenue
    rev.append(train_data.rev)
    rev.append(unlab_data.rev)
    rev.append(valid_data.rev)
    rev.append(test_data.rev)
    rev = torch.cat(rev,dim=-1)
    stack.rev = rev
    
    # edge index
    stack.train_edge = torch.cat((train_data.edge_index, unlab_data.edge_index), dim=1)
    stack.valid_edge = torch.cat((stack.train_edge,valid_data.edge_index ), dim=1)
    stack.test_edge = torch.cat((stack.valid_edge,test_data.edge_index ), dim=1)
    
    # transaction index
    stack.train_idx = train_data.node_idx
    stack.valid_idx = valid_data.node_idx
    stack.test_idx = test_data.node_idx
    
    return stack

In [56]:
class UnsupSageNeighborSampler(NeighborSampler):
    
    def __init__(self, x, y, sage, *args, **kwargs):
        super(UnsupSageNeighborSampler, self).__init__(*args, **kwargs)
        
        self.x = x
        self.y = y
        self.sage = sage
    
    def sample(self, batch):

        batch_size, n_id, org_adjs = super(UnsupSageNeighborSampler, self).sample(batch)
        
        _, unsup_n_id, adjs = super(UnsupSageNeighborSampler, self).sample(n_id)
        with torch.no_grad():
            unsup_emb = self.sage(self.x[unsup_n_id], adjs)
        
        x = self.x[n_id]
        x_unsup = unsup_emb.detach()
        y = self.y[n_id[:batch_size]]
        
        return x, x_unsup, y, org_adjs

class EmbedData(LightningDataModule):
    def __init__(self, sage, data, sizes, batch_size = 128):
        super(EmbedData,self).__init__()
        
        self.sage = sage
        self.data = data
        self.sizes = sizes
        self.valid_sizes = sizes
        self.batch_size = batch_size
        
        
    def train_dataloader(self):
        return UnsupSageNeighborSampler(
                               self.data.x,
                               self.data.y,
                               self.sage,
                               self.data.train_edge, 
                               node_idx=self.data.train_idx,
                               sizes=self.sizes, 
                               return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=True,
                               drop_last=True,
                               num_workers=16)

    def val_dataloader(self):
        return UnsupSageNeighborSampler(
                               self.data.x,
                               self.data.y,
                               self.sage,
                               self.data.valid_edge, 
                               node_idx=self.data.valid_idx,
                               sizes=self.sizes, 
                               return_e_id=True,
                               batch_size=self.batch_size,
                               shuffle=False,
                               drop_last=True,
                               num_workers=16)


In [57]:
sage = SAGE( 
    in_channels=128, 
    hidden_channels=128, 
    num_layers=2, 
    leaf_len=leaf_len)

sizes = [64, 32]
batch_size = 1024

stacked_data = StackData(train_lab_data,train_unlab_data,valid_data, test_data)
data_loader = EmbedData(sage, stacked_data, sizes, batch_size)

# ns = UnsupSageNeighborSampler(train_lab_data.x, train_lab_data.y, sage, train_lab_data.edge_index, node_idx=train_lab_data.node_idx,
#                                sizes=[-1,-1], return_e_id=True,
#                                batch_size=1,
#                                shuffle=True,
#                                drop_last=True,
# #                                transform=train_batch,
#                                num_workers=16)

# x, x_unsup, y, adjs = ns.sample([0])
# x.shape, x_unsup.shape


In [58]:
from torch_scatter import scatter 

class LAConv(SAGEConv):
    def __init__(self, label_predictor, in_channels, out_channels, **kwargs):
        super(LAConv, self).__init__(in_channels, out_channels, aggr='add', **kwargs)
        
        self.label_predictor = label_predictor
        
        # freeze label predictor
        for p in self.label_predictor.parameters():
            p.requires_grad = False
    
    def forward(self, x, edge_index):

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x)
        
        # split to x, and unsupervised embeddings
        out_unsup = out[:,self.in_channels:]
        out = self.lin_l(out[:,:self.in_channels])

        x_r = x[1][:,:self.in_channels]

        out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return torch.cat([out, out_unsup], dim=-1)
    
    def message(self, x_j, x_i, index):

        '''
        i - central node that aggregates 
        x_j has shape [E, out_channels]
        index - indicies of the center nodes for each element in x_j
        '''

        # predicting similarity between center node and each of its neighbor 
        
        x_unsup_j = x_j[:,self.in_channels:]
        x_j = x_j[:,:self.in_channels]
        
        x_unsup_i = x_i[:,self.in_channels:]
        
        with torch.no_grad():
            A_j = self.label_predictor.edge_forward(x_unsup_i, x_unsup_j)
                
        # sum all weights belonging to the same center node
        A_sum_scattered = scatter(A_j, index, dim=0, reduce="sum") 
        # unscattering sums
        A_sum = A_sum_scattered[index]
        # normalizing weights
        A_j = A_j / A_sum
        
        x_j = A_j * x_j
        x_j = torch.cat([x_j, x_unsup_j], dim=-1)
        
        return x_j
        

In [59]:
# graph sage with label aware aggregation 

class LA_SAGE(nn.Module):
    def __init__(self, unsup_channels, in_channels, hidden_channels, num_layers, leaf_len):
        super(LA_SAGE, self).__init__()
        
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        
        self.unsup_channels = unsup_channels
        
        self.label_predictor = LabelPredictor(unsup_channels)
        
        self.emb = nn.Embedding(leaf_len, in_channels, padding_idx=0)
        self.ln = nn.LayerNorm(in_channels)
            
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else hidden_channels
            self.convs.append(LAConv(self.label_predictor, in_channels, hidden_channels))
        
    def forward(self, x, x_unsup, adjs):
        
        x = self.emb(x)

        x = torch.sum(x,dim=1) # summation over the leaves
        x = F.relu(self.ln(x))
        
        x = torch.cat([x, x_unsup], dim=-1) # cat learned features with fixed unsup embeddings

        for i, (edge_index, _, size) in enumerate(adjs):

            x_target = x[:size[1]]  # Target nodes are always placed first.

            x = self.convs[i]((x, x_target), edge_index)
            
            x_unsup = x[:,-self.unsup_channels:]
            x = x[:,:-self.unsup_channels]
            

            if i != self.num_layers - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
                
            x = torch.cat([x, x_unsup], dim=-1)
                
        return x[:,:-self.unsup_channels]

# testing
leaf_len = 1550
sizes = [2,1]
batch_size = 1

la = LA_SAGE(
    unsup_channels=128,
    in_channels=128, 
    hidden_channels=128, 
    num_layers=2, 
    leaf_len=leaf_len)

# x, y, rev, adjs = next(iter(sub_data_module.train_dataloader()))

x, x_unsup, y, adjs = next(iter(data_loader.train_dataloader()))
la(x, x_unsup, adjs).shape

torch.Size([1024, 128])

In [None]:
class LA_SAGE_Model(pl.LightningModule):
    def __init__(self, X, Y, **kwargs):
        super(LA_SAGE_Model, self).__init__()
        
        self.sage = LA_SAGE(**kwargs)

        self.X = None
        self.y = None

    def forward(self, x, adjs):
        
        return self.sage(x, adjs)
        
    
    def full_forward(self, x, adjs):

        return self.sage.full_forward(x, adjs)
        
    def training_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        out = self.forward(x, adjs)
        
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)

        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        loss = -pos_loss - neg_loss
             
        return loss

    def validation_step(self, batch, batch_idx):
        
        x, y, rev, adjs = batch
        
        features = self.forward(x, adjs)
                
        return {'features':features, 'y': y.flatten()}
        
    def validation_epoch_end(self, outputs):
        
        valid_X = torch.cat([x["features"] for x in outputs], dim=0)
        valid_y = torch.cat([x["y"] for x in outputs], dim=0)
        
        print(self.y.shape)
        
        clf = LogisticRegression()
        clf.fit(self.X, self.y)
        
        score = clf.score(valid_X, valid_y)
        
        self.log('val_acc', score, prog_bar=True)
            
        self.X = None
        self.y = None
            
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [57]:
leaf_len = 1549

batch_size = 256
sizes = [-1, 200]

data_loader = CustomData(train_lab_data, valid_lab_data, sizes, batch_size)

model = LA_SAGE_Model(
    train_lab_data.x, 
    train_lab_data.y, 
    in_channels=128, 
    hidden_channels=64, 
    num_layers=2, 
    leaf_len=leaf_len)

trainer = Trainer(num_sanity_val_steps=0)
trainer.fit(model, data_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name | Type    | Params
---------------------------------
0 | sage | LA_SAGE | 223 K 
---------------------------------
223 K     Trainable params
0         Non-trainable params
223 K     Total params


Epoch 0:   0%|          | 0/1865 [00:00<?, ?it/s] torch.Size([59938, 128]) torch.Size([59938, 128])
torch.Size([59938, 384])
torch.Size([26962, 128]) torch.Size([26962, 128])
torch.Size([26962, 384])
Epoch 0:   0%|          | 0/1865 [00:00<?, ?it/s]


RuntimeError: mat1 dim 1 must match mat2 dim 0