In [None]:
import networkx as nx
import torch_geometric as tg
import numpy as np
from pathlib import Path
from pathlib import PurePath
import os
import re
import glob
import math

from xml.dom import minidom

from numpy import argmax
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder

import json

In [None]:
'''
Method to create one hot encoders

Input: Array with string
'''
def one_hot_string(map):
    values = np.array(map)

    label_encoder = LabelEncoder()
    integer_encoded = label_encoder.fit_transform(values)

    onehot_encoder = OneHotEncoder(sparse=False)
    integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
    onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
    
    return onehot_encoded

'''
    Get action name from one-hot-encoding
'''
def get_action_from_encoding(enc):
    idx = _actions.index(enc)
    return action_map[idx], idx

'''
    Create a dict with temporal information
'''
def createTemporalDict(in_dict):
    temporal_concat = {}
    
    for key in in_dict:
        temporal_concat[key] = {}
        for seq in in_dict[key]:
                    temporal_concat[key][seq] = [concatenateTemporal(in_dict[key][seq])]
    return temporal_concat

'''
    See if two nodes match
'''
def nodesMatch(n1, n_list):
    for i, node in enumerate(n_list, start=0):
        if node[1] == n1:
            return True, i

    return False, 0

'''
    Calculate the temporal edges
    # nodes will start at node_index_list.last + 1
'''
def calculateTemporalEdges(full_gr, nodes, index_list):
    temporal_rel = _relations[spatial_map.index("temporal")]
    old_nodes = []
    node_cnt = index_list[-1]


    for index in index_list:
        old_nodes.append(full_gr.nodes[index])

    for index in index_list:
        match, ind = nodesMatch(full_gr.nodes[index], nodes)
        if match:
            full_gr.add_edge(index, node_cnt, edge_attr=temporal_rel)
            node_cnt += 1

'''
    Function to create the temporal information between graphs.
    Input: list of graphs
'''
def concatenateTemporal(graphs):
    graph_nx = nx.DiGraph()
    graph_nx.graph["features"] = graphs[0].graph['features']
    node_cnt = 0
    node_index_list = []

    for i, graph in enumerate(graphs, start=0):

        if len(node_index_list) > 0:
            calculateTemporalEdges(graph_nx, graph.nodes(data=True), node_index_list)

        node_index_list = []

        for node in graph.nodes(data=True):
            graph_nx.add_node(node_cnt, x=node[1]['x'])
            node_index_list.append(node_cnt)
            node_cnt += 1
            
        for edge in graph.edges():
            graph_nx.add_edge(node_index_list[edge[0]], node_index_list[edge[1]], edge_attr=graph.get_edge_data(edge[0], edge[1])['edge_attr'])

    empty_list = []
    for node in graph_nx.nodes(data=True):
        if node[1] == {}:
            empty_list.append(node[0])

    for node in empty_list:
        graph_nx.remove_node(node)


    return graph_nx

In [None]:
'''
    WARNING WARNING WARNING 
    DO NOT CHANGE ORDER OF THESE ITEMS
'''
objects = ['LeftHand', 'RightHand', 'banana', 'bottle', 'bowl', 'cereals', 'cup', 'cuttingboard', 
          'hammer', 'harddrive', 'knife', 'saw', 'screwdriver', 'sponge', 'whisk', 'woodenwedge']

spatial_map = ['above', 'behind of', 'below', 'contact', 'fixed moving together', 'getting close',
              'halting together', 'in front of', 'inside', 'left of', 'moving apart', 'moving together', 
                'right of', 'stable', 'surround', 'temporal']


action_map = ['approach', 'cut', 'drink', 'hammer', 'hold', 'idle', 'lift', 'place', 'pour', 'retreat', 
              'saw', 'screw', 'stir', 'wipe']

original_index = ['idle', 'approach', 'retreat', 'lift', 'place', 'hold', 'pour', 'cut',
                 'hammer', 'saw', 'stir', 'screw', 'drink', 'wipe']

_relations = one_hot_string(spatial_map).tolist()
_actions = one_hot_string(action_map).tolist()
_objects = one_hot_string(objects).tolist()

