# Visualizing learning process as GIF

In [13]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt 
import imageio
import networkx as nx  

In [80]:
import sys
sys.path.append("../")
from causal_graphs.graph_definition import CausalDAG
from causal_graphs.graph_visualization import visualize_graph

In [82]:
# BASE_PATH = "checkpoints/2021_04_29__15_39_17/"
BASE_PATH = "checkpoints/2021_04_29__15_50_42/"

In [83]:
vis_folder = BASE_PATH + "visualizations/"
os.makedirs(vis_folder, exist_ok=True)

In [84]:
graph = CausalDAG.load_from_file(BASE_PATH + "graph_1.pt")
visualize_graph(graph, filename=vis_folder + "graph_1.pdf", layout="circular", figsize=(5,5))

In [85]:
gamma = np.load(BASE_PATH + "gamma_log_1_GraphDiscoveryMatrix.npz")["arr_0"].astype(np.float32)
theta = np.load(BASE_PATH + "theta_matrix_log_1_GraphDiscoveryMatrix.npz")["arr_0"].astype(np.float32)

In [86]:
probs = 1/((1 + np.exp(-gamma))*(1 + np.exp(-theta)))

In [87]:
G = nx.DiGraph()
G.add_nodes_from([v.name for v in graph.variables])
G.add_edges_from([(graph.variables[i].name,graph.variables[j].name) for j in range(graph.num_vars) for i in range(graph.num_vars) if i!=j])

In [88]:
filenames = []
for i in range(0, probs.shape[0], 20):
    edge_colors = probs[i].reshape(-1)
    edge_colors = edge_colors[(np.arange(edge_colors.shape[0])%probs.shape[1])!=(np.arange(edge_colors.shape[0])//probs.shape[2])]
    edge_colors = np.concatenate([np.zeros((edge_colors.shape[0],3), dtype=np.float32), edge_colors[...,None]], axis=-1)
    
    fig = plt.figure(figsize=(5,5), dpi=100)
    nx.draw(G, pos=nx.circular_layout(G),
            arrows=True, with_labels=True, font_weight='bold', node_color='lightgrey', 
            edge_color=edge_colors, edgecolors='black', node_size=600, arrowstyle='-|>', arrowsize=16)
    filenames.append(vis_folder + "figure_%s.png" % (str(i).zfill(5)))
    plt.savefig(filenames[-1])
    plt.close()

In [89]:
with imageio.get_writer(vis_folder + 'learning_process.gif', mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

In [90]:
for filename in set(filenames):
    os.remove(filename)