In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

import networkx as nx
import matplotlib.pyplot as plt
import torch_geometric as tg
import numpy as np
from pathlib import Path
from pathlib import PurePath
import os
import scipy.optimize
from torch.autograd import Variable

from xml.dom import minidom

from numpy import argmax
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder

from torch.utils.tensorboard import SummaryWriter

In [None]:
'''
Method to create one hot encoders

Input: Array with string
'''
def one_hot_string(map):
    values = np.array(map)

    label_encoder = LabelEncoder()
    integer_encoded = label_encoder.fit_transform(values)

    onehot_encoder = OneHotEncoder(sparse=False)
    integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
    onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
    
    return onehot_encoded

'''
    Get action name from one-hot-encoding
'''
def get_action_from_encoding(enc):
    idx = _actions.index(enc)
    return action_map[idx], idx

'''
    Create a dict with temporal information
'''
def createTemporalDict(in_dict):
    temporal_concat = {}
    
    for key in in_dict:
        temporal_concat[key] = {}
        for seq in in_dict[key]:
                    temporal_concat[key][seq] = [concatenateTemporal(in_dict[key][seq])]
    return temporal_concat

'''
    See if two nodes match
'''
def nodesMatch(n1, n_list):
    for i, node in enumerate(n_list, start=0):
        if node[1] == n1:
            return True, i

    return False, 0

'''
    Calculate the temporal edges
    # nodes will start at node_index_list.last + 1
'''
def calculateTemporalEdges(full_gr, nodes, index_list):
    temporal_rel = _relations[spatial_map.index("temporal")]
    old_nodes = []
    node_cnt = index_list[-1]


    for index in index_list:
        old_nodes.append(full_gr.nodes[index])

    for index in index_list:
        match, ind = nodesMatch(full_gr.nodes[index], nodes)
        if match:
            full_gr.add_edge(index, node_cnt, edge_attr=temporal_rel)
            node_cnt += 1

'''
    Function to create the temporal information between graphs.
    Input: list of graphs
'''
def concatenateTemporal(graphs):
    graph_nx = nx.DiGraph()
    graph_nx.graph["features"] = graphs[0].graph['features']
    node_cnt = 0
    node_index_list = []

    for i, graph in enumerate(graphs, start=0):

        if len(node_index_list) > 0:
            calculateTemporalEdges(graph_nx, graph.nodes(data=True), node_index_list)

        node_index_list = []

        for node in graph.nodes(data=True):
            graph_nx.add_node(node_cnt, x=node[1]['x'])
            node_index_list.append(node_cnt)
            node_cnt += 1
            
        for edge in graph.edges():
            graph_nx.add_edge(node_index_list[edge[0]], node_index_list[edge[1]], edge_attr=graph.get_edge_data(edge[0], edge[1])['edge_attr'])

    empty_list = []
    for node in graph_nx.nodes(data=True):
        if node[1] == {}:
            empty_list.append(node[0])

    for node in empty_list:
        graph_nx.remove_node(node)


    return graph_nx

In [None]:
# This needs to be hard-coded and in alphabetic order.
action_map  = ["chopping", "cutting", "hiding", "pushing", "putontop", "stirring", "takedown", "uncover"]
spatial_map = ["noconnection", "temporal", "touching"]
objects     = ['Apple', 'Arm', 'Ball', 'Banana', 'Body', 'Bowl', 'Box', 'Bread', 
               'Carrot', 'Chopper', 'Cucumber', 'Cup', 'Hand', 
               'Knife', 'Liquid', 'Pepper', 'Plate', 'Sausage', 'Slice', 'Spoon', 'null']

# Creates one hot encodings for actions, relations and objects.
_relations = one_hot_string(spatial_map).tolist()
_actions = one_hot_string(action_map).tolist()
_objects = one_hot_string(objects).tolist()

# Folder with MANIAC keyframes
_FOLDER = os.getcwd() + "/maniac/"

# If you have MANIAC data in different folder
_MANIAC_DATA_FOLDER = "/data/tmp/master_thesis/MANIAC/"

In [None]:
print("BOWL:", _objects[objects.index('Bowl')])
print("Box:", _objects[objects.index('Box')])
print("Hand:", _objects[objects.index('Hand')])

print("no connect:", _relations[spatial_map.index("noconnection")])

In [None]:
'''
Method to find all xml files.

Input: String of folder
'''
def find_xmls(path=_FOLDER):
    all_xmls = []
    for filename in Path(path).rglob('*.xml'):
        all_xmls.append(filename)
    return all_xmls

