## Data preprocessing

In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path

# Parse data file
def file_to_matrix(filename):
    m = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            l = []
            for num in line.split(' '):
                try:
                    l.append(float(num))
                except:
                    pass
            m.append(l)
    return np.asarray(m)

# Scan dataset dir and parse files
def dataset_to_matrices(folder):
    heatmaps = []
    paths = sorted(Path(folder).iterdir(), key=os.path.getmtime)

    for path in paths:
        path = str(path)
        if (path.endswith(".out") or path.endswith(".out")) and os.path.isfile(path):
            m = file_to_matrix(path)
            
            heatmaps.append(m)
    return np.asarray(heatmaps)

In [2]:
heatmaps = dataset_to_matrices('dataset/')

In [3]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

def matrix_to_graph(m):

    # Step size to move to row
    u = []
    v = []
    STEP_SIZE = m.shape[1]

    x = 0
    for row in range(m.shape[0]):
        for col in range(m.shape[1]):
            neighbours = []
        
            # LEFT
            if col > 0: neighbours.append(x - 1)

            # RIGHT
            if col < STEP_SIZE-1: neighbours.append(x + 1)

            # UP
            if row > 0: neighbours.append(x - STEP_SIZE)

            # DOWN
            if row < STEP_SIZE-1: neighbours.append(x + STEP_SIZE)

            # Creating edges
            for n in neighbours:
                u.append(int(x))
                v.append(int(n))
            
            x += 1

    # Graph creation
    edges = torch.tensor(u), torch.tensor(v)
    return dgl.graph(edges) #.to('cuda:0')

def matrix_to_node_features(m):
    return torch.tensor([[c] for c in m.flatten()]).float()

Using backend: pytorch


In [4]:
matrix_to_graph(heatmaps[0])

Graph(num_nodes=260100, num_edges=1038360,
      ndata_schemes={}
      edata_schemes={})

In [5]:
G = matrix_to_graph(heatmaps[0])

In [6]:
node_features = matrix_to_node_features(heatmaps[0])

In [7]:
node_features

tensor([[1.0000],
        [1.0000],
        [1.0000],
        ...,
        [0.9600],
        [0.9700],
        [0.9900]])

## GNN Construction

In [12]:
# Creating thge GCN Functions

# Message passing function: from features -> aggregate -> message
gcn_msg = fn.copy_src(src='h', out='m')

# Aggregation function (reduce): from all messages -> sum -> compute nodes
gcn_reduce = fn.sum(msg='m', out='h')

In [8]:
# Graph Convolutional Layer
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()

        # simple linear layer
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
        with g.local_scope():
            # Stores features data into 'h'
            g.ndata['h'] = feature
            # Apply graph convolution
            g.update_all(gcn_msg, gcn_reduce)
            # Obtain final node features
            h = g.ndata['h']
            # Apply linear layer
            return self.linear(h)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # Using GCNLayers
        # 1 node feature (each node)
        self.layer1 = GCNLayer(1, 16)
        self.layer2 = GCNLayer(16, 1)

    def forward(self, g, features):

        # Continuous output
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)
        return x

net = Net().to('cuda:0')

## It's training time!

In [13]:
import time 

optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

# 1 shot
for epoch in range(1):
    
    dur = []
    for i in range(heatmaps.shape[0] - 1):

        if i >= 3:
            t0 = time.time()

        net.train()

        h = heatmaps[i]
        h_next = heatmaps[i+1]

        # h -> input
        g = matrix_to_graph(h).to('cuda:0')
        features = matrix_to_node_features(h).to('cuda:0')

        y_hat = net(g, features)

        # h_next -> prediction
        h_next = torch.tensor([[c] for c in h_next.flatten()]).float().to('cuda:0')

        # Compute mse_loss
        loss = F.mse_loss(y_hat, h_next)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i >= 3:
            dur.append(time.time() - t0)

        print("*************\nEpoch {:05d} | Heatmap {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
                epoch, i, loss.item(), np.mean(dur)))

        #print("{}\n{}".format(y_hat, h_next))

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
*************
Epoch 00000 | Heatmap 00000 | Loss 0.9388 | Time(s) nan
*************
Epoch 00000 | Heatmap 00001 | Loss 0.3585 | Time(s) nan
*************
Epoch 00000 | Heatmap 00002 | Loss 0.0637 | Time(s) nan
*************
Epoch 00000 | Heatmap 00003 | Loss 0.0222 | Time(s) 0.8010
*************
Epoch 00000 | Heatmap 00004 | Loss 0.1308 | Time(s) 0.8048
*************
Epoch 00000 | Heatmap 00005 | Loss 0.2447 | Time(s) 0.7932
*************
Epoch 00000 | Heatmap 00006 | Loss 0.2698 | Time(s) 0.8005
*************
Epoch 00000 | Heatmap 00007 | Loss 0.2580 | Time(s) 0.8007
*************
Epoch 00000 | Heatmap 00008 | Loss 0.1422 | Time(s) 0.8078
*************
Epoch 00000 | Heatmap 00009 | Loss 0.0679 | Time(s) 0.8058
*************
Epoch 00000 | Heatmap 00010 | Loss 0.0156 | Time(s) 0.8099
*************
Epoch 00000 | Heatmap 00011 | Loss 0.0049 | Time(s) 0.8065
*************
Epoch 00000 | Heatmap 00012 | L

## Plotting the results (real vs. prediction)

In [22]:
x = 0
for heatmap in heatmaps:

    # Real
    #plt.matshow(heatmaps[x+1], cmap='hot', label='real')
    plt.imsave(f"animation/real/heat{x:03d}.png", heatmaps[x+1], cmap='hot')
    #plt.title("Real")

    # Prediction
    h = heatmaps[x]

    g = matrix_to_graph(h).to('cuda:0')
    features = matrix_to_node_features(h).to('cuda:0')
    h_hat = net(g, features).detach().cpu().clone().numpy()
    h_hat = h_hat.reshape(-1, 510)


    #plt.matshow(h_hat, cmap='hot', label='prediction')
    plt.imsave(f"animation/prediction/heat{x:03d}.png", h_hat, cmap='hot')
    #plt.title("Prediction")

    x += 1
    if x == heatmaps.shape[0]-2:
        break