In [None]:
import os
import re
import json
import numpy as np
import xarray as xr
import collections
import networkx as nx
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from ipywidgets import interact
%matplotlib inline

In [None]:
# Choose chemical mechanism
model_name = 'MCM_C1_C2_ss'
exp_name = 'init'

In [None]:
# Load reaction descriptions
eqs_json_path = os.path.join(os.path.split(os.getcwd())[0], 'MCM', model_name, 'data', model_name, 'eqnjson')
with open(os.path.join(eqs_json_path, model_name+'.json'), 'r') as f:
    all_mcm_eqs = json.load(f)
# Read time series and flux data
ts = xr.open_dataset(os.path.join(os.path.split(os.getcwd())[0], 'processed', model_name+'_'+exp_name+'.nc'))
for var in ts.data_vars:
    ts[var].values[ts[var].values < 0] = 0. # why negative values are there?
fl = xr.open_dataset(os.path.join(os.path.split(os.getcwd())[0], 'processed', 'flux', 'flux_'+model_name+'_'+exp_name+'.nc'))
for var in fl.data_vars:
    fl[var].values[fl[var].values < 0] = 0. # why negative values are there?

In [None]:
# Convert equation's info to nodes, edges and edge labels needed for networkx
nodes = ['N2'] # nodes
links = [] # edges
link_labels = collections.OrderedDict() # edge labels
flux_labels = collections.OrderedDict()
major_reactants = ['CL', 'H2', 'HO2', 'NO', 'NO2', 'NO3', 'OH', 'SO2', 'SO3'] # 'arrows', no nodes created
for eq in all_mcm_eqs:
    if len(eq['reac']) == 1:
        if 'J' in eq['coef'] and len(eq['prod']) == 1:
            flux_label = eq['reac'][0]+'+hv='+eq['prod'][0]
            link = (eq['reac'][0], eq['prod'][0])
            if flux_label in fl.data_vars:
                links.append(link)
                link_labels[link] = 'hv'
                flux_labels.setdefault(flux_label, []).append(link)
        elif 'J' in eq['coef'] and len(eq['prod']) != 1:
            flux_label = eq['reac'][0]+'+hv='+'+'.join(eq['prod'])
            for prod in eq['prod']:
                link = (eq['reac'][0], prod)
                if flux_label in fl.data_vars:
                    links.append(link)
                    link_labels[link] = 'hv'
                    flux_labels.setdefault(flux_label, []).append(link)
        elif 'J' not in eq['coef'] and len(eq['prod']) == 1:
            flux_label = eq['reac'][0]+'='+eq['prod'][0]
            link = (eq['reac'][0], eq['prod'][0])
            if flux_label in fl.data_vars:
                links.append(link)
                link_labels[link] = ''
                flux_labels.setdefault(flux_label, []).append(link)
        elif 'J' not in eq['coef'] and len(eq['prod']) != 1:
            flux_label = eq['reac'][0]+'='+'+'.join(eq['prod'])
            for prod in eq['prod']:
                link = (eq['reac'][0], prod)
                if flux_label in fl.data_vars:
                    links.append(link)
                    link_labels[link] = ''
                    flux_labels.setdefault(flux_label, []).append(link)
    else: # len(eq['reac']) == 2:
        reac1, reac2 = eq['reac']
        if reac1 in major_reactants and reac2 not in major_reactants:
            reac1, reac2 = reac2, reac1
        if len(eq['prod']) == 1:
            flux_label = '+'.join(eq['reac'])+'='+eq['prod'][0]
            for reac in eq['reac']:
                link = (reac, eq['prod'][0])
                if flux_label in fl.data_vars:
                    links.append(link)
                    link_labels[link] = reac2
                    flux_labels.setdefault(flux_label, []).append(link)
        elif len(eq['prod']) != 1:
            flux_label = '+'.join(eq['reac'])+'='+'+'.join(eq['prod'])
            for reac in eq['reac']:
                for prod in eq['prod']:
                    link = (reac, prod)
                    if flux_label in fl.data_vars:
                        links.append(link)
                        link_labels[link] = reac2
                        flux_labels.setdefault(flux_label, []).append(link)
    for reac in eq['reac']:
        if reac not in nodes:
            nodes.append(reac)
    for prod in eq['prod']:
        if prod not in nodes:
            nodes.append(prod)