In [None]:
def get_object_list(path):
    # ['DREHER', 'bimacs_derived_data', 'subject_1', 'task_5_k_cereals', 'take_3', 'spatial_relations', 'frame_0.json']
    final_path = './' + path[0] + '/bimacs_derived_data_3D/' + path[2] + '/' + path[3] + '/' + path[4] + '/3d_objects/' + path[6]
    with open(final_path) as f:
        data = json.load(f)

    temp_dict = {}
    cnt = 0

    for obj in data:
        temp_dict[cnt] = obj['class_name']
        cnt += 1
    
    return temp_dict


# Takes a single JSON file and outputs a graph
def json_to_graph(input_path, target):
    graph = nx.DiGraph()
    
    node_list_path = os.path.normpath(input_path).split(os.sep)
    node_list = get_object_list(node_list_path)
    
    # Load JSON Object to list
    with open(input_path) as f:
          data = json.load(f)
    
    for index, name in node_list.items():
        graph.add_node(index, x=_objects[objects.index(name)])
    
    # Populate the graph with nodes and edges
    for obj in data:  
        relation_name = obj['relation_name']
        graph.add_edge(obj['object_index'], obj['subject_index'], edge_attr=_relations[spatial_map.index(relation_name)])
    
    # If the ground truth contain null value set the action to undefined
    if(target is None):
        return -1
        #graph.graph['features'] = _actions[action_map.index('undefined')]
    else:
        #action_map.index(original_index[target])
        graph.graph["features"] = _actions[action_map.index(original_index[target])]
    
    if len(graph.nodes()) == 0:
        print("------- NO NODES")
        return -1
    
    if len(graph.edges()) == 0:
        print("------- NO EDGES")
        return -1
    

    return graph


def get_target_action(cnt, ground_truth):
    # Find target by comparing the frame count with ground truth
    for index, item in enumerate(ground_truth['right_hand']):
        if(index % 2 == 0 and index != 0):
            if(cnt <= item):
                return ground_truth['right_hand'][index-1]  
    return 'Not found'

def take_to_graph_list(path):
    graph_list = []
    # Get ground truth path
    gt_name = [pos_json for pos_json in os.listdir(path) if pos_json.endswith('.json')][0]
    gt_path = path + gt_name
    
    # Extract all json files in spatial_relations and sort them
    json_files = [pos_json for pos_json in os.listdir(path + 'spatial_relations') if pos_json.endswith('.json')]
    json_files.sort(key=lambda f: int(re.sub('\D', '', f)))
    
    # Load ground truth
    with open(gt_path) as f:
          ground_truth = json.load(f)
    
    for file in json_files:
        frame_cnt = int(re.search(r'\d+', file).group())
        graphs = json_to_graph(path + 'spatial_relations/' + file, get_target_action(frame_cnt, ground_truth))
        
        if graphs != -1:
            graph_list.append(graphs)
        else:
            print("file:", file, "frame_cnt:", frame_cnt, "gt:", ground_truth)
    
    return graph_list


def generate_graphs(_seperate=True):
    
    if _seperate:
        print("Creating graphs to dict.")
        all_data = {}
    else:
        print("Appending graphs to array.")
        all_data = []
    
    # Iterate subjects
    for dic in list(glob.glob(MAIN_PATH+'/*/')):
        # Iterate tasks
        for sub_dic in list(glob.glob(dic+'/*/')):
            # Iterate takes
            for sub_sub_dic in list(glob.glob(sub_dic+'/*/')):
                # Get subject number
                sub = int(re.search(r'\d+', dic).group())
                # Get task number
                task = int(re.search(r'\d+', sub_dic[len(dic):]).group())
                # Get take number
                take = int(re.search(r'\d+', sub_sub_dic[len(sub_dic):]).group())
                name = "take_" + str(sub) + "_" + str(task) + "_" + str(take)
                
                if _seperate:
                    all_data[name] = take_to_graph_list(sub_sub_dic)
                else:
                    all_data += take_to_graph_list(sub_sub_dic)
             
    return all_data

MAIN_PATH = "./DREHER/bimacs_derived_data"
SECOND_PATH = "./DREHER/bimacs_derived_data_3D"

#test = generate_graphs()

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torchvision import transforms

