In [None]:
import networkx as nx
import torch_geometric as tg
import numpy as np
from pathlib import Path
from pathlib import PurePath
import os
import re
import glob
import math
import json
from xml.dom import minidom
import torch
import torch.nn as nn
from numpy import argmax
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import Utils.utils as util
import Utils.config as conf
import Utils.bad_utils as b_utils

from Utils import BAD as b

# Load BAC config
cfg = conf.get_bac_cfg()

# Getting folders
_FOLDER = os.getcwd() + "/BAD/" # Where processed data are located
MAIN_PATH = "/DREHER/bimacs_derived_data/" # Where BAD json objects are located

In [None]:
'''

    This is used to create new datasets.
    
    Example:
    train_set = BAD_DS( "FOLDER_TO_BAD_JSON", split_type)
    
    All settings is set inside config.py
    except save new data and create new dataset

'''
_SAVE_RAW_DATA = False
_CREATE_DATASET = False

if _CREATE_DATASET:
    train_set = b.BAD_DS(MAIN_PATH, "training")
    val_set = b.BAD_DS(MAIN_PATH, "validation")
    test_set = b.BAD_DS(MAIN_PATH, "test")


# USED TO SAVE RAW TRAINING DATA
if _SAVE_RAW_DATA:
    with open(os.path.join(_FOLDER + "raw/bad_training_" + str(cfg.time_window) + "w.pt"), 'wb') as f:
                torch.save(train_set, f)
    with open(os.path.join(_FOLDER + "raw/bad_validation_" + str(cfg.time_window) + "w.pt"), 'wb') as f:
                torch.save(val_set, f)
    with open(os.path.join(_FOLDER + "raw/bad_test_" + str(cfg.time_window) + "w.pt"), 'wb') as f:
                torch.save(test_set, f)

In [None]:
'''

    This is used to create load MANIAC dataset into DataLoader.
    
    Example)
    To load processed or create preprocessed data into list.
    train_dataset = b.BadIMDS(_FOLDER, "train")
    
    Creates a DataLoader of the loaded dataset.
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)

'''

# Loading pre-processed or creates new processed pt files into FOLDER/processed/
train_dataset = b.BadIMDS(_FOLDER + "/", "train")
val_dataset = b.BadIMDS(_FOLDER + "/", "valid")
test_dataset = b.BadIMDS(_FOLDER + "/", "test")

#####################PRINT################################
print("Total graphs:\t {}\n=========".format(len(train_dataset)+len(test_dataset)+len(val_dataset)))
print("Training: \t {}".format(len(train_dataset)))
print("Test: \t\t {}".format(len(test_dataset)))
print("Validation: \t {}\n=========".format(len(val_dataset)))
#####################PRINT################################

# Create data loaders from dataset.
# https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, drop_last=True)
valid_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, drop_last=True)

#####################PRINT################################
print("Total batchs:\t {}\n=========".format(len(train_loader)+len(test_loader)+len(valid_loader)))
print("Training: \t {}".format(len(train_loader)))
print("Test: \t\t {}".format(len(test_loader)))
print("Validation: \t {}\n=========".format(len(valid_loader)))
#####################PRINT################################

# Get maximum node for graph reconstruction
max_num_nodes_train = max([len(i.x) for i in train_dataset])
max_num_nodes_valid = max([len(i.x) for i in val_dataset])
max_num_nodes_test = max([len(i.x) for i in test_dataset])
max_num_nodes = max(max_num_nodes_test, max_num_nodes_train, max_num_nodes_valid)

#####################PRINT################################
print("Max number of nodes found:", max_num_nodes)
#####################PRINT################################

In [None]:
from torch.nn import Sequential, Linear, ReLU, ELU
from torch_geometric.nn import NNConv, BatchNorm
from torch_scatter import scatter_mean

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""

    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