In [None]:
# Find min and max number densities and fluxes
ts_mins = []
ts_maxs = []
fl_mins = []
fl_maxs = []
for var in ts.data_vars:
    ts_mins.append(ts[var].values.min())
    ts_maxs.append(ts[var].values.max())
for var in fl.data_vars:
    fl_mins.append(fl[var].values.min())
    fl_maxs.append(fl[var].values.max())
ts_min = min(ts_mins)
ts_max = max(ts_maxs)
fl_min = min(fl_mins)
fl_max = max(fl_maxs)

In [None]:
# Put sampled number densities into bins in order to show them by varying node size
node_size_nden_bnds = np.logspace(ts_min, round(np.log10(ts_max)+1), 11, endpoint=True)-1 # 11 bounds
node_size_nden_bins = np.array([*zip(node_size_nden_bnds[:-1], node_size_nden_bnds[1:])]) # 10 bins
node_size_plot_bins = collections.OrderedDict()
for var in ts.data_vars:
    idx = np.empty((len(ts.xrun), len(ts.yrun), len(ts.time)), dtype=int)
    for i, (bottom, top) in enumerate(node_size_nden_bins):
        inds = (bottom <= ts[var].values) & (ts[var].values < top)
        idx[inds] = i
    node_size_plot_bins[var] = idx

In [None]:
# Put sampled fluxes into bins in order to show them by varying edge width
edge_wdth_flux_bnds = np.logspace(fl_min, round(np.log10(fl_max)+1), 11, endpoint=True)-1 # 11 bounds
edge_wdth_flux_bins = np.array([*zip(edge_wdth_flux_bnds[:-1], edge_wdth_flux_bnds[1:])]) # 10 bins 
edge_wdth_plot_bins = collections.OrderedDict()
for var in fl.data_vars:
    idx = np.empty((len(fl.xrun), len(fl.yrun), len(fl.time)), dtype=int)
    for i, (bottom, top) in enumerate(edge_wdth_flux_bins):
        inds = (bottom <= fl[var].values) & (fl[var].values < top)
        idx[inds] = i
    edge_wdth_plot_bins[var] = idx

In [None]:
# Create network layout (don't run this cell again if want to preserve positions of nodes and edges)
scheme = nx.MultiDiGraph()
scheme.add_edges_from(links)
scheme.add_nodes_from(nodes)
pos = nx.nx_pydot.graphviz_layout(scheme)

In [None]:
# Sort links in the order that network creates edges
links_sorted_as_edges = sorted(links, key=scheme.edges().index)
# Combine info about link name with flux intensity through the reaction where this link comes from
link_flux_info = []
for name1, arr in edge_wdth_plot_bins.items():
    for name2, link_list in flux_labels.items():
        if name1 == name2:
            for link in link_list:
                link_flux_info.append((name1, link, arr))
# Sort this info in the order that network creates edges
link_flux_info_sorted_as_edges = sorted(link_flux_info, key=lambda i: links_sorted_as_edges.index(i[1]))

In [None]:
#---Draw a single network-------------------------------------------------------------------------------------------------------------#
# Choose model run and hour
x, y, t = 10, 10, 12
# Create a list of node sizes
node_sizes = []
node_size_scaling_factor = 500
node_sizes = [node_size_scaling_factor*node_size_plot_bins[key][x,y,t] for key in scheme.nodes()]
# Create a list of edge colors
edge_colors = []
for i in link_flux_info_sorted_as_edges:
    edge_colors.append(i[2][x,y,t])
