In [1]:
import pandas as pd
import numpy as np
import tqdm
import torch

# Pre processing

In [2]:
import torch
from torch_geometric.data import Data, DataLoader
import os
import os.path as osp
import math
import argparse
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from graph_data import GraphDataset

gdata = GraphDataset(root='/anomalyvol/data/graph100particles',n_particles=100)

input_dim = 4
big_dim = 32
hidden_dim = 2
fulllen = len(gdata)
tv_frac = 0.10
tv_num = math.ceil(fulllen*tv_frac)
splits = np.cumsum([fulllen-2*tv_num,tv_num,tv_num])
batch_size = 32
n_epochs = 1
lr = 0.0001
patience = 10
device = 'cuda:0'
model_fname = 'EdgeNetPool'

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool

class EdgeNetPool(nn.Module):
    def __init__(self, input_dim=4, big_dim=32, hidden_dim=2, n_particles=100, aggr='mean'):
        self.input_dim = input_dim
        self.big_dim = big_dim
        self.hidden_dim = hidden_dim
        self.n_particles = n_particles
        self.aggr = aggr
        super(EdgeNetPool, self).__init__()
        encoder_nn = nn.Sequential(nn.Linear(2*(self.input_dim), self.big_dim),
                               nn.ReLU(),
                               nn.Linear(self.big_dim, self.big_dim),
                               nn.ReLU(),
        )
        
        decoder_nn = nn.Sequential(nn.Linear(2*self.big_dim, self.n_particles*self.big_dim),
                               nn.ReLU(),
                               nn.Linear(self.n_particles*self.big_dim, self.n_particles*self.input_dim)
        )
        
        self.batchnorm = nn.BatchNorm1d(input_dim)

        self.encoder = EdgeConv(nn=encoder_nn,aggr=aggr)
        self.decoder = EdgeConv(nn=decoder_nn,aggr=aggr)

    def forward(self, data):
        input_len = data.x.size()[0] // self.n_particles
        data.x = self.batchnorm(data.x)
        data.x = self.encoder(data.x, data.edge_index)
        data.u = global_mean_pool(data.x, data.batch)
        data.x = self.decoder(data.u, data.edge_index)
        data.x = torch.reshape(data.x, (input_len * self.n_particles, self.input_dim))
        return data.x

In [4]:
model = EdgeNetPool(input_dim=input_dim, big_dim=big_dim, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)

In [5]:
train_dataset = GraphDataset(root='/anomalyvol/data/graph100particles',n_particles=100,start=0,stop=splits[0])
valid_dataset = GraphDataset(root='/anomalyvol/data/graph100particles',n_particles=100,start=splits[1],stop=splits[2])
test_dataset = GraphDataset(root='/anomalyvol/data/graph100particles',n_particles=100,start=splits[0 ],stop=splits[1])

train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)

train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)

print(train_samples)
print(valid_samples)
print(test_samples)
print(train_dataset[0])

24043
2405
2405
Data(edge_index=[2, 420], x=[100, 4], y=[100, 4])