'''
Method to split xml into action and s

Input: Array with strings of xml 
'''
def parse_xml_info(xml):
    if isinstance(xml, Path):
        ss = str(PurePath(xml))
        if ss[-3:] != 'xml':
            raise ValueError("Must be XML! recieved {}".format(ss))
        
        path_split = xml.parts
        split = path_split[-2].split('_')
        action = "".join(split[:-1]).lower()
        seq = split[-1]
        
        return action, int(seq)
    else:
        raise ValueError("Must be XML! recieved {}".format(xml))
# USAGE
#all_xmls = find_xmls()
#parse_xml_info(_FOLDER)

In [None]:
import warnings
import xml.etree.ElementTree as ET
from glob import glob

'''
Boolean flag to skip edge connections.

If set to false:
    All spatial relations will be used.

If set to true:
    Will _ONLY_ use the spatial relation 'touching'
'''
_SKIP_CONNECTIONS = False


'''
A SEC parser

Create networkx graphs from MANIACs GraphML.xml files.

'''

class SECParser():
    
    def __init__(self):
        try: 
            import xml.etree.ElementTree as ET
        except ImportError:
            raise ImportError('We need xml.elementtree.ElementTree!')
        
        self.graph_dict = {}
        self.action_dict = {}
        self.actions = {}
        
    def __call__(self, path=None, action=None):
        if path is not None:
            self.xml = ET.parse(path)
            self.path = path
        else:
            raise ValueError("Must specify path!")
        
        # Get all XML with action and sequence
        self.action, self.seq = parse_xml_info(self.path)
        
        # Create the graphs
        self.create_graph(self.xml)
        
        # Returns all keyframes as a list
        return self.get_keyframes_as_list()
    
    '''
    
        Method to create networkx graph
    
    '''
    def create_graph(self, graph_element):
        root = graph_element.getroot()
        node_count = 0
        nodes_map = {}
        keep_graph = True
        replace_object_dict = {}
        
        # Objects in MANIAC GraphML is defined as they appear in the frame.
        # This helps to use correct item label for each node. 
        for root_replace_object in root.iter('ReplaceObject'):
            replace_objects = root_replace_object.findall('Object')
            
            # Creates a dict to store object id that and the new replaced values.
            for obj in replace_objects:
                replace_object_dict[int(obj.attrib['id'])] = obj.attrib['obj']
        
        for keyframe in root.iter('KeyFrame'):
            # Creates a Directed graph structure
            graph = nx.DiGraph()
            # Add graph features
            graph.graph["features"] = _actions[action_map.index(self.action)]
            graph.graph['seq'] = self.seq
            
            # Gets the action's identifier
            action_id = np.argmax(_actions[action_map.index(self.action)])
            self.actions[action_id] = action_map.index(self.action)
            
            # Add all nodes and edges to list.
            keyframe_id = keyframe.attrib['ID'] 
            nodes = keyframe.findall('Node')
            edges = keyframe.findall('Edge')
            
            skip_nodes = []
            
            for node in nodes:
                # This is not used yet, 
                # but for future development the position of the node may be used.
                pos_X = float(node.attrib['pos_X'])
                pos_Y = float(node.attrib['pos_Y'])
                pos_Z = float(node.attrib['pos_Z'])
                n_type = node.attrib['type']
                n_id = node.attrib['id']
                
                # Check if the node should be added to nodes_map and gives the node a new value.
                # The value of nodes need to be incremental order such as [0, 1, 2...., n]
                if nodes_map.get(n_id) is None:
                    if objects.index(replace_object_dict.get(int(n_id))) != objects.index('null'):
                        nodes_map[n_id] = node_count
                        node_count += 1
                    else:
                        skip_nodes.append(n_id)
                
                # Replace the nodes id with correct x feature.
                # If acceptable, the node is added to the graph with one hot encoding of the object.
                if int(n_id) in replace_object_dict:
                    if objects.index(replace_object_dict.get(int(n_id))) != objects.index('null'):
                        graph.add_node(nodes_map[n_id], x=_objects[objects.index(replace_object_dict.get(int(n_id)))])
                    
                else:
                    # If object not defined it will be called null,
                    # this is very bad with missing features.
                    graph.add_node(nodes_map[n_id], x=_objects[objects.index("null")])
                    print("This is very bad, missing features! Please look at:")
                    print("Action:", self.action, "seq:", self.seq, "keyframe:", keyframe_id, "node id:", n_id)
                
            # Check if the graph contains and edges. If not, the graph will be removed.
            if len(edges) == 0:
                print("[SECParser] NO edge found. Need at least 1 edge for GraphNet. Added dummy value")
                #graph.add_edge(0,0, features=_relations[spatial_map.index("dummy_value")])
                keep_graph = False
            
            # Adds the edge relationships between nodes.
            for edge in edges:
                target = edge.attrib['target']
                relation = (edge.attrib['relation']).lower()
                
                # decide if noconneciton should be a edge or not.
                if relation == 'noconnection' and _SKIP_CONNECTIONS:
                    continue
                source = edge.attrib['source']
                
                # Nodes that shall be removed should not have any edges. This prevents uncessary edges.
                if target in skip_nodes:
                    continue
                if source in skip_nodes:
                    continue
                
                
                graph.add_edge(nodes_map[target], nodes_map[source], edge_attr=_relations[spatial_map.index(relation)])
            
            _check_graph_remapping_nodes = check_graph(graph)
            # Check that node mapping is in the right way.
            if _check_graph_remapping_nodes:
                graph = nx.relabel_nodes(graph, _check_graph_remapping_nodes)
            
            # Adds the graph into a dict for the action id and corresponding sequence id.
            # Example: HIDING_SECs\Hiding_01\GraphML.xml
            # graph_dict['hiding'][1][...] will have the current graphs.
            if self.graph_dict.get(action_id) is None:
                if keep_graph:
                    self.graph_dict[action_id] = {}
                    self.graph_dict[action_id][self.seq] = [graph]
                else:
                    keep_graph = True
            else:
                if self.graph_dict[action_id].get(self.seq) is None:
                    if keep_graph:
                        self.graph_dict[action_id][self.seq] = [graph]
                    else:
                        keep_graph = True
                else:
                    if keep_graph:
                        self.graph_dict[action_id][self.seq].append(graph)
                    else:
                        print("[SECParser] Did not not add graph to list, due to 1 node. GraphNet requires at least 2 nodes.")
                        keep_graph = True
    
    def get_action(self, action):
        return self.actions[action]
                    
    def get_keyframes_as_list(self):
        return self.graph_dict
    
    def get_graph_dict(self):
        return self.graph_dict