'''

    Encoder with 2x NNConv, BN
    
    Outputs mu, log variance, mu, prediction

'''

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # NN for NNConv features
        nn = Sequential(Linear(len(cfg.spatial_map), 512), ReLU(), Linear(512, cfg.channels * cfg.channels * 2 ))
        nn2 = Sequential(Linear(len(cfg.spatial_map), 512), ReLU(), Linear(512, cfg.decoder_in * cfg.channels * 2 ))
        
        # Encoder
        self.lin   = torch.nn.Linear(len(cfg.objects), cfg.channels) # FCL
        self.conv1 = NNConv(cfg.channels, cfg.channels*2, nn, aggr='mean')
        self.bn1   = BatchNorm(cfg.channels*2)
        self.conv2 = NNConv(cfg.channels*2, cfg.channels, nn, aggr='mean')
        
        # Z-representation as mu and log
        self.mu = NNConv(cfg.channels*2, cfg.decoder_in, nn2, aggr='max')
        self.logvar = NNConv(cfg.channels*2, cfg.decoder_in, nn2, aggr='max')
    
    def forward(self, data):
        # Input
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # FCL
        out = F.relu(self.lin(data.x))
        
        # Used for ACTION PREDICTION
        hidden = F.relu(self.conv1(out, data.edge_index, data.edge_attr))       # Conv1
        hidden = self.bn1(hidden)                                               # BatchNorm
        conv2_out = F.relu(self.conv2(hidden, data.edge_index, data.edge_attr)) # Conv2
        
        p_x = scatter_mean(conv2_out, batch, dim=0)
        
        # Used for GRAPH RECONSTRUCTION
        mu = self.mu(hidden, data.edge_index, data.edge_attr)
        logvar = self.logvar(hidden, data.edge_index, data.edge_attr)
        mu = scatter_mean(mu, batch, dim=0)
        logvar = scatter_mean(logvar, batch, dim=0)
        
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        
        # Only using reparam trick for training.
        if self.training:
            return std * eps + mu, logvar, mu, p_x
        else:
            return mu, logvar, mu, p_x

'''

    Prediction Branch
    
    Input: Input size, sequence length, hidden size in LSTM and number of layers in LSTM.
    Output: Prediction
    

'''
class Predictor(torch.nn.Module):
    def __init__(self, input_size, seq_len, hidden_size, n_layers):
        super(Predictor, self).__init__()
        
        self.prev_hidden = None
        self.bs          = cfg.batch_size
        self.input_size  = input_size
        self.seq_len     = seq_len
        self.hidden_size = hidden_size
        self.n_layers    = n_layers
        
        # Model
        self.lstm = torch.nn.LSTM(self.input_size, self.hidden_size, self.n_layers, dropout=cfg.dropout, batch_first=True)
        self.lin1 = torch.nn.Linear(self.hidden_size, len(cfg.action_map))
    
    def forward(self, p_x):
        
        if self.prev_hidden is None:
            self.prev_hidden = (torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda(device),
                                torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda(device))

        # Reshape data to proper shape
        input_reshape = p_x.reshape( self.bs, self.seq_len, -1 ).to(device)
        
        # Fed LSTM
        q, h = self.lstm( input_reshape , self.prev_hidden )
        
        # Repackage hidden layer to reduce memory overflow
        self.prev_hidden = repackage_hidden(h)
        
        # Get the LAST output from lstm
        out = self.lin1(q[:, -1, :])
        
        return out

'''

    Graph Reconstruction Branch
    
'''
class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        # Experiment with more or larger FCL.
        self.fc1 = nn.Linear(cfg.decoder_in, 64)
        self.fc2 = nn.Linear(64, max_num_nodes*max_num_nodes)
    
    def forward(self, z_x):

        out = F.relu(self.fc1(z_x)) # FCL with ReLU
        out = self.fc2(out)         # FCL to output size 
        out = torch.sigmoid(out)    # Sigmoid
        
        return out

'''

    djNetwork model.
    
                -> Action Prediction
    Encoder ->  -> Decoder -> Graph Reconstruction
    
'''
class djNet(torch.nn.Module):
    def __init__(self):
        super(djNet, self).__init__()
        # djNet struture
        self.encoder = Encoder()
        self.predictor = Predictor(8, 8, 8, 2)
        self.decoder = Decoder()
    
    def forward(self, x):
        z, logvar, mu, p_x = self.encoder(x) # Encoder input
        p_z = self.predictor(p_x)            # Prediction input
        q_z = self.decoder(z)                # Decoder input
        
        return q_z, logvar, mu, z, p_z


In [None]:
'''
    Create model with device.
'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = djNet().to(device)


print("#################")
print(model)
print("#################")

In [None]:
# OPTIMIZER USED
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
print("Optimizer is set.")
print("------------------------")
print(optimizer)

print("####################")

# If SummeryWriter for tensorboard is used.
if cfg.summery_writer:
    from torch.utils.tensorboard import SummaryWriter
    print("SummeryWriter is ON.")
    writer = SummaryWriter(comment="ENTER SOME COMMENT ABOUT MODEL")
else:
    print("SummerWriter is OFF.")

print("####################")

# Loss function
ap_criterion = nn.CrossEntropyLoss()

def loss_criterion(inputs, targets, logvar, mu, ap_inputs, ap_targets):
    # Reconstruction loss
    bce_loss = F.binary_cross_entropy(inputs, targets, reduction="sum")
    
    # Action prediction loss
    ap_loss = ap_criterion(ap_inputs, ap_targets)

    # Regularization term
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())    

    return bce_loss + kl_loss, ap_loss

In [None]:
'''
    
    Train model with specfic loader.
    
    Input: loader and model
    Output: Reconstruction Loss and Action Prediciton Loss