# Draw network
fig, ax = plt.subplots(figsize=(15,15), facecolor='white')
nx.draw_networkx_edges(scheme, pos, ax=ax, width=2, edge_color=edge_colors, edge_cmap=plt.cm.Blues, arrows=False)
nx.draw_networkx_nodes(scheme, pos, ax=ax, node_size=node_sizes, node_color='grey', alpha=0.3)
nx.draw_networkx_labels(scheme, pos, ax=ax, font_size=9)
# nx.draw_networkx_edge_labels(scheme, pos, edge_labels=link_labels, font_size=7) # haven't corrected them yet
_ = ax.axis('off')
ax.set_xlim(-200, 200)
ax.set_ylim(-200, 200)

In [None]:
#---Draw an interactive network-------------------------------------------------------------------------------------------------------#
@interact(x=(0,21,1), y=(0,20,1), t=(0, 721, 1))
def view_exp(x=0, y=0, t=0):
    # Create a list of node sizes
    node_sizes = []
    node_size_scaling_factor = 500
    node_sizes = [node_size_scaling_factor*node_size_plot_bins[key][x,y,t] for key in scheme.nodes()]
    # Create a list of edge colors
    edge_colors = []
    for i in link_flux_info_sorted_as_edges:
        edge_colors.append(i[2][x,y,t])
    # Plot network
    fig, ax = plt.subplots(figsize=(15,15), facecolor='white')
    nx.draw_networkx_edges(scheme, pos, ax=ax, width=2, edge_color=edge_colors, edge_cmap=plt.cm.Blues, arrows=False)
    nx.draw_networkx_nodes(scheme, pos, ax=ax, node_size=node_sizes, node_color='grey', alpha=0.3)
    nx.draw_networkx_labels(scheme, pos, ax=ax, font_size=9)
    # nx.draw_networkx_edge_labels(scheme, pos, edge_labels=link_labels, font_size=7) # haven't corrected them yet
    _ = ax.axis('off')
    ax.set_xlim(-300, 300)
    ax.set_ylim(-300, 300)
    ax.set_title(str(t), fontsize=20)
    plt.show()

In [None]:
#---Draw and save network-------------------------------------------------------------------------------------------------------------#
pics_path = '/local/mwe14avu/UEA/PhD/results/plot_network/'
xruns = [0]#, 10]
yruns = [0]#, 10]
times = np.arange(1, 2)
for x in xruns:
    for y in yruns:
        for t in times:
            # Create a list of node sizes
            node_sizes = []
            node_size_scaling_factor = 500
            node_sizes = [node_size_scaling_factor*node_size_plot_bins[key][x,y,t] for key in scheme.nodes()]
            # Create a list of edge colors
            edge_colors = []
            for i in link_flux_info_sorted_as_edges:
                edge_colors.append(i[2][x,y,t])
            # Plot network
            fig, ax = plt.subplots(figsize=(7,7), facecolor='white')
            nx.draw_networkx_edges(scheme, pos, ax=ax, width=2, edge_color=edge_colors, edge_cmap=plt.cm.Blues, arrows=False)
            nx.draw_networkx_nodes(scheme, pos, ax=ax, node_size=node_sizes, node_color='grey', alpha=0.3)
            nx.draw_networkx_labels(scheme, pos, ax=ax, font_size=9)
            # nx.draw_networkx_edge_labels(scheme, pos, edge_labels=link_labels, font_size=7) # haven't corrected them yet
            _ = ax.axis('off')
            ax.set_xlim(-200, 200)
            ax.set_ylim(-200, 200)
            ax.set_title('x '+str(x)+'; y '+str(y)+'; hr '+str(t), fontsize=20)
            fig.savefig(pics_path+'_'+str(t)+str(x)+str(y)+'.svg', format='svg', dpi=100, bbox_inches='tight')
            plt.close()