class DreherDataset(Dataset):
    """Dreher Dataset dataset."""

    def __init__(self, window=0, root_dir=None, transform=None):
        self.to_tensor = transforms.ToTensor()
        self.root_dir = root_dir
        self.samples = generate_graphs(_seperate=False)
        self.transform = transform
        self.window = window

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        
        if self.window > 1:
            current_action = self.samples[idx].graph['features']
            step_back = 0

            for i in range(self.window):
                if self.samples[idx+i].graph['features'] != current_action:
                    step_back += 1

            if step_back > 0:
                x = self.samples[idx-step_back:idx+self.window-step_back]
            else:
                x = self.samples[idx:idx+self.window]
        else:
            x = self.samples[idx]

        return x

In [None]:
'''

    This is used to create new datasets.
    
    Example:
    train_set = DreherDataset( window=n )
    
    if window = 0, 1 graph will be made.
    if window > 0, x number of graphs will be created together
    
    temporal = False, will not add temporal information between nodes
    if tempral = True, will add temporal information between nodes

'''
_SAVE_RAW_DATA = False
_CREATE_DATASET = False

_TIME_WINDOW = 8

if _CREATE_DATASET:
    train_set = DreherDataset(window=_TIME_WINDOW)


# USED TO SAVE RAW TRAINING DATA
if _SAVE_RAW_DATA:
    with open(os.path.join(_FOLDER + "raw/training_" + str(_TIME_WINDOW) + "w.pt"), 'wb') as f:
                torch.save(train_set, f)

In [None]:
import torch
from torch_geometric.data import DataLoader
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar)

class DreherDS(InMemoryDataset):
    def __init__(self,
                root,
                dset="train",
                transform=None):
        super(DreherDS, self).__init__(root, transform)

        if dset == "train":
            path = self.processed_paths[0]
        elif dset == "valid":
            path = self.processed_paths[1]
        else:
            path = self.processed_paths[2]

        self.data, self.slices = torch.load(path)
        
    @property
    def raw_file_names(self):
        return ['dreher_train_disc_4w.pt', 
                'dreher_val_disc_4w.pt',
                'dreher_test_disc_4w.pt']
    
    
    @property
    def processed_file_names(self):
        return ['dreher_train_disc_10w.pt', 
                'dreher_test_disc_10w.pt',
                'dreher_test_disc_4w.pt']
    
    def download(self):
        return
    
    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
    
    def process(self):
        big_slices = []
        for raw_path, path in zip(self.raw_paths, self.processed_paths):
            big_data = []
            graphs = torch.load(raw_path)
            j = 0
            for graph in graphs:
                G = nx.convert_node_labels_to_integers(graph)
                G = G.to_directed() if not nx.is_directed(G) else G
                edge_index = torch.tensor(list(G.edges)).t().contiguous()

                data = {}

                for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
                    for key, value in feat_dict.items():
                        data[key] = [value] if i == 0 else data[key] + [value]

                for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
                    for key, value in feat_dict.items():
                        data[key] = [value] if i == 0 else data[key] + [value]

                for key, item in data.items():
                    try:
                        data[key] = torch.tensor(item)
                    except ValueError:
                        pass

                data['edge_index'] = edge_index.view(2, -1)
                data = tg.data.Data.from_dict(data)
                data.num_nodes = len(graph)
                data.y = torch.tensor(graph.graph['features'])
                
                big_data.append(data)

            torch.save(self.collate(big_data), path)

In [None]:
train_dataset = DreherDS('/data/tmp/dj_data/', "train")
val_dataset = DreherDS('/data/tmp/dj_data/', "valid")
test_dataset = DreherDS('/data/tmp/dj_data/', "test")

print("TRAIN DATASET:", len(train_dataset))
print("VALID DATASET:", len(val_dataset))
print("TEST DATASET:", len(test_dataset))

In [None]:
count = 0
max_count = 1
_bs = 32

train_loader = DataLoader(train_dataset, batch_size=_bs, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=_bs, shuffle=True, drop_last=True)
valid_loader = DataLoader(val_dataset, batch_size=_bs, shuffle=True, drop_last=True)

