In [47]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
from sklearn import datasets


In [48]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()


In [49]:
X_train, X_test = X_train / 255.0, X_test / 255.0

# Flatten pixels
X_train = X_train.reshape((-1, 784))
X_test = X_test.reshape((-1, 784))

In [None]:
# tf.random.set_seed(42)
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=784),
    # tf.keras.layers.Conv2D(4, 4)
    tf.keras.layers.Dense(20, activation='relu'),
    tf.keras.layers.Dense(20, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(X_train, y_train, epochs=3, batch_size=1, verbose=1)

loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test accuracy: {accuracy:.4f}')


Epoch 1/3

In [None]:
import networkx as nx

weights = model.weights
# Create a directed graph
G = nx.DiGraph()

pos = {}
# Add nodes for first column
col = 0
print(len(weights))
i = 0
z = 0
maxRows = 0
for z in range(len(weights)):
  maxRows = max(maxRows, weights[z].shape[0])
  z += 2

maxRows += 1

while i < len(weights): #gets weights associated with the a column
  w = weights[i]

  for sourceRow in range(w.shape[0]): # gets weights associated with a node
    source_node = 'l_' + str(sourceRow) + '_' + str(col)
    pos[source_node] = (col, maxRows//(w.shape[0] + 1) * (sourceRow + 1) * -5)

    for row in range(w.shape[1]):
      dest_node = 'l_' + str(row) + '_' + str(col+1)
      G.add_edge(source_node, dest_node, weight = w[sourceRow][row])
      pos[dest_node] = (col + 1, maxRows//(w.shape[1] + 1) * (row + 1) * -5)

  i += 2
  col += 1

In [None]:
# mapping = {}
# target_names = iris["target_names"]
# target_index = 0

# for node in G.nodes():
#     if node.endswith('_3'):
#         mapping[node] = node.replace(node, f'{target_names[target_index]}')
#         target_index += 1  # Move to the next target_name
# # Perform the renaming
# G = nx.relabel_nodes(G, mapping)

# print(G.nodes)

In [None]:
# import matplotlib.pyplot as plt
# import networkx as nx
# import math

# edge_colors = []
# alpha = []
# width = []
# for u, v, d in G.edges(data=True):
#     # print("Source is", u, "destination is", v, "and weight is", d["weight"])
#     if d['weight'] > 0:
#         edge_colors.append('black')
#         value = (d['weight'] **(3/2))
#     else:
#         edge_colors.append('red')
#         value = (d['weight'] **(3/2))
#         # print("stuff happened")
#     alpha.append( d['weight'])

#     cubed_value = d['weight'] ** 3
#     if cubed_value > 0:
#         log_of_cubed_value = math.log(cubed_value)  # natural logarithm (base e)
#     else:
#         cubed_value *= -1
#         log_of_cubed_value = math.log(cubed_value)
#     width.append(log_of_cubed_value)

# plt.figure(figsize=(35, 10))

# nx.draw(G, pos=pos, with_labels=True, node_size=300, node_color='skyblue', font_size=10, font_weight='bold', arrows=True)

# nx.draw_networkx_edges(G, pos=pos, edge_color=edge_colors, arrows=True, width = width)


# edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}
# nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='black', font_size=8)

# plt.title('Neural Network Visualization')

# plt.axis('on')
# plt.margins(0.1)

# plt.show()

In [None]:
import plotly.graph_objects as go
import networkx as nx
import random
import math

source = []
target = []
weights_list = []
edges = G.edges(data=True)
edges = sorted(edges)
edge_colors = []
# print("Length of edges:", len(edges))
for u, v, d in edges:
  if u.endswith("_0"):
    continue
  source.append(u)
  target.append(v)
  value = d["weight"].numpy()
  # if d["weight"] < 0:
  #   value = d["weight"].numpy() * -1
  #   value = (value+1) ** 3
  #   value = math.log(value, 10)
  #   value *= -1
  #   edge_colors.append("black")
  # else:
  #   edge_colors.append("red")
  #   value = (value+1) ** 3
  #   value = math.log(value, 10)

  if d["weight"] > 0:
    value = d["weight"]
    edge_colors.append("black")
  else:
    value = d["weight"] * -1
    edge_colors.append("red")
  weights_list.append(value)

node_to_index = {node: i for i, node in enumerate(sorted(G.nodes))}
print(len(source), len(target), len(weights))
source_index = []
target_index = []


for val in source:
  source_index.append(node_to_index[val])

for val in target:
  target_index.append(node_to_index[val])

# print("Sources:", source_index)

# print("Targets:", target_index)

# print("Indexes:", node_to_index)
link_properties = {
    'source': source_index,
    'target': target_index,
    'value': weights_list,
    'color': edge_colors
}

node_properties = {
    'pad': 15,
    'thickness': 20,
    'line': dict(color="black", width=0.5),
    'label': sorted(G.nodes())
}

sankey_diagram = go.Sankey(node=node_properties, link=link_properties)

fig = go.Figure(data=[sankey_diagram])
fig.update_layout(title_text="Sankey Diagram for Neural Networks", font_size=10)
fig.show()