def _check_key(node, key):
    return node != key

'''
    Method to check if graph have data features and with creates node id to be in right order.
    Usually only called internally from SECParser.
'''
def check_graph(graph_nx):
    map_dict = {}
    for node_i, (key, data) in enumerate(graph_nx.nodes(data=True)):
        if _check_key(node_i, key) and data['x'] is not None:
            map_dict[key] = node_i
    
    if len(map_dict) > 0:
        return map_dict

'''
 Method to create a big list of the action dict.
 
 Input: dict with action, seq and graphs.

'''
def create_big_list(input_dict):
    big_list = []
    manipulation_actions = list(input_dict.keys())
    
    for action in sorted(manipulation_actions):
        variations_of_manipulation = sorted(input_dict[action].keys())
        
        for variation in variations_of_manipulation:
            num_of_graphs = len(input_dict[action][variation])
            
            for graph in range(num_of_graphs):
                big_list.append(input_dict[action][variation][graph])
    
    return big_list

In [None]:
from torch.utils.data import Dataset

'''

    Creates MANIAC dataset to work with PyTorch Geometric.

'''
class MANIAC(Dataset):
    def __init__(self, root_dir, window, temporal=False):
        self.window = window
        self.root_dir = root_dir
        self.all_xmls = find_xmls(self.root_dir)
        self.sp = SECParser()
        self.temporal = temporal
        
        for xml in self.all_xmls:
            self.dict_with_graphs = self.sp(xml)
        
        self.samples = create_big_list(self.dict_with_graphs)
    
    def __len__(self):
        return len(self.samples) - self.window

    def __getitem__(self, idx):

        if self.window > 0:
            
            x = self.samples[idx:idx+self.window]
        
            current_action = self.samples[idx].graph['features']
            step_back = 0

            for i in range(self.window):
                if self.samples[idx+i].graph['features'] != current_action:
                    step_back += 1

            if step_back > 0:
                x = self.samples[idx-step_back:idx+self.window-step_back]
            else:
                x = self.samples[idx:idx+self.window]
        else:
            x = self.samples[idx]
            
        if self.temporal:
            return concatenateTemporal(x)
        else:
            return x