In [None]:
from torch_geometric.utils import to_dense_batch, to_dense_adj, to_scipy_sparse_matrix
from torch_scatter import scatter_add

'''
    Creates a dense adjacency matrix from pg data.
    With normalization
'''
def to_dense_adj_max_node(edge_index, x, edge_attr, batch=None, max_node=None):
    if batch is None:
        batch = edge_index.new_zeros(edge_index.max().item() + 1)
    
    batch_size = batch[-1].item() + 1
    
    if max_node is None:
        max_num_nodes = num_nodes.max().item()
    else:
        max_num_nodes = max_node
    
    one = batch.new_ones(batch.size(0))
    num_nodes = scatter_add(one, batch, dim=0, dim_size=batch_size)
    cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
    
    size = [batch_size, max_num_nodes, max_num_nodes]
    
    size = size
    dtype = torch.float
    
    adj = torch.zeros(size, dtype=dtype, device=edge_index.device)

    edge_index_0 = batch[edge_index[0]].view(1, -1)
    edge_index_1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    edge_index_2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
    
    # Normalize the edges on the length of pre-defined spatial objects.
    _ea = []
    for ea in edge_attr:
        _ea.append((ea[0].item()+1)/len(spatial_map))

    adj[edge_index_0, edge_index_1, edge_index_2] = torch.FloatTensor(_ea).cuda()

    # Normalize objects in the diagonal
    objects_sum = x.argmax(dim=1).type(dtype)
    objects_sum = objects_sum/len(_objects)

    object_size = [batch_size, num_nodes.max().item()]
    object_mat = torch.zeros(object_size, dtype=dtype, device=edge_index.device)
    
    obj_offset = 0
    
    # Creates the adjacency matrix of size [max_node*max_node].
    for b in range(batch_size):
        temp = torch.zeros(num_nodes.max().item(), dtype=dtype, device=edge_index.device)
        _obj = objects_sum[obj_offset:(obj_offset+num_nodes[b])].type(dtype)
        obj_offset += num_nodes[b]
        dd = torch.cat((_obj, temp), dim=0)
        dd = dd[:num_nodes.max().item()]
        object_mat[b] = dd
            
    adj.as_strided(object_mat.size(), [adj.stride(0), adj.size(2) + 1]).copy_(object_mat)
    
    # Error check to see that the diagonal is not zero.
    for i in range(len(adj)):
        if torch.diag(adj[i]).sum().item() == 0:
            print("ERROR! ZERO DIAGONAL!", i)

    return adj

In [None]:
from torch_geometric.utils import to_dense_batch, to_dense_adj, to_scipy_sparse_matrix
from torch_scatter import scatter_add

'''
    Creates a dense adjacency matrix from pg data.
    Without normalization
'''
def to_dense_adj_max_node(edge_index, x, edge_attr, batch=None, max_node=None):
    if batch is None:
        batch = edge_index.new_zeros(edge_index.max().item() + 1)
    
    batch_size = batch[-1].item() + 1
    
    if max_node is None:
        max_num_nodes = num_nodes.max().item()
    else:
        max_num_nodes = max_node
    
    one = batch.new_ones(batch.size(0))
    num_nodes = scatter_add(one, batch, dim=0, dim_size=batch_size)
    cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
    
    size = [batch_size, max_num_nodes, max_num_nodes]
    
    size = size
    dtype = torch.float
    
    adj = torch.zeros(size, dtype=dtype, device=edge_index.device)

    edge_index_0 = batch[edge_index[0]].view(1, -1)
    edge_index_1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    edge_index_2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
    
    # Normalize the edges on the length of pre-defined spatial objects.
    _ea = []
    for ea in edge_attr:
        _ea.append(1)

    adj[edge_index_0, edge_index_1, edge_index_2] = torch.FloatTensor(_ea).cuda(device)

    # Normalize objects in the diagonal
    objects_sum = x.argmax(dim=1).type(dtype)
    objects_sum[objects_sum > 0] = 1

    object_size = [batch_size, num_nodes.max().item()]
    object_mat = torch.zeros(object_size, dtype=dtype, device=edge_index.device)
    
    obj_offset = 0
    
    # Creates the adjacency matrix of size [max_node*max_node].
    for b in range(batch_size):
        temp = torch.zeros(num_nodes.max().item(), dtype=dtype, device=edge_index.device)
        _obj = objects_sum[obj_offset:(obj_offset+num_nodes[b])].type(dtype)
        obj_offset += num_nodes[b]
        dd = torch.cat((_obj, temp), dim=0)
        dd = dd[:num_nodes.max().item()]
        object_mat[b] = dd
            
    adj.as_strided(object_mat.size(), [adj.stride(0), adj.size(2) + 1]).copy_(object_mat)
    
    # Error check to see that the diagonal is not zero.
    for i in range(len(adj)):
        if torch.diag(adj[i]).sum().item() == 0:
            print("ERROR! ZERO DIAGONAL!", i)

    return adj