In [6]:
@torch.no_grad()
def test(model,loader,total,batch_size):
    model.eval()
    
    mse = nn.MSELoss(reduction='mean')

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data in t:
        data = data.to(device)
        batch_output = model(data)
        batch_loss_item = mse(batch_output, data.y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh() # to show immediately the update

    return sum_loss/(i+1)

def train(model, optimizer, loader, total, batch_size):
    model.train()
    
    mse = nn.MSELoss(reduction='mean')

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader),total=total/batch_size)
    for i,data in t:
        data = data.to(device)
        optimizer.zero_grad()
        print(data.y.size())
        print(data.x.size())
        batch_output = model(data)             # illegal memory error due to output
        print(batch_output.size())
        batch_loss = mse(batch_output, data.y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh() # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()
    
    return sum_loss/(i+1)

In [7]:
n_epochs = 20

In [None]:
stale_epochs = 0
best_valid_loss = 99999
for epoch in range(0, n_epochs):
    loss = train(model, optimizer, train_loader, train_samples, batch_size)
    valid_loss = test(model, valid_loader, valid_samples, batch_size)
    print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))
    print('               Validation Loss: {:.4f}'.format(valid_loss))

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join('/anomalyvol/models/gnn/',model_fname+'.best.pth')
        print('New best model saved to:',modpath)
        torch.save(model.state_dict(),modpath)
        stale_epochs = 0
    else:
        print('Stale epoch')
        stale_epochs += 1
    if stale_epochs >= patience:
        print('Early stopping after %i stale epochs'%patience)
        break

  0%|          | 0/751.34375 [00:00<?, ?it/s]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = 64045658964754432.00000:   0%|          | 2/751.34375 [00:11<1:22:30,  6.61s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   0%|          | 3/751.34375 [00:15<1:10:30,  5.65s/it]                    

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 4/751.34375 [00:20<1:08:48,  5.52s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 5/751.34375 [00:23<59:03,  4.75s/it]  

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 6/751.34375 [00:26<53:44,  4.33s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 7/751.34375 [00:31<54:37,  4.40s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 8/751.34375 [00:34<52:06,  4.21s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|          | 9/751.34375 [00:38<48:47,  3.94s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|▏         | 10/751.34375 [00:45<1:00:49,  4.92s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   1%|▏         | 11/751.34375 [00:51<1:02:48,  5.09s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 12/751.34375 [00:53<52:35,  4.27s/it]  

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 13/751.34375 [00:56<48:02,  3.90s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 14/751.34375 [00:58<42:56,  3.49s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 15/751.34375 [01:02<42:07,  3.43s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 16/751.34375 [01:05<42:21,  3.46s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 17/751.34375 [01:08<38:10,  3.12s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   2%|▏         | 18/751.34375 [01:10<36:36,  2.99s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 19/751.34375 [01:13<36:48,  3.02s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 20/751.34375 [01:16<34:59,  2.87s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 22/751.34375 [01:25<43:09,  3.55s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 23/751.34375 [01:27<39:10,  3.23s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 24/751.34375 [01:30<38:35,  3.18s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 25/751.34375 [01:33<36:20,  3.00s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   3%|▎         | 26/751.34375 [01:36<35:53,  2.97s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▎         | 27/751.34375 [01:38<33:57,  2.81s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▎         | 28/751.34375 [01:44<44:52,  3.72s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▍         | 29/751.34375 [01:46<38:56,  3.23s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▍         | 30/751.34375 [01:49<37:14,  3.10s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▍         | 31/751.34375 [01:51<34:38,  2.89s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▍         | 32/751.34375 [01:54<34:45,  2.90s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   4%|▍         | 33/751.34375 [01:58<38:45,  3.24s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▍         | 34/751.34375 [02:04<49:35,  4.15s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▍         | 35/751.34375 [02:08<46:10,  3.87s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▍         | 36/751.34375 [02:12<47:28,  3.98s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▌         | 38/751.34375 [02:25<1:04:35,  5.43s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▌         | 39/751.34375 [02:28<56:13,  4.74s/it]  

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▌         | 40/751.34375 [02:34<1:00:09,  5.07s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   5%|▌         | 41/751.34375 [02:36<49:54,  4.22s/it]  

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▌         | 42/751.34375 [02:39<43:34,  3.69s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▌         | 43/751.34375 [02:41<38:55,  3.30s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▌         | 44/751.34375 [02:43<35:19,  3.00s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▌         | 45/751.34375 [02:51<53:27,  4.54s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▌         | 46/751.34375 [02:56<52:10,  4.44s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▋         | 47/751.34375 [02:58<43:35,  3.71s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   6%|▋         | 48/751.34375 [03:01<41:37,  3.55s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 49/751.34375 [03:03<37:56,  3.24s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 50/751.34375 [03:06<35:39,  3.05s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 51/751.34375 [03:09<34:59,  3.00s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 52/751.34375 [03:14<44:00,  3.78s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 53/751.34375 [03:18<41:53,  3.60s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 54/751.34375 [03:23<47:39,  4.10s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 55/751.34375 [03:26<42:42,  3.68s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   7%|▋         | 56/751.34375 [03:27<35:55,  3.10s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 57/751.34375 [03:30<33:49,  2.92s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 58/751.34375 [03:33<33:00,  2.86s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 59/751.34375 [03:36<35:33,  3.08s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 60/751.34375 [03:41<41:02,  3.56s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 61/751.34375 [03:44<37:57,  3.30s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 62/751.34375 [03:48<40:48,  3.55s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   8%|▊         | 63/751.34375 [03:54<49:14,  4.29s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▊         | 64/751.34375 [03:58<48:16,  4.21s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▊         | 65/751.34375 [04:01<43:59,  3.85s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 66/751.34375 [04:04<41:12,  3.61s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 67/751.34375 [04:06<36:27,  3.20s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 68/751.34375 [04:09<36:51,  3.24s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 69/751.34375 [04:12<35:54,  3.16s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 70/751.34375 [04:14<32:01,  2.82s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:   9%|▉         | 71/751.34375 [04:22<49:22,  4.35s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|▉         | 72/751.34375 [04:29<58:04,  5.13s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|▉         | 73/751.34375 [04:33<52:02,  4.60s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|▉         | 74/751.34375 [04:35<45:12,  4.00s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|▉         | 75/751.34375 [04:42<53:47,  4.77s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|█         | 76/751.34375 [04:46<50:45,  4.51s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|█         | 77/751.34375 [04:49<46:21,  4.12s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  10%|█         | 78/751.34375 [04:55<54:13,  4.83s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  11%|█         | 79/751.34375 [05:00<52:42,  4.70s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


loss = nan:  11%|█         | 80/751.34375 [05:02<45:06,  4.03s/it]

torch.Size([3200, 4])
torch.Size([3200, 4])
torch.Size([3200, 4])


In [None]:
model.load_state_dict(torch.load(modpath))
input_x = []
output_x = []

t = tqdm.tqdm(enumerate(test_loader),total=test_samples/batch_size)
for i, data in t:
    data.to(device)
    input_x.append(data.x.cpu().numpy())
    output_x.append(model(data).cpu().detach().numpy())

In [None]:
diff = []
output_px = []
input_px = []
for i in range(len(input_x)):
    diff.append(((output_x[i][:,3]-input_x[i][:,3])/input_x[i][:,3]).flatten())
    output_px.append(output_x[i][:,3].flatten())
    input_px.append(input_x[i][:,3].flatten())
    
all_diff = np.concatenate(diff)
all_input_px = np.concatenate(input_px)
all_output_px = np.concatenate(output_px)

print(all_input_px.shape)
print(all_output_px.shape)

plt.figure()
plt.hist(all_input_px, bins=np.linspace(-1, 40, 101),alpha=0.5)
plt.hist(all_output_px, bins=np.linspace(-1, 40, 101),alpha=0.5)

plt.figure()
plt.hist(all_diff, bins=np.linspace(-5, 5, 101))