In [None]:
'''

    This is used to create new datasets.
    
    Example:
    train_set = MANIAC( "FOLDER_TO_MANIAC_GRAPHML", window, temporal)
    
    if window = 0, 1 graph will be made.
    if window > 0, x number of graphs will be created together
    
    temporal = False, will not add temporal information between nodes
    if tempral = True, will add temporal information between nodes

'''
_SAVE_RAW_DATA = False
_CREATE_DATASET = False

_TIME_WINDOW = 4
_TEMPORAL = True

if _CREATE_DATASET:
    train_set = MANIAC(_MANIAC_DATA_FOLDER + "training/", _TIME_WINDOW, _TEMPORAL)
    val_set = MANIAC(_MANIAC_DATA_FOLDER + "validation/", _TIME_WINDOW, _TEMPORAL)
    test_set = MANIAC(_MANIAC_DATA_FOLDER + "testing/", _TIME_WINDOW, _TEMPORAL)


# USED TO SAVE RAW TRAINING DATA
if _SAVE_RAW_DATA:
    with open(os.path.join(_FOLDER + "raw/training_" + str(_TIME_WINDOW) + "w.pt"), 'wb') as f:
                torch.save(train_set, f)

    with open(os.path.join(_FOLDER + "raw/validation_" + str(_TIME_WINDOW) + "w.pt"), 'wb') as df:
                torch.save(val_set, df)

    with open(os.path.join(_FOLDER + "raw/test_" + str(_TIME_WINDOW) + "w.pt"), 'wb') as df:
                torch.save(test_set, df)

In [None]:
import torch
from torch_geometric.data import DataLoader
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar)

'''

    Creates MANIAC dataset to work with PyTorch Geometric.
    
    Read more about it at
    https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

'''
class ManiacDS(InMemoryDataset):
    def __init__(self,
                root,
                dset="train",
                transform=None):
        super(ManiacDS, self).__init__(root, transform)
        
        if dset == "train":
            path = self.processed_paths[0]
        elif dset == "valid":
            path = self.processed_paths[1]
        else:
            path = self.processed_paths[2]

        self.data, self.slices = torch.load(path)
        
    @property
    def raw_file_names(self):
        return ['training_' + str(_TIME_WINDOW) + 'w.pt', 'validation_' + str(_TIME_WINDOW) + 'w.pt', 'test_' + str(_TIME_WINDOW) + 'w.pt']
    
    @property
    def processed_file_names(self):
        return ['training_' + str(_TIME_WINDOW) + 'w.pt', 'validation_' + str(_TIME_WINDOW) + 'w.pt', 'test_' + str(_TIME_WINDOW) + 'w.pt']
    
    def download(self):
        return
    
    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
    
    def process(self):
        big_slices = []
        for raw_path, path in zip(self.raw_paths, self.processed_paths):
            big_data = []
            graphs = torch.load(raw_path)
            
            # Creates torch_geometric data from networkx graphs
            # https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs
            for graph in graphs:
                G = nx.convert_node_labels_to_integers(graph)
                G = G.to_directed() if not nx.is_directed(G) else G
                edge_index = torch.tensor(list(G.edges)).t().contiguous()

                data = {}

                for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
                    for key, value in feat_dict.items():
                        data[key] = [value] if i == 0 else data[key] + [value]

                for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
                    for key, value in feat_dict.items():
                        data[key] = [value] if i == 0 else data[key] + [value]

                for key, item in data.items():
                    try:
                        data[key] = torch.tensor(item)
                    except ValueError:
                        pass
                
                # Creates the tg data
                data['edge_index'] = edge_index.view(2, -1)
                data = tg.data.Data.from_dict(data)
                data.y = torch.tensor(graph.graph['features'])
                
                # This is not used, can be useful in future development if the sequence id is needed.
                #data.seq = torch.tensor([graph.graph['seq']])
                
                if _SKIP_CONNECTIONS:
                    if data.edge_attr is not None:
                        big_data.append(data)
                else:
                    big_data.append(data)
            
            for graph in big_data:
                if graph.edge_attr is None:
                    print(graph)
                    print(graph.edge_attr)
                    print(action_map[graph.y.argmax().item()])
                    break
            
            torch.save(self.collate(big_data), path)

In [None]:
train_dataset = ManiacDS(_FOLDER, "train")
test_ds = ManiacDS(_FOLDER, "test")
valid_ds = ManiacDS(_FOLDER, "valid")