In [None]:
max_num_nodes_train = max([len(i.x) for i in train_dataset])
print("max in train:", max_num_nodes_train)

max_num_nodes_valid = max([len(i.x) for i in val_dataset])
print("max in train:", max_num_nodes_valid)

max_num_nodes_test = max([len(i.x) for i in test_dataset])
print("max in train:", max_num_nodes_test)

max_num_nodes = max(max_num_nodes_train, max_num_nodes_valid)
print("MAX:", max_num_nodes)

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""

    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [None]:
from torch.nn import Sequential, Linear, ReLU, ELU
from torch_geometric.nn import NNConv, radius_graph, fps, global_mean_pool
from torch_scatter import scatter_mean
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import softmax
from torch_geometric.nn import BatchNorm

input_size=len(action_map) # static value
output_size=len(action_map) # static value
channels = 64

_PRINT = False # Used for debugging

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin = torch.nn.Linear(len(objects), channels)

        nn = Sequential(Linear(len(spatial_map), 512), ReLU(), Linear(512, channels * channels *2))

        self.conv1 = NNConv(channels, channels*2, nn, aggr='mean') 
        self.bn1 = BatchNorm(channels*2)
        self.conv2 = NNConv(channels*2, channels, nn, aggr='mean')
        
        nn2 = Sequential(Linear(len(spatial_map), 512), ReLU(), Linear(512, 2 * channels * 64  ))
        self.mu = NNConv(channels*2, 64, nn2, aggr='mean')
        self.logvar = NNConv(channels*2, 64, nn2, aggr='mean')
    
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        out = F.relu(self.lin(data.x)) # Fully connected
        
        # Action prediction
        hidden = F.relu(self.conv1(out, data.edge_index, data.edge_attr)) # conv1
        hidden = self.bn1(hidden)
        conv2_out = F.relu(self.conv2(hidden, data.edge_index, data.edge_attr)) # conv2
        
        # Reconstruction
        mu = self.mu(hidden, data.edge_index, data.edge_attr)
        logvar = self.logvar(hidden, data.edge_index, data.edge_attr)
        
        mu = scatter_mean(mu, batch, dim=0)
        logvar = scatter_mean(logvar, batch, dim=0)
        
        p_x = scatter_mean(conv2_out, batch, dim=0)
        
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        
        if self.training:
            return std * eps + mu, logvar, mu, p_x
        else:
            return mu, logvar, mu, p_x

class Predictor(torch.nn.Module):
    def __init__(self, input_size, seq_len, hidden_size, n_layers):
        super(Predictor, self).__init__()
        
        self.prev_hidden = None
        self.bs = _bs
        self.input_size = input_size
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        self.lstm = torch.nn.LSTM(self.input_size, self.hidden_size, self.n_layers, dropout=0.1, batch_first=True)
        self.lin1 = torch.nn.Linear(self.hidden_size, len(action_map))
    
    def forward(self, p_x):
        
        if self.prev_hidden is None:
            self.prev_hidden = (torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda(device),
                                torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda(device))


        input_reshape = p_x.reshape( self.bs, self.seq_len, -1 ).to(device)

        q, h = self.lstm( input_reshape , self.prev_hidden )
        
        self.prev_hidden = repackage_hidden(h)
        
        out = self.lin1(q[:, -1, :])
        
        return out

class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(64, 256)
        self.fc2 = nn.Linear(256, max_num_nodes*max_num_nodes)
    
    def forward(self, z_x):

        out = F.relu(self.fc1(z_x))
        out = self.fc2(out)
        out = torch.sigmoid(out)
        
        return out


