In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import graphviz

In [2]:
def create_grey_to_black_colormap():
  color_dictionary = {
    'red': [
      (0.0, 0.75, 0.75),
      (1.0, 0.0, 0.0)
    ],
    'green': [
      (0.0, 0.75, 0.75),
      (1.0, 0.0, 0.0)
    ],
    'blue': [
      (0.0, 0.75, 0.75),
      (1.0, 0.0, 0.0)
    ]
  }
    
  return mcolors.LinearSegmentedColormap('GreyToBlack', color_dictionary)

def create_yellow_colormap():
  light_yellow = mcolors.hex2color('#FFF9E2')
  dark_yellow = mcolors.hex2color('#FFD500')
    
  color_dictionary = {
    'red':   [(0.0, light_yellow[0], light_yellow[0]), (1.0, dark_yellow[0], dark_yellow[0])],
    'green': [(0.0, light_yellow[1], light_yellow[1]), (1.0, dark_yellow[1], dark_yellow[1])],
    'blue':  [(0.0, light_yellow[2], light_yellow[2]), (1.0, dark_yellow[2], dark_yellow[2])]
  }

  return mcolors.LinearSegmentedColormap('YellowGradient', color_dictionary)

def create_color_gradient(values, colormap):
  norm = plt.Normalize(np.min(values), np.max(values))
  normalized_values = norm(values)    
  rgb_colors = colormap(normalized_values)
  return [mcolors.to_hex(color) for color in rgb_colors]

In [3]:
grey_to_black_cmap = create_grey_to_black_colormap()
light_to_dark_yellow_cmap = create_yellow_colormap()

In [17]:
def visualize_graph_before_training(filepath, head = None):
  df = pd.read_csv(filepath)
  if head:
    df = df.head(head)
  name = filepath.split('/')[-1].replace('.csv', '')
  nodes = pd.concat([
    df[['from_ids', 'from']].rename(columns = {'from_ids' : 'id', 'from' : 'token'}),
    df[['to_id', 'to']].rename(columns = {'to_id' : 'id', 'to' : 'token'})],
    axis = 0
  ).drop_duplicates()
  edge_colors = create_color_gradient(df['weight'], grey_to_black_cmap)
  graph = graphviz.Digraph('3-gram', engine = 'circo') # engine = 'circo', 'twopi', 'dot
  graph.graph_attr['dpi'] = '300'

  for i, node in nodes.iterrows():
    graph.node(str(node['id']), label = node['token'].replace('Ġ', ''), color = '#363636', style = 'filled', fillcolor = '#ffd53d', shape = 'oval')

  for i, row in df.iterrows():
    graph.edge(str(row['from_ids']), str(row['to_id']), arrowsize = '0.5', color = edge_colors[i])
  graph.render(directory = f'./{name}', format = 'png', view = True)

In [22]:
def visualize_graph_after_training(filepath, head = None):
  df = pd.read_csv(filepath)
  if head:
    df = df.head(head)
  name = filepath.split('/')[-1].replace('.csv', '')
  nodes = pd.concat([
    df[['from_ids', 'from_token']].rename(columns = {'from_ids' : 'id', 'from_token' : 'token'}),
    df[['to_ids', 'to_token']].rename(columns = {'to_ids' : 'id', 'to_token' : 'token'})],
    axis = 0
  ).drop_duplicates()
  edge_colors = create_color_gradient(df['weights'], grey_to_black_cmap)
  graph = graphviz.Digraph('3-gram', engine = 'circo') # engine = 'circo', 'twopi', 'dot
  graph.graph_attr['dpi'] = '300'

  for i, node in nodes.iterrows():
    graph.node(str(node['id']), label = node['token'].replace('Ġ', ''), color = '#363636', style = 'filled', fillcolor = '#ffd53d', shape = 'oval')

  for i, row in df.iterrows():
    graph.edge(str(row['from_ids']), str(row['to_ids']), arrowsize = '0.5', color = edge_colors[i])
  graph.render(directory = f'./{name}', format = 'png', view = True)

In [42]:
def visualize_graph_after_training_2(filepath, head = None):
  nodes = pd.read_csv(filepath)
  if head:
    nodes = nodes.head(head)
  name = filepath.split('/')[-1].replace('.csv', '')
  node_colors = create_color_gradient(nodes['aggregated_weights'], grey_to_black_cmap)
  graph = graphviz.Digraph('3-gram', engine = 'circo') # engine = 'circo', 'twopi', 'dot
  graph.graph_attr['dpi'] = '300'

  for i, node in nodes.iterrows():
    graph.node(str(node['ids']), label = node['tokens'].replace('Ġ', ''), color = '#363636', style = 'filled', fillcolor = node_colors[i], shape = 'oval')

  graph.render(directory = f'./{name}', format = 'png', view = True)

# Before Model Training (Method as an XAI Tool)

In [7]:
visualize_graph_before_training('./data/facebook-bart-large/SST-2/1126-Grouped.csv')

In [32]:
visualize_graph_before_training('./data/facebook-bart-large/SST-2/1126-Surrogate.csv')

In [35]:
visualize_graph_before_training('./data/facebook-bart-large/IMDb-top_1000/102-Surrogate.csv', head = 50)

# After Model Training (XAI Technique Applied to Method)

In [44]:
visualize_graph_after_training('./data/facebook-bart-large/SST-2/1126-Grouped-edge_importance.csv')

In [45]:
visualize_graph_after_training('./data/facebook-bart-large/IMDb-top_1000/102-Surrogate-edge_importance.csv', head = 100)

In [39]:
visualize_graph_after_training_2('./data/facebook-bart-large/SST-2/1126-Grouped-node_importance.csv')

In [43]:
visualize_graph_after_training_2('./data/facebook-bart-large/IMDb-top_1000/102-Surrogate-node_importance.csv', head = 100)