print(len(train_dataset)+len(test_ds)+len(valid_ds))
print(len(test_ds))
print(len(valid_ds))

In [None]:
count = 0
max_count = 1
_bs = 1

train_loader = DataLoader(train_dataset, batch_size=_bs, shuffle=True, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=_bs, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=_bs, shuffle=True, drop_last=True)

In [None]:
from torch_geometric.utils import to_dense_batch, to_dense_adj, to_scipy_sparse_matrix
from torch_scatter import scatter_add

'''
    Creates a dense adjacency matrix from pg data.
    With normalization
'''
def to_dense_adj_max_node(edge_index, x, edge_attr, batch=None, max_node=None):
    if batch is None:
        batch = edge_index.new_zeros(edge_index.max().item() + 1)
    
    batch_size = batch[-1].item() + 1
    
    if max_node is None:
        max_num_nodes = num_nodes.max().item()
    else:
        max_num_nodes = max_node
    
    one = batch.new_ones(batch.size(0))
    num_nodes = scatter_add(one, batch, dim=0, dim_size=batch_size)
    cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
    
    size = [batch_size, max_num_nodes, max_num_nodes]
    
    size = size
    dtype = torch.float
    
    adj = torch.zeros(size, dtype=dtype, device=edge_index.device)

    edge_index_0 = batch[edge_index[0]].view(1, -1)
    edge_index_1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    edge_index_2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
    
    # Normalize the edges on the length of pre-defined spatial objects.
    _ea = []
    for ea in edge_attr:
        _ea.append((ea[0].item()+1)/len(spatial_map))

    adj[edge_index_0, edge_index_1, edge_index_2] = torch.FloatTensor(_ea).cuda()

    # Normalize objects in the diagonal
    objects_sum = x.argmax(dim=1).type(dtype)
    objects_sum = objects_sum/len(_objects)

    object_size = [batch_size, num_nodes.max().item()]
    object_mat = torch.zeros(object_size, dtype=dtype, device=edge_index.device)
    
    obj_offset = 0
    
    # Creates the adjacency matrix of size [max_node*max_node].
    for b in range(batch_size):
        temp = torch.zeros(num_nodes.max().item(), dtype=dtype, device=edge_index.device)
        _obj = objects_sum[obj_offset:(obj_offset+num_nodes[b])].type(dtype)
        obj_offset += num_nodes[b]
        dd = torch.cat((_obj, temp), dim=0)
        dd = dd[:num_nodes.max().item()]
        object_mat[b] = dd
            
    adj.as_strided(object_mat.size(), [adj.stride(0), adj.size(2) + 1]).copy_(object_mat)
    
    # Error check to see that the diagonal is not zero.
    for i in range(len(adj)):
        if torch.diag(adj[i]).sum().item() == 0:
            print("ERROR! ZERO DIAGONAL!", i)

    return adj

In [None]:
from torch_geometric.utils import to_dense_batch, to_dense_adj, to_scipy_sparse_matrix
from torch_scatter import scatter_add

'''
    Creates a dense adjacency matrix from pg data.
    Without normalization
'''
def to_dense_adj_max_node(edge_index, x, edge_attr, batch=None, max_node=None):
    if batch is None:
        batch = edge_index.new_zeros(edge_index.max().item() + 1)
    
    batch_size = batch[-1].item() + 1
    
    if max_node is None:
        max_num_nodes = num_nodes.max().item()
    else:
        max_num_nodes = max_node
    
    one = batch.new_ones(batch.size(0))
    num_nodes = scatter_add(one, batch, dim=0, dim_size=batch_size)
    cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
    
    size = [batch_size, max_num_nodes, max_num_nodes]
    
    size = size
    dtype = torch.float
    
    adj = torch.zeros(size, dtype=dtype, device=edge_index.device)

    edge_index_0 = batch[edge_index[0]].view(1, -1)
    edge_index_1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    edge_index_2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
    
    # Normalize the edges on the length of pre-defined spatial objects.
    _ea = []
    for ea in edge_attr:
        _ea.append(1)

    adj[edge_index_0, edge_index_1, edge_index_2] = torch.FloatTensor(_ea).cuda()

    # Normalize objects in the diagonal
    objects_sum = x.argmax(dim=1).type(dtype)
    objects_sum[objects_sum > 0] = 1

    object_size = [batch_size, num_nodes.max().item()]
    object_mat = torch.zeros(object_size, dtype=dtype, device=edge_index.device)
    
    obj_offset = 0
    
    # Creates the adjacency matrix of size [max_node*max_node].
    for b in range(batch_size):
        temp = torch.zeros(num_nodes.max().item(), dtype=dtype, device=edge_index.device)
        _obj = objects_sum[obj_offset:(obj_offset+num_nodes[b])].type(dtype)
        obj_offset += num_nodes[b]
        dd = torch.cat((_obj, temp), dim=0)
        dd = dd[:num_nodes.max().item()]
        object_mat[b] = dd
            
    adj.as_strided(object_mat.size(), [adj.stride(0), adj.size(2) + 1]).copy_(object_mat)
    
    # Error check to see that the diagonal is not zero.
    for i in range(len(adj)):
        if torch.diag(adj[i]).sum().item() == 0:
            print("ERROR! ZERO DIAGONAL!", i)

    return adj