class djNet(torch.nn.Module):
    def __init__(self):
        super(djNet, self).__init__()
        
        self.encoder = Encoder()
        self.predictor = Predictor(8, 8, 8, 2)
        self.decoder = Decoder()
    
    def forward(self, x):
        z, logvar, mu, p_x = self.encoder(x)
        q_z = self.decoder(z)
        p_z = self.predictor(p_x)
        
        return q_z, logvar, mu, z, p_z


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = djNet()
model = model.to(device)

print(model)

In [None]:
ap_criterion = nn.CrossEntropyLoss()

def loss_criterion(inputs, targets, logvar, mu, ap_inputs, ap_targets):
    # Reconstruction loss
    bce_loss = F.binary_cross_entropy(inputs, targets, reduction="sum")
    
    # Action prediction loss
    ap_loss = ap_criterion(ap_inputs, ap_targets)

    # Regularization term
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())    

    return bce_loss + kl_loss, ap_loss

In [None]:
from torch.utils.tensorboard import SummaryWriter

#writer = SummaryWriter(comment="MANIAC_4w_dropout02")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(loader, model):
    model.train()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data)
        prediction = y_hat.view(_bs, -1, max_num_nodes)

        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes)
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)
        net_loss = recon_loss * 0.6 + ap_loss

        # Compute gradients and updates weights.
        net_loss.backward()
        optimizer.step()

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs)

def test(loader, model):
    model.eval()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data)
        prediction = y_hat.view(_bs, -1, max_num_nodes)

        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes)
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)
        
        pred = y_ap.max(1)[1]

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
        correct += pred.eq(y_ap_true).sum().item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs), correct/(len(loader)*_bs)

        
def super_test(loader, model):
    model.eval()
    ap_predict_list = np.array([])
    ap_gt_list = np.array([])
    
    reconstruct_predict_list = []
    reconstruct_gt_list = []
    top3_list = []
    num_nodes_list = []

    for data in loader:
        data = data.to(device)
        batch_size = data.batch[-1].item() + 1

        y_hat, logvar, mu, _, y_ap = model(data)
        pred = y_ap.max(1)[1]
        
        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes).cuda()
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)
        
        prediction = y_hat.view(_bs, -1, max_num_nodes)
        ap_predict_list = np.append(ap_predict_list, pred.cpu().detach().numpy())
        ap_gt_list = np.append(ap_gt_list, y_ap_true.cpu().detach().numpy())
        
        batch_size = data.batch[-1].item() + 1

        one = data.batch.new_ones(data.batch.size(0))
        num_nodes = scatter_add(one, data.batch, dim=0, dim_size=batch_size)
        num_nodes_list.append(num_nodes.cpu().detach().numpy())
        
        
        # Get top 3 predictions
        indices = torch.topk(y_ap, 3)
        for i in range(len(indices[0])):
            batch_top_list = []
            for top in indices[0][i].tolist():
                batch_top_list.append(action_map[y_ap[i].tolist().index(top)])
            
            top3_list.append(batch_top_list)

        gr_pred = prediction.cpu().detach().numpy()
        gr_gt = target_adj.cpu().detach().numpy()

        reconstruct_predict_list.append(gr_pred)
        reconstruct_gt_list.append(gr_gt)                
        
    return reconstruct_predict_list, reconstruct_gt_list, ap_predict_list, ap_gt_list, top3_list, num_nodes_list

In [None]:
save_counter = 0
test_ap_acc = 1