'''

def train(loader, model):
    model.train()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data) # input model
        prediction = y_hat.view(_bs, -1, max_num_nodes) # reshape to prediction

        # Creating targets
        target_adj = util.to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes)
        target = data.y.view(_bs, -1) # reshape target
        y_ap_true = target.argmax(axis=1) # get the ground truth target

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)
        net_loss = recon_loss * 0.6 + ap_loss

        # Compute gradients and updates weights.
        net_loss.backward()
        optimizer.step()

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs)


'''
    
    Test model with a specific loader.
    
    Input: loader and model
    Output: Reconstruction Loss, Action Prediciton Loss, Accuracy
    
'''

def test(loader, model):
    model.eval()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data)
        prediction = y_hat.view(_bs, -1, max_num_nodes)

        # Creating targets
        target_adj = util.to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes)
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)
        
        pred = y_ap.max(1)[1]

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
        correct += pred.eq(y_ap_true).sum().item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs), correct/(len(loader)*_bs)

In [None]:
'''

    TRAINING

'''

for epoch in range(1, cfg.epochs):
    
    train_recon_loss, train_ap_loss = train(train_loader, model)
    _, _, train_ap_acc = test(train_loader, model)
    validation_recon_loss, validation_ap_loss, validation_ap_acc = test(valid_loader, model)

    # Writes to tensorboard
    if cfg.summery_writer:
        writer.add_scalar('AP_Acc/train', train_ap_acc, epoch)
        writer.add_scalar('AP_Acc/validation', validation_ap_acc, epoch)

        writer.add_scalar('Recon_Loss/train', train_recon_loss, epoch)
        writer.add_scalar('Recon_Loss/validation', validation_recon_loss, epoch)

        writer.add_scalar('AP_Loss/train', train_ap_loss, epoch)
        writer.add_scalar('AP_Loss/validation', validation_ap_loss, epoch)

    
    print("Epoch {:02d}, [T] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}% [V] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}%, TAcc: {:.2f}".format( epoch, 
                                                                           train_recon_loss,
                                                                           train_ap_loss,
                                                                           train_ap_acc*100,
                                                                           validation_recon_loss,
                                                                           validation_ap_loss,
                                                                           validation_ap_acc*100,
                                                                           test_ap_acc))


In [None]:
_SAVE = False
if _SAVE:
    torch.save({
                'epoch': 99,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': 4.110431,
                }, "./MANIAC_final_models/MANIAC_final_4w_dim_64_increase_mu_64.pt")
    print("SAVED!")
else:
    checkpoint = torch.load("/data/tmp/dj_data/runs/dreher_10.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print("LOADED MODEL")

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

# Test that returns graph reconstruction and action prediction
#
# Input: loader, model, max_num_nodes, device
# Output: 
# cr_pred   - graph reconstruction prediction
# cr_gt     - graph reconstruction ground truth
# ap_pred   - action prediction
# ap_gt     - action ground truth
# top3      - list of top 3 predicitons
# node_list - correct number of nodes of graphs.
cr_pred, cr_gt, ap_pred, ap_target, top3, node_list = b.test(test_loader, model, max_num_nodes, device)

cm = confusion_matrix(ap_target, ap_pred)
print(cm)

print(classification_report(ap_target, ap_pred, target_names=cfg.action_map, labels=[i for i in range(len(cfg.action_map))]))

util.calc_auc_roc(cr_gt, cr_pred, node_list, 0.3)

In [None]:
new_pred = []
for i in range(len(top3)):
    if action_map[ap_target[i].astype(int)] in top3[i]:
        new_pred.append(ap_target[i])
    else:
        new_pred.append(action_map.index(top3[i][0]))

print("top 3")

print(classification_report(ap_target.astype(int), new_pred, target_names=action_map, labels=[0,1,2,3,4,5,6,7, 8, 9, 10, 11, 12,13,14]))

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Normalise
cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Normalized')
plt.show(block=False)

fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Number of predictions')
plt.show(block=False)

In [None]:
"""
 Example on how to save and load models
"""

_SAVE = False

if _SAVE:
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': 0,
                }, "./BAD_MODEL/bad_84.pt")
    print("SAVED!")
else:
    checkpoint = torch.load("./BAD_MODEL/bad_84.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print("LOADED MODEL")