In [None]:
max_num_nodes_train = max([len(i.x) for i in train_dataset])
print("max in train:", max_num_nodes_train)

max_num_nodes_valid = max([len(i.x) for i in valid_ds])
print("max in train:", max_num_nodes_valid)

max_num_nodes_test = max([len(i.x) for i in test_ds])
print("max in train:", max_num_nodes_test)

max_num_nodes = max(max_num_nodes_test, max_num_nodes_train, max_num_nodes_valid)
print("MAX:", max_num_nodes)

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""

    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [None]:
from torch.nn import Sequential, Linear, ReLU, ELU
from torch_geometric.nn import NNConv, radius_graph, fps, global_mean_pool
from torch_scatter import scatter_mean
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import softmax
from torch_geometric.nn import BatchNorm

input_size=len(action_map) # static value
output_size=len(action_map) # static value
channels = 64

_decoder_in = 32

_PRINT = False # Used for debugging

class Encoder(torch.nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin = torch.nn.Linear(len(objects), channels)

        nn = Sequential(Linear(len(spatial_map), 64), ReLU(), Linear(64, channels * channels*2 ))

        self.conv1 = NNConv(channels, channels*2, nn, aggr='mean') 
        self.bn1 = BatchNorm(channels*2)
        self.conv2 = NNConv(channels*2, channels, nn, aggr='mean')
        
        nn2 = Sequential(Linear(len(spatial_map), 64), ReLU(), Linear(64, _decoder_in * channels*2 ))
        self.mu = NNConv(channels*2, _decoder_in, nn2, aggr='max')
        self.logvar = NNConv(channels*2, _decoder_in, nn2, aggr='max')
    
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        out = F.relu(self.lin(data.x)) # Fully connected
        
        # Action prediction
        hidden = F.relu(self.conv1(out, data.edge_index, data.edge_attr)) # conv1
        hidden = self.bn1(hidden)
        conv2_out = F.relu(self.conv2(hidden, data.edge_index, data.edge_attr)) # conv2
        
        # Reconstruction
        mu = self.mu(hidden, data.edge_index, data.edge_attr)
        logvar = self.logvar(hidden, data.edge_index, data.edge_attr)
        
        mu = scatter_mean(mu, batch, dim=0)
        logvar = scatter_mean(logvar, batch, dim=0)
        
        p_x = scatter_mean(conv2_out, batch, dim=0)
        
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        
        if self.training:
            return std * eps + mu, logvar, mu, p_x
        else:
            return mu, logvar, mu, p_x

class Predictor(torch.nn.Module):
    def __init__(self, input_size, seq_len, hidden_size, n_layers):
        super(Predictor, self).__init__()
        
        self.prev_hidden = None
        self.bs = _bs
        self.input_size = input_size
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        self.lstm = torch.nn.LSTM(self.input_size, self.hidden_size, self.n_layers, dropout=0.1, batch_first=True)
        self.lin1 = torch.nn.Linear(self.hidden_size, len(action_map))
    
    def forward(self, p_x):
        
        if self.prev_hidden is None:
            self.prev_hidden = (torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda(),
                                torch.zeros(self.n_layers, self.bs, self.hidden_size).cuda())


        input_reshape = p_x.reshape( self.bs, self.seq_len, -1 ).to(device)

        q, h = self.lstm( input_reshape , self.prev_hidden )
        
        self.prev_hidden = repackage_hidden(h)
        
        out = self.lin1(q[:, -1, :])
        
        return out

class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(_decoder_in, 64)
        self.fc2 = nn.Linear(64, max_num_nodes*max_num_nodes)
    
    def forward(self, z_x):

        out = F.relu(self.fc1(z_x))
        out = self.fc2(out)
        out = torch.sigmoid(out)
        
        return out


class djNet(torch.nn.Module):
    def __init__(self):
        super(djNet, self).__init__()
        
        self.encoder = Encoder()
        self.predictor = Predictor(8, 8, 8, 4)
        self.decoder = Decoder()
    
    def forward(self, x):
        z, logvar, mu, p_x = self.encoder(x)
        p_z = self.predictor(p_x)
        q_z = self.decoder(z)
        
        return q_z, logvar, mu, z, p_z


In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = djNet().to(device)

print(model)

In [None]:
ap_criterion = nn.CrossEntropyLoss()

def loss_criterion(inputs, targets, logvar, mu, ap_inputs, ap_targets):
    # Reconstruction loss
    bce_loss = F.binary_cross_entropy(inputs, targets, reduction="sum")
    
    # Action prediction loss
    ap_loss = ap_criterion(ap_inputs, ap_targets)

    # Regularization term
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())    

    return bce_loss + kl_loss, ap_loss

In [None]:
from torch.utils.tensorboard import SummaryWriter

#writer = SummaryWriter(comment="MANIAC_4w_dropout02")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(loader, model):
    model.train()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data)
        prediction = y_hat.view(_bs, -1, max_num_nodes)

        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes).cuda()
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)
        net_loss = recon_loss * 0.7 + ap_loss

        # Compute gradients and updates weights.
        net_loss.backward()
        optimizer.step()

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs)

def test(loader, model):
    model.eval()
    
    recon_loss_all = 0
    ap_loss_all = 0
    correct = 0
    
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)

        y_hat, logvar, mu, _, y_ap = model(data)
        prediction = y_hat.view(_bs, -1, max_num_nodes)

        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes).cuda()
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)
        
        pred = y_ap.max(1)[1]

        # Compute loss
        recon_loss, ap_loss = loss_criterion(prediction, target_adj, logvar, mu, y_ap, y_ap_true)

        recon_loss_all += recon_loss.item()
        ap_loss_all += ap_loss.item()
        correct += pred.eq(y_ap_true).sum().item()
    
    return recon_loss_all/(len(loader)*_bs), ap_loss_all/(len(loader)*_bs), correct/(len(loader)*_bs)

        
def super_test(loader, model):
    model.eval()
    ap_predict_list = np.array([])
    ap_gt_list = np.array([])
    
    reconstruct_predict_list = []
    reconstruct_gt_list = []
    num_nodes_list = []

    for data in loader:
        data = data.to(device)
        batch_size = data.batch[-1].item() + 1

        y_hat, logvar, mu, _, y_ap = model(data)
        pred = y_ap.max(1)[1]
        
        batch_size = data.batch[-1].item() + 1

        one = data.batch.new_ones(data.batch.size(0))
        num_nodes = scatter_add(one, data.batch, dim=0, dim_size=batch_size)
        num_nodes_list.append(num_nodes.cpu().detach().numpy())
        
        # Creating targets
        target_adj = to_dense_adj_max_node(data.edge_index, data.x, data.edge_attr, data.batch, max_num_nodes).cuda()
        target = data.y.view(_bs, -1)
        y_ap_true = target.argmax(axis=1)
        
        prediction = y_hat.view(_bs, -1, max_num_nodes)
        ap_predict_list = np.append(ap_predict_list, pred.cpu().detach().numpy())
        ap_gt_list = np.append(ap_gt_list, y_ap_true.cpu().detach().numpy())

        gr_pred = prediction.cpu().detach().numpy()
        gr_gt = target_adj.cpu().detach().numpy()
        
        reconstruct_predict_list.append(gr_pred)
        reconstruct_gt_list.append(gr_gt)                
        
    return reconstruct_predict_list, reconstruct_gt_list, ap_predict_list, ap_gt_list, num_nodes_list

In [None]:
save_counter = 0
for epoch in range(1, 5001):
    
    train_recon_loss, train_ap_loss = train(train_loader, model)
    _, _, train_ap_acc = test(train_loader, model)
    validation_recon_loss, validation_ap_loss, validation_ap_acc = test(valid_loader, model)    

    
    if False:
        writer.add_scalar('AP_Acc/train', train_ap_acc, epoch)
        writer.add_scalar('AP_Acc/validation', validation_ap_acc, epoch)

        writer.add_scalar('Recon_Loss/train', train_recon_loss, epoch)
        writer.add_scalar('Recon_Loss/validation', validation_recon_loss, epoch)

        writer.add_scalar('AP_Loss/train', train_ap_loss, epoch)
        writer.add_scalar('AP_Loss/validation', validation_ap_loss, epoch)

    
    print("Epoch {:02d}, [T] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}% [V] RLoss: {:.2f}, APLoss: {:.4f}, Acc: {:.2f}%".format( epoch, 
                                                                           train_recon_loss,
                                                                           train_ap_loss,
                                                                           train_ap_acc*100,
                                                                           validation_recon_loss,
                                                                           validation_ap_loss,
                                                                           validation_ap_acc*100,
                                                                            ))


In [None]:
_SAVE = False
if _SAVE:
    torch.save({
                'epoch': 99,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': 4.110431,
                }, "./MANIAC_final_models/MANIAC_final_4w_dim_64_increase_mu_64.pt")
    print("SAVED!")
else:
    checkpoint = torch.load("./FINAL_RESULTS/2w/maniac_416.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print("LOADED MODEL")

In [None]:
_CURRENT_LOADER = test_loader
test_recon_loss, test_ap_loss, test_ap_acc = test(_CURRENT_LOADER, model)
print("test recon loss:", test_recon_loss)
print("test ap loss:", test_ap_loss)
print("test ap acc:", test_ap_acc)

In [None]:
cr_pred, cr_gt, ap_pred, ap_target, node_list = super_test(_CURRENT_LOADER, model)

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

cr_pred, cr_gt, ap_pred, ap_target, node_list = super_test(_CURRENT_LOADER, model)

cm = confusion_matrix(ap_target, ap_pred)
print(cm)

print(classification_report(ap_target, ap_pred, target_names=action_map, labels=[0,1,2,3,4,5,6,7]))

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score

_PRED_TRESHOLD = 0.3

def calc_auc_roc(gt, pred, node_list):

    if len(gt) != len(pred):
        raise Exception("Invalid length of ground truth and prediction! GT {:d} and pred {:d}".format(len(gt), len(pred))) 
    
    num_matrix = len(gt)
    
    _correct_node_length = 0
    _correct_objects_length = 0
    _correct_objects = 0
    _total_objects = 0
    _roc_auc_score = 0
    _max_nodes_in_graph = 0
    _average_precision_score = 0
    graph_count = 0
    
    _node_diff_dict = {0: 0}
    
    for idx in range(num_matrix):
        for i in range(len(gt[idx])):
            gt_diag = np.round(np.diag(gt[idx][i]))
            cr_diag = np.round(np.diag(pred[idx][i]))
            cr_nodes = np.count_nonzero(cr_diag)
            graph_count += 1
        
            if node_list[idx][i] == cr_nodes:
                _max_nodes_in_graph = node_list[idx][i]
                _correct_objects_length += 1
                _node_diff_dict[0] += 1
            else:
                diff = cr_nodes - node_list[idx][i]
                if _node_diff_dict.get(diff) is None:
                    _node_diff_dict[diff] = 1
                else:
                    _node_diff_dict[diff] += 1
       
            gt_flatten = np.ceil(gt[idx][i].flatten())

            pred[idx][pred[idx] <= _PRED_TRESHOLD] = 0
            pred[idx][pred[idx] > _PRED_TRESHOLD] = 1

            cr_flatten = np.ceil(pred[idx][i].flatten())
            _roc_auc_score += roc_auc_score(gt_flatten, cr_flatten)
            _average_precision_score += average_precision_score(gt_flatten, cr_flatten)
    
    print(_node_diff_dict)
    print("Correct node length:", _correct_objects_length, "out of", graph_count, "graphs.", _correct_objects_length/num_matrix)
    print("Correct objects in place:", _correct_objects, "out of", _total_objects, "objects")
    print("Average ROC AUC SCORE:", _roc_auc_score/graph_count)
    print("Average precision score:", _average_precision_score/graph_count)

calc_auc_roc(cr_gt, cr_pred, node_list)

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Normalise
cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=True, fmt='.2f', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.yticks(rotation=0)
plt.title('Normalized')
plt.show(block=False)

fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=action_map, yticklabels=action_map, cmap="YlGnBu")
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Number of predictions')
plt.show(block=False)