for epoch in range(1, 501):
    
    train_recon_loss, train_ap_loss = train(train_loader, model)
    _, _, train_ap_acc = test(train_loader, model)
    validation_recon_loss, validation_ap_loss, validation_ap_acc = test(valid_loader, model)

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_ap_loss,
        }, "/data/tmp/dj_data/runs/dreher_" + str(epoch) + ".pt")
    
    if False:
        writer.add_scalar('AP_Acc/train', train_ap_acc, epoch)
        writer.add_scalar('AP_Acc/validation', validation_ap_acc, epoch)

        writer.add_scalar('Recon_Loss/train', train_recon_loss, epoch)
        writer.add_scalar('Recon_Loss/validation', validation_recon_loss, epoch)

        writer.add_scalar('AP_Loss/train', train_ap_loss, epoch)
        writer.add_scalar('AP_Loss/validation', validation_ap_loss, epoch)

    
    print("Epoch {:02d}, [T] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}% [V] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}%, TAcc: {:.2f}".format( epoch, 
                                                                           train_recon_loss,
                                                                           train_ap_loss,
                                                                           train_ap_acc*100,
                                                                           validation_recon_loss,
                                                                           validation_ap_loss,
                                                                           validation_ap_acc*100,
                                                                           test_ap_acc))


In [None]:
_SAVE = False
if _SAVE:
    torch.save({
                'epoch': 99,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': 4.110431,
                }, "./MANIAC_final_models/MANIAC_final_4w_dim_64_increase_mu_64.pt")
    print("SAVED!")
else:
    checkpoint = torch.load("/data/tmp/dj_data/runs/dreher_10.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print("LOADED MODEL")

In [None]:
cr_pred, cr_gt, ap_pred, ap_target, top3, node_list = super_test(test_loader, model)

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

cm = confusion_matrix(ap_target, ap_pred)

print(classification_report(ap_target.astype(int), ap_pred.astype(int), target_names=action_map, labels=[0,1,2,3,4,5,6,7, 8, 9, 10, 11, 12,13,14]))

In [None]:
new_pred = []
for i in range(len(top3)):
    if action_map[ap_target[i].astype(int)] in top3[i]:
        new_pred.append(ap_target[i])
    else:
        new_pred.append(action_map.index(top3[i][0]))

print("top 3")

print(classification_report(ap_target.astype(int), new_pred, target_names=action_map, labels=[0,1,2,3,4,5,6,7, 8, 9, 10, 11, 12,13,14]))

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score

_PRED_TRESHOLD = 0.3

def calc_auc_roc(gt, pred, node_list):

    if len(gt) != len(pred):
        raise Exception("Invalid length of ground truth and prediction! GT {:d} and pred {:d}".format(len(gt), len(pred))) 
    
    num_matrix = len(gt)
    
    _correct_node_length = 0
    _correct_objects_length = 0
    _correct_objects = 0
    _total_objects = 0
    _roc_auc_score = 0
    _max_nodes_in_graph = 0
    _average_precision_score = 0
    graph_count = 0
    
    _node_diff_dict = {0: 0}
    
    for idx in range(num_matrix):
        for i in range(len(gt[idx])):
            gt_diag = np.round(np.diag(gt[idx][i]))
            cr_diag = np.round(np.diag(pred[idx][i]))
            cr_nodes = np.count_nonzero(cr_diag)
            graph_count += 1
        
            if node_list[idx][i] == cr_nodes:
                _max_nodes_in_graph = node_list[idx][i]
                _correct_objects_length += 1
                _node_diff_dict[0] += 1
            else:
                diff = cr_nodes - node_list[idx][i]
                if _node_diff_dict.get(diff) is None:
                    _node_diff_dict[diff] = 1
                else:
                    _node_diff_dict[diff] += 1
       
            gt_flatten = np.ceil(gt[idx][i].flatten())

            pred[idx][pred[idx] <= _PRED_TRESHOLD] = 0
            pred[idx][pred[idx] > _PRED_TRESHOLD] = 1

            cr_flatten = np.ceil(pred[idx][i].flatten())
            _roc_auc_score += roc_auc_score(gt_flatten, cr_flatten)
            _average_precision_score += average_precision_score(gt_flatten, cr_flatten)
    
    print(_node_diff_dict)
    print("Correct node length:", _correct_objects_length, "out of", graph_count, "graphs.", _correct_objects_length/num_matrix)
    print("Correct objects in place:", _correct_objects, "out of", _total_objects, "objects")
    print("Average ROC AUC SCORE:", _roc_auc_score/graph_count)
    print("Average precision score:", _average_precision_score/graph_count)

calc_auc_roc(cr_gt, cr_pred, node_list)

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Normalise
cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Normalized')
plt.show(block=False)

fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Number of predictions')
plt.show(block=False)