# Imports

In [1]:
from platform import python_version
print(python_version())

3.7.12


In [2]:
import tensorflow as tf
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# gpus = tf.config.experimental.list_physical_devices('GPU')
# for gpu in gpus:
#     tf.config.experimental.set_memory_growth(gpu, True)
    
# print(tf.config.list_physical_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 16116794567552834832
xla_global_id: -1
]


In [3]:
print(tf.__version__)

2.7.0


In [4]:
import os, sys
import pickle

import pandas as pd
# import modin.pandas as pd

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, concatenate
from tensorflow.keras.layers import Conv1D, MaxPool1D, Flatten
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.utils import plot_model

from sklearn import model_selection
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from sklearn.metrics import precision_score, recall_score

from IPython.display import display

# import pkg_resources
# from packaging.version import Version
# if Version(pkg_resources.get_distribution('tables').version) < Version('3.6.1'):
!pip install --upgrade tables

try:
    import stellargraph as sg
except ImportError as e:
    !pip install stellargraph
    import stellargraph as sg
    
try:
    sg.utils.validate_notebook_version("1.2.1")
except AttributeError:
    raise ValueError(
        f"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>."
    ) from None

from stellargraph import StellarGraph
from stellargraph import IndexedArray
from stellargraph.mapper import PaddedGraphGenerator, FullBatchNodeGenerator
from stellargraph.layer import GCNSupervisedGraphClassification, DeepGraphCNN, GAT
import networkx as nx
from networkx import draw_networkx

Collecting tables
  Downloading tables-3.6.1-cp37-cp37m-manylinux1_x86_64.whl (4.3 MB)
[K     |████████████████████████████████| 4.3 MB 5.4 MB/s 
Installing collected packages: tables
  Attempting uninstall: tables
    Found existing installation: tables 3.4.4
    Uninstalling tables-3.4.4:
      Successfully uninstalled tables-3.4.4
Successfully installed tables-3.6.1
Collecting stellargraph
  Downloading stellargraph-1.2.1-py3-none-any.whl (435 kB)
[K     |████████████████████████████████| 435 kB 5.1 MB/s 
Installing collected packages: stellargraph
Successfully installed stellargraph-1.2.1


# Contants

In [5]:
# We ignore relationships with no connections
min_connections = 1

# Storage Paths
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    storage_path = "/content/drive/MyDrive/ms_proj/"
else:
    storage_path = "/Users/ivan/Google Drive/ms_proj/"

graph_model_stats_file = os.path.join(storage_path, 'graph_model_stats.pickle')

def defineGlobalConstants(num_friends_local = 3, num_tweets_local = 10):
  df_storage_path = os.path.join(storage_path, "tweets_dfs")
  df_storage_path_friends = os.path.join(df_storage_path, str(num_friends_local), str(num_tweets_local))
  combined_train_test_users_path = os.path.join(storage_path, "combined_train_test_users", str(num_friends_local), str(num_tweets_local))

  # # Stats
  # global hdf_stats_depressed_path
  # hdf_stats_depressed_path = os.path.join(df_storage_path_friends, "stats_depressed.h5")
  # global hdf_stats_control_path
  # hdf_stats_control_path = os.path.join(df_storage_path_friends, "stats_control.h5")
  # global hdf_replies_key
  # hdf_replies_key = "replies"
  # global hdf_mentions_key
  # hdf_mentions_key = "mentions"
  # global hdf_quote_tweets_key
  # hdf_quote_tweets_key = "quote_tweets"

  global hdf_stats_reply_path
  hdf_stats_reply_path = os.path.join(df_storage_path_friends, "stats_reply.h5")
  global hdf_stats_mention_path
  hdf_stats_mention_path = os.path.join(df_storage_path_friends, "stats_mention.h5")
  global hdf_stats_quote_tweets_path
  hdf_stats_quote_tweets_path = os.path.join(df_storage_path_friends, "stats_quote_tweets.h5")
  global hdf_stats_key
  hdf_stats_key = "stats"

  # Train/Test Users
  global combined_users_key
  combined_users_key = "combined_users"
  global combined_train_users_path
  combined_train_users_path = os.path.join(combined_train_test_users_path, "train_users.h5")
  global combined_test_users_path
  combined_test_users_path = os.path.join(combined_train_test_users_path, "test_users.h5")

  # Encoded tweets DFs
  global hdf_encoded_tweets_path
  hdf_encoded_tweets_path = os.path.join(df_storage_path_friends, "encoded")
  global hdf_encoded_tweets_key
  hdf_encoded_tweets_key = "tweets"

Mounted at /content/drive


# Helper Functions

In [6]:
# Combine multiple dataframes
def combineDfs(dfs: list, random = True):
    frame_combined = pd.concat(dfs, ignore_index=True)
    # frame_combined.reset_index(inplace=True)
    # frame_combined.drop(['index'], axis=1, inplace=True)
    if random:
        return frame_combined.sample(frac = 1)
    else:
        return frame_combined
    
# def normalizeDict(d, normalized_low=0, normalized_high=1):
#     data_low = min(d.values())
#     data_high = max(d.values())
    
#     if data_low == data_high:
#         return d
    
#     return {key:((value - data_low) / (data_high - data_low)) * (normalized_high - normalized_low) + normalized_low for key,value in d.items()}

def normalizeDict(d, target=1.0):
    raw = sum(d.values())
    factor = target/raw
    return {key:value*factor for key,value in d.items()}
       
def loadCombinedUsers():
    if os.path.isfile(combined_train_users_path) and os.path.isfile(combined_test_users_path):
        print("Found saved train/test users in path: {}. Loading them...".format(combined_train_users_path))
        x_train_users = pd.read_hdf(combined_train_users_path, key=combined_users_key)
        x_test_users = pd.read_hdf(combined_test_users_path, key=combined_users_key)
        return x_train_users, x_test_users
    else:
        print("No train/test users found in path: {}. Run the main notebook".format(combined_train_users_path))
    
def genEncodedTweets(userOrFriend, lstm_layers = 1, word_embeddings = 100, user_embeddings = 200):   
    # load the dict from
    saved_path = os.path.join(
        hdf_encoded_tweets_path,
        userOrFriend + "_"  + str(lstm_layers) + "_" + str(word_embeddings) + "_" + str(user_embeddings)+".pickle")
    
    # If we have a saved file, we load the encoded tweets from it
    if os.path.isfile(saved_path):
        print("Found saved {} encoded tweets in path: {}. Loading them...".format(userOrFriend, saved_path))
        encoded_tweets_dict = pickle.load( open( saved_path, "rb" ) )
        return encoded_tweets_dict
    else:
        print("No encoded tweets found in path: {}. Run the main notebook".format(saved_path))

# Plot an ROC. pred - the predictions, y - the expected output.
def plot_roc(pred,y):
    fpr, tpr, thresholds = roc_curve(y, pred)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc="lower right")
    plt.show()

    return roc_auc

# Plot a confusion matrix.
# cm is the confusion matrix, names are the names of the classes.
def plot_confusion_matrix(cm, names, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(names))
    plt.xticks(tick_marks, names, rotation=45)
    plt.yticks(tick_marks, names)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

def predToBinPred(pred, threshold = 0.5):
  return [1 if p > threshold else 0 for p in pred]

# Function to evaluate
def evaluate_preds(true, pred):
    auc = roc_auc_score(true, pred)
    pr = average_precision_score(true, pred)
    bin_pred = predToBinPred(pred)
    f_score = f1_score(true, bin_pred)
    print('ROC AUC:', auc)
    print('PR AUC:', pr)
    print('F1 score:', f_score)
    print(confusion_matrix(true, bin_pred, normalize='true'))
    
    return auc, pr, f_score

# Helper Graph Functions

In [7]:
def genNodesForGraph(user_id, graph_type):
    if graph_type == 'replies':
        row = df_replies_stats.loc[df_replies_stats['user_id'] == user_id]
    elif graph_type == 'mentions':
        row = df_mentions_stats.loc[df_mentions_stats['user_id'] == user_id]
    elif graph_type == 'quote_tweets':
        row = df_quote_tweets_stats.loc[df_quote_tweets_stats['user_id'] == user_id]
    else:
        print('Invalid graphType specified: ' + graph_type)
        return
    
    stat_list = row[graph_type].values[0]
    user_node = {}
    user_node_normalized = {}
    friend_nodes = {}
    
    for stat in stat_list:
        friend_nodes[stat.get('friend_id')] = stat.get(graph_type)
    
    user_node[user_id] = friend_nodes
    
    friend_nodes_normalized = normalizeDict(friend_nodes)
    user_node_normalized[user_id] = friend_nodes_normalized
    
    return user_node, user_node_normalized

def genUserGraph(user_id, graph_type):
    if graph_type not in ['replies', 'mentions', 'quote_tweets']:
        print('Invalid graphType specified: ' + graph_type)
        return
    
    _,nodes = genNodesForGraph(user_id, graph_type)
    
    graph = nx.Graph()
    
    # Add the edges
    friend_nodes = nodes[user_id]
    users_friends_dict = {}
    for friend_id in friend_nodes:
        weight = friend_nodes[friend_id]
        
        # add each node to dict with it's label
        users_friends_dict[user_id] = 'user'
        users_friends_dict[friend_id] = 'friend'
        
        # add the edge
        graph.add_edge(user_id, friend_id, weight=weight)
        
    # label each node
    nx.set_node_attributes(graph, users_friends_dict, "label")
    
    # node features
    for node_id, node_data in graph.nodes(data=True):
        if node_data.get('label') == 'user':
            encoded_vector_tweets = all_encoded_user_tweets[node_id]
            node_data['type'] = 'user'
        elif node_data.get('label') == 'friend':
            encoded_vector_tweets = all_encoded_friend_tweets.get(node_id)
            ###### RESET THE LABEL TO USER SINCE GCN AND GAT ONLY WORK WITH ONE LABEL NODES
            node_data['label'] = 'user'
            node_data['type'] = 'friend'

        node_data["feature"] = encoded_vector_tweets
    
    stellargraph_data = StellarGraph.from_networkx(
        graph,
        node_features="feature",
        node_type_default="user",
        edge_type_default=graph_type)
    
#     nx.draw(graph, with_labels=False, node_color='red')
#     print(stellargraph_data.info())
    return stellargraph_data


def genMultiGraph(user_id):
    _,replies_nodes = genNodesForGraph(user_id, 'replies')
    _,mentions_nodes = genNodesForGraph(user_id, 'mentions')
    _,quote_tweets_nodes = genNodesForGraph(user_id, 'quote_tweets')
    
    graph = nx.MultiGraph()
    
    # Add the edges
    replies_friend_nodes = replies_nodes[user_id]
    mentions_friend_nodes = mentions_nodes[user_id]
    quote_tweets_friend_nodes = quote_tweets_nodes[user_id]
    replies_users_friends_dict = {}
    mentions_users_friends_dict = {}
    quote_tweets_users_friends_dict = {}

    for friend_id in replies_friend_nodes:
        weight = replies_friend_nodes[friend_id]
        # add each node to dict with it's label
        replies_users_friends_dict[user_id] = 'user'
        replies_users_friends_dict[friend_id] = 'friend'
        # add the edge
        graph.add_edge(user_id, friend_id, weight=weight, label="replies", color="r")

    for friend_id in mentions_friend_nodes:
        weight = mentions_friend_nodes[friend_id]
        # add each node to dict with it's label
        mentions_users_friends_dict[user_id] = 'user'
        mentions_users_friends_dict[friend_id] = 'friend'
        # add the edge
        graph.add_edge(user_id, friend_id, weight=weight, label="mentions", color="g")

    for friend_id in quote_tweets_friend_nodes:
        weight = quote_tweets_friend_nodes[friend_id]
        # add each node to dict with it's label
        quote_tweets_users_friends_dict[user_id] = 'user'
        quote_tweets_users_friends_dict[friend_id] = 'friend'
        # add the edge
        graph.add_edge(user_id, friend_id, weight=weight, label="quote_tweets", color="b")
        
    # label each node
    nx.set_node_attributes(graph, replies_users_friends_dict, "label")
    nx.set_node_attributes(graph, mentions_users_friends_dict, "label")
    nx.set_node_attributes(graph, quote_tweets_users_friends_dict, "label")
    
    # node features
    for node_id, node_data in graph.nodes(data=True):
        if node_data.get('label') == 'user':
            encoded_vector_tweets = all_encoded_user_tweets[node_id]
            node_data['type'] = 'user'
        elif node_data.get('label') == 'friend':
            encoded_vector_tweets = all_encoded_friend_tweets.get(node_id)
            ###### RESET THE LABEL TO USER SINCE GCN AND GAT ONLY WORK WITH ONE LABEL NODES
            node_data['label'] = 'user'
            node_data['type'] = 'friend'

        node_data["feature"] = encoded_vector_tweets
    
    stellargraph_data = StellarGraph.from_networkx(
        graph,
        node_features="feature",
        node_type_default="user",
        edge_type_default="replies")
    
#     nx.draw(graph, with_labels=False, node_color='red')
#     print(stellargraph_data.info())
    return stellargraph_data

def getUsersGraphsDict():
  global train_users_graphs
  train_users_graphs = {}
  global test_users_graphs
  test_users_graphs = {}

  global train_users_graph_labels
  train_users_graph_labels = []
  global test_users_graph_labels
  test_users_graph_labels = []
  global combined_graph_labels
  combined_graph_labels = []

  # Train user graphs
  for user_id in graph_train_users['user_id'].values:
      train_users_graphs[user_id] = {}
      train_users_graph_labels.append(graph_train_users.loc[graph_train_users['user_id'] == user_id]['depressed'].values[0])
      train_users_graphs[user_id]['depressed'] = graph_train_users.loc[graph_train_users['user_id'] == user_id]['depressed'].values[0]
      train_users_graphs[user_id]['reply_graph'] = genUserGraph(user_id, 'replies')
      train_users_graphs[user_id]['mention_graph'] = genUserGraph(user_id, 'mentions')
      train_users_graphs[user_id]['quote_tweet_graph'] = genUserGraph(user_id, 'quote_tweets')
      train_users_graphs[user_id]['combined'] = genMultiGraph(user_id)

  # Test user graphs
  for user_id in graph_test_users['user_id'].values:
      test_users_graphs[user_id] = {}
      test_users_graph_labels.append(graph_test_users.loc[graph_test_users['user_id'] == user_id]['depressed'].values[0])
      test_users_graphs[user_id]['depressed'] = graph_test_users.loc[graph_test_users['user_id'] == user_id]['depressed'].values[0]
      test_users_graphs[user_id]['reply_graph'] = genUserGraph(user_id, 'replies')
      test_users_graphs[user_id]['mention_graph'] = genUserGraph(user_id, 'mentions')
      test_users_graphs[user_id]['quote_tweet_graph'] = genUserGraph(user_id, 'quote_tweets')
      test_users_graphs[user_id]['combined'] = genMultiGraph(user_id)

  # Combine Labels
  combined_graph_labels = train_users_graph_labels + test_users_graph_labels

def getUsersGraphsCombined():
  global train_replies_graphs
  train_replies_graphs = []
  global train_mentions_graphs
  train_mentions_graphs = []
  global train_quote_tweets_graphs
  train_quote_tweets_graphs = []
  global train_combined_graphs
  train_combined_graphs = []
  global test_replies_graphs
  test_replies_graphs = []
  global test_mentions_graphs
  test_mentions_graphs = []
  global test_quote_tweets_graphs
  test_quote_tweets_graphs = []
  global test_combined_graphs
  test_combined_graphs = []

  global combined_replies_graphs
  combined_replies_graphs = []
  global combined_mentions_graphs
  combined_mentions_graphs = []
  global combined_quote_tweets_graphs
  combined_quote_tweets_graphs = []
  global combined_graphs
  combined_graphs = []

  for user_graphs in train_users_graphs:
      # print(train_users_graphs[user_graphs]['reply_graph'].info())
      train_replies_graphs.append(train_users_graphs[user_graphs]['reply_graph'])
      train_mentions_graphs.append(train_users_graphs[user_graphs]['mention_graph'])
      train_quote_tweets_graphs.append(train_users_graphs[user_graphs]['quote_tweet_graph'])
      train_combined_graphs.append(train_users_graphs[user_graphs]['combined'])

  for user_graphs in test_users_graphs:
      # print(test_users_graphs[user_graphs]['reply_graph'].info())
      test_replies_graphs.append(test_users_graphs[user_graphs]['reply_graph'])
      test_mentions_graphs.append(test_users_graphs[user_graphs]['mention_graph'])
      test_quote_tweets_graphs.append(test_users_graphs[user_graphs]['quote_tweet_graph'])
      test_combined_graphs.append(test_users_graphs[user_graphs]['combined'])

  # Combine
  combined_replies_graphs = train_replies_graphs + test_replies_graphs
  combined_mentions_graphs = train_mentions_graphs + test_mentions_graphs
  combined_quote_tweets_graphs = train_quote_tweets_graphs + test_quote_tweets_graphs
  combined_graphs = train_combined_graphs + test_combined_graphs

  # Summaries
  replies_summary = pd.DataFrame(
      [(g.number_of_nodes(), g.number_of_edges()) for g in combined_replies_graphs],
      columns=["nodes", "edges"],
  )
  mentions_summary = pd.DataFrame(
      [(g.number_of_nodes(), g.number_of_edges()) for g in combined_mentions_graphs],
      columns=["nodes", "edges"],
  )
  quote_tweets_summary = pd.DataFrame(
      [(g.number_of_nodes(), g.number_of_edges()) for g in combined_quote_tweets_graphs],
      columns=["nodes", "edges"],
  )
  combined_summary = pd.DataFrame(
      [(g.number_of_nodes(), g.number_of_edges()) for g in combined_graphs],
      columns=["nodes", "edges"],
  )

  global replies_graph_node_num
  replies_graph_node_num = np.max(replies_summary.nodes)
  global mentions_graph_node_num
  mentions_graph_node_num = np.max(mentions_summary.nodes)
  global quote_tweets_graph_node_num
  quote_tweets_graph_node_num = np.max(quote_tweets_summary.nodes)
  global combined_graph_node_num
  combined_graph_node_num = np.max(combined_summary.nodes)

  print("Replies Graph Stats")
  print(replies_summary.describe().round(1))
  print("\nMentions Graph Stats")
  print(mentions_summary.describe().round(1))
  print("\nQuote Tweets Graph Stats")
  print(quote_tweets_summary.describe().round(1))
  print("\nCombined Graph Stats")
  print(combined_summary.describe().round(1))

# Graph model functions

### Prepare graph generator

To feed data to the `tf.Keras` model that we will create later, we need a data generator. For supervised graph classification, we create an instance of `StellarGraph`'s `PaddedGraphGenerator` class. Note that `graphs` is a list of `StellarGraph` graph objects.

In [8]:
def prepGraphGens():
  replies_generator_gcn = PaddedGraphGenerator(graphs=combined_replies_graphs)
  mentions_generator_gcn = PaddedGraphGenerator(graphs=combined_mentions_graphs)
  quote_tweets_generator_gcn = PaddedGraphGenerator(graphs=combined_quote_tweets_graphs)
  combined_generator_gcn = PaddedGraphGenerator(graphs=combined_graphs)
# generator_gcn = FullBatchNodeGenerator(g)

  return replies_generator_gcn, mentions_generator_gcn, quote_tweets_generator_gcn, combined_generator_gcn

### Misc

In [9]:
def get_generators_flow(generator, batch_size):
    # We already have our train and test users split. The first 344 users are for training
    # and the rest 148 users are for testing
    # from 0 to 343 = 344 train users
    train_index = list(range(0, len(graph_train_users)))
    # from 344 to 492 = 148 train users
    test_index = list(range(len(graph_train_users), len(graph_train_users)+len(graph_test_users)))

    # y = tf.keras.utils.to_categorical(combined_graph_labels, num_classes=2)

    train_gen_flow = generator.flow(
        train_index, targets=pd.DataFrame(combined_graph_labels).iloc[train_index].values, batch_size=batch_size
    )
    test_gen_flow = generator.flow(
        test_index, targets=pd.DataFrame(combined_graph_labels).iloc[test_index].values, batch_size=batch_size
    )

    return train_gen_flow, test_gen_flow

### Create the Keras graph classification model

We are now ready to create a `tf.Keras` graph classification model using `StellarGraph`'s `GraphClassification` class together with standard `tf.Keras` layers, e.g., `Dense`. 

The input is the graph represented by its adjacency and node features matrices. The first two layers are Graph Convolutional with each layer having 64 units and `relu` activations. The next layer is a mean pooling layer where the learned node representation are summarized to create a graph representation. The graph representation is input to two fully connected layers with 32 and 16 units respectively and `relu` activations. The last layer is the output layer with a single unit and `sigmoid` activation.

In [28]:
def create_graph_classification_model_gcn(generator, k):
    gc_model = GCNSupervisedGraphClassification(
        layer_sizes=[64, 64],
        activations=["relu", "relu"],
        generator=generator,
        dropout=0.5,
    )
    x_inp, x_out = gc_model.in_out_tensors()
    predictions = Dense(units=32, activation="relu")(x_out)
    predictions = Dense(units=16, activation="relu")(predictions)
    predictions = Dense(units=1, activation="sigmoid")(predictions)

    # Let's create the Keras model and prepare it for training
    model = Model(inputs=x_inp, outputs=predictions)
    # fpr, tpr, thresholds = metrics.roc_curve(test_users_graph_labels, predictions, pos_label=2)
    # model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=[metrics.auc(fpr, tpr), "acc"])
    model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=["acc"])

    return model, predictions, x_inp, x_out

def create_graph_classification_model_dgcnn(generator, k, layer_sizes, activations):
    gc_model = DeepGraphCNN(
        layer_sizes=layer_sizes,
        activations=activations,
        k=k,
        generator=generator,
        dropout=0.5,
    )
    x_inp, x_out = gc_model.in_out_tensors()
    x_out = Conv1D(filters=16, kernel_size=sum(layer_sizes), strides=sum(layer_sizes))(x_out)
    x_out = MaxPool1D(pool_size=2)(x_out)
    # x_out = Conv1D(filters=32, kernel_size=3, strides=1)(x_out)
    x_out = Flatten()(x_out)
    x_out = Dense(units=128, activation="relu")(x_out)
    x_out = Dropout(rate=0.5)(x_out)
    predictions = Dense(units=1, activation="sigmoid")(x_out)

    # Let's create the Keras model and prepare it for training
    model = Model(inputs=x_inp, outputs=predictions)
    # fpr, tpr, thresholds = metrics.roc_curve(test_users_graph_labels, predictions, pos_label=2)
    # model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=[metrics.auc(fpr, tpr), "acc"])
    model.compile(optimizer=Adam(0.0001), loss=binary_crossentropy, metrics=["acc"])

    return model, predictions, x_inp, x_out

def train_fold(model, train_gen, test_gen):
  epochs = 300
  es = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=20, verbose=1, restore_best_weights=True, mode='auto')

  history = model.fit(
      train_gen, epochs=epochs, validation_data=test_gen, verbose=0, callbacks=[es],
  )
  return model, history

def trainModel(generator, k, layer_sizes, activations):
  print("K value: {}".format(k))

  train_gen_flow, test_gen_flow = get_generators_flow(generator=generator, batch_size=15)

  model, predictions, x_inp, x_out = create_graph_classification_model_dgcnn(generator, k, layer_sizes, activations)
  display(model.summary())

  model, history = train_fold(model, train_gen_flow, test_gen_flow)

  return model, history, train_gen_flow, test_gen_flow, predictions, x_inp, x_out

### Stats

In [11]:
def show_acc_stats(title, model, history, test_gen):
    stats = {}
    # calculate performance on the test data and return along with history
    test_metrics = model.evaluate(test_gen, verbose=0)
    # test_acc = test_metrics[model.metrics_names.index("acc")]
    
    print("\n{} Test Set Metrics:".format(title))
    for name, val in zip(model.metrics_names, test_metrics):
        print("\t{}: {:0.4f}".format(name, val))
        stats[name] = val

    return stats

#### AU-ROC
Let's use the trained model to make a prediction for each graph.

Then, select only the predictions for the graphs in the test set and calculate the AU-ROC as another performance metric in addition to the accuracy shown above.

In [12]:
def get_auroc_stats(name, pred):
  # replies_auc, replies_pr, replies_f_score = evaluate_preds(true=test_users_graph_labels, pred=pred_replies)
  model_auc = plot_roc(pred, test_users_graph_labels)
  print("The {} AUC on the test set: {}".format(name, model_auc))

  return model_auc

#### Confusion Matrix

In [13]:
def getConfusionMatrix(name, pred_bin):
  cm = confusion_matrix(np.array(test_users_graph_labels), pred_bin)
  print(cm)
  print('Plotting {} confusion matrix'.format(name))
  plt.figure()
  plot_confusion_matrix(cm, ["healthy", "depressed"])
  plt.show()

  print(classification_report(np.array(test_users_graph_labels), pred_bin))

In [14]:
def getStats(name, pred_bin):
  precission = precision_score(np.array(test_users_graph_labels), pred_bin)
  recall = recall_score(np.array(test_users_graph_labels), pred_bin)
  f1 = f1_score(np.array(test_users_graph_labels), pred_bin)
  print("{} Precision score: {}".format(name, precission))
  print("{} Recall score: {}".format(name, recall))
  print("{} F1 Score: {}".format(name, f1))

  return precission, recall, f1

# Run model loop

In [31]:
# The total number of friends 0 = all friends
num_friends_list = [0]
# The number of tweets - most recent. 0 means no limit.
num_tweets_list = [50]


# lstm_layers = [1,2]
lstm_layers = [1]
# word_embedding_size_list: 100,200,300
word_embedding_size_list = [100]
# user_embedding_size_list: 50,100,200
user_embedding_size_list = [200]

# DGCNN Params
# k
k_list = [5,10,15,35]
# l
layer_sizes_list = [
  [32, 32],
]
# j
activations_list = [
  ["tanh", "tanh"],
]

itteration = 0

stats = {}

for lstm_num in lstm_layers:
    for word_embedding_size in word_embedding_size_list:
        for user_embedding_size in user_embedding_size_list:
            for num_friends in num_friends_list:
                for num_tweets in num_tweets_list:
                    for layer_sizes_idx, layer_sizes in enumerate(layer_sizes_list):
                        for k in k_list:

                            # only have unlimitted tweets for num_friends = 3
                            if (num_friends != 3 or word_embedding_size != 100 or user_embedding_size != 200) and num_tweets == 0:
                              continue
      
                            itteration += 1
                            print("\n\nItteration: {}".format(itteration))
                            print("#####################################################")
                            print("Friends Count: {}".format(num_friends))
                            print("Tweet Count: {}".format(num_tweets))
      
                            print("LSTM count: {}".format(lstm_num))
                            print("Word embeddings size: {}".format(word_embedding_size))
                            print("User embeddings size: {}".format(user_embedding_size))

                            print("\nDGCNN Params:")
                            print("Layer Sizes (l): {}".format(layer_sizes))
                            print("Activations (j): {}".format(activations_list[layer_sizes_idx]))
                            print("Nodes after SortPooling (k): {}".format(k))
                            print("#####################################################\n")
      
                            # Define the constants
                            defineGlobalConstants(num_friends_local = num_friends, num_tweets_local = num_tweets)
      
                            # Get the stats
                            df_replies_stats = pd.read_hdf(hdf_stats_reply_path, key=hdf_stats_key)
                            df_mentions_stats = pd.read_hdf(hdf_stats_mention_path, key=hdf_stats_key)
                            df_quote_tweets_stats = pd.read_hdf(hdf_stats_quote_tweets_path, key=hdf_stats_key)
                            print("Replies Stats Shape: {}".format(df_replies_stats.shape))
                            print("Mentions Stats Shape: {}".format(df_mentions_stats.shape))
                            print("Quote Tweets Stats Shape: {}".format(df_quote_tweets_stats.shape))
      
                            # Get the encoded tweets
                            all_encoded_user_tweets = genEncodedTweets('user', lstm_num, word_embedding_size, user_embedding_size)
                            all_encoded_friend_tweets = genEncodedTweets('friend', lstm_num, word_embedding_size, user_embedding_size)
                            print("User Node Count: {}".format(len(all_encoded_user_tweets)))
                            print("Friends Unique Node Count: {}".format(len(all_encoded_friend_tweets)))

                            # k = len(all_encoded_friend_tweets)+1
                            # print("TEST K: {}".format(k))
      
                            # Get the train/test users
                            graph_train_users, graph_test_users = loadCombinedUsers()
                            print("Total users: {}".format(len(graph_train_users)+len(graph_test_users)))
                            print("Total of {} train users".format(len(graph_train_users)))
                            print("Total of {} test users".format(len(graph_test_users)))
      
                            # Get a dictionary for each train/test user containing their graphs
                            getUsersGraphsDict()
                            # Train graph labels stats
                            print("Train labels")
                            print(pd.Series(train_users_graph_labels).value_counts().to_frame())
                            # Test graph labels stats
                            print("Test labels")
                            print(pd.Series(test_users_graph_labels).value_counts().to_frame())
      
                            # Get the combined user graphs
                            getUsersGraphsCombined()
      
                            # Draw a sample from the train replies graph
                            print("Sample reply graph")
                            plt.figure(figsize=(10,10))
                            draw_networkx(train_replies_graphs[0].to_networkx(), node_size=2000, edge_color='grey')
      
                            # Get the generators
                            replies_generator_gcn, mentions_generator_gcn, quote_tweets_generator_gcn, combined_generator_gcn = prepGraphGens()
      
                            ##### TRAIN
                            # Replies model
                            replies_model, replies_history, replies_train_gen, replies_test_gen, replies_predictions, replies_x_inp, replies_x_out = trainModel(replies_generator_gcn, k, layer_sizes, activations_list[layer_sizes_idx])
                            # Mentions Model
                            mentions_model, mentions_history, mentions_train_gen, mentions_test_gen, mentions_predictions, mentions_x_inp, mentions_x_out = trainModel(mentions_generator_gcn, k, layer_sizes, activations_list[layer_sizes_idx])
                            # Quote Tweets Model
                            quote_tweets_model, quote_tweets_history, quote_tweets_train_gen, quote_tweets_test_gen, quote_tweets_predictions, quote_tweets_x_inp, quote_tweets_x_out = trainModel(quote_tweets_generator_gcn, k, layer_sizes, activations_list[layer_sizes_idx])
                            # Combined Model
                            combined_model, combined_history, combined_train_gen, combined_test_gen, combined_predictions, combined_x_inp, combined_x_out = trainModel(combined_generator_gcn, k, layer_sizes, activations_list[layer_sizes_idx])
      
                            ##### Predictions
                            replies_pred = replies_model.predict(replies_test_gen)
                            mentions_pred = mentions_model.predict(mentions_test_gen)
                            quote_tweets_pred = quote_tweets_model.predict(quote_tweets_test_gen)
                            combined_pred = combined_model.predict(combined_test_gen)
      
                            replies_pred_list = []
                            for i in range(len(replies_pred)):
                              replies_pred_list.append(np.max(replies_pred[i]))
                            mentions_pred_list = []
                            for i in range(len(mentions_pred)):
                              mentions_pred_list.append(np.max(mentions_pred[i]))
                            quote_tweets_pred_list = []
                            for i in range(len(quote_tweets_pred)):
                              quote_tweets_pred_list.append(np.max(quote_tweets_pred[i]))
                            combined_pred_list = []
                            for i in range(len(combined_pred)):
                              combined_pred_list.append(np.max(combined_pred[i]))
      
                            replies_pred_bin = predToBinPred(replies_pred_list, threshold=0.5)
                            mentions_pred_bin = predToBinPred(mentions_pred_list, threshold=0.5)
                            quote_tweets_pred_bin = predToBinPred(quote_tweets_pred_list, threshold=0.5)
                            combined_pred_bin = predToBinPred(combined_pred_list, threshold=0.5)
           
                            ##### EVAL
                            # Accuracy
                            stats_replies = show_acc_stats("Replies", replies_model, replies_history, replies_test_gen)
                            stats_mentions = show_acc_stats("Mentions", mentions_model, mentions_history, mentions_test_gen)
                            stats_quote_tweets = show_acc_stats("Quote Tweets", quote_tweets_model, quote_tweets_history, quote_tweets_test_gen)
                            stats_combined = show_acc_stats("Combined", combined_model, combined_history, combined_test_gen)
                            print('Replies Stats Graph')
                            sg.utils.plot_history(replies_history)
                            print('Mentions Stats Graph')
                            sg.utils.plot_history(mentions_history)
                            print('Quote Tweets Stats Graph')
                            sg.utils.plot_history(quote_tweets_history)
                            print('Combined Stats Graph')
                            sg.utils.plot_history(combined_history)
      
                            # AU-ROC
                            replies_auc = get_auroc_stats("Replies", replies_pred_list)
                            mentions_auc = get_auroc_stats("Mentions", mentions_pred_list)
                            quote_tweets_auc = get_auroc_stats("Quote-Tweets", quote_tweets_pred_list)
                            combined_auc = get_auroc_stats("Combined", combined_pred_list)
      
                            # Precision, Recall, F1
                            replies_precission, replies_recall, replies_f1 = getStats("Replies", replies_pred_bin)
                            mentions_precission, mentions_recall, mentions_f1 = getStats("Mentions", mentions_pred_bin)
                            quote_tweets_precission, quote_tweets_recall, quote_tweets_f1 = getStats("Quote-Tweets", quote_tweets_pred_bin)
                            combined_precission, combined_recall, combined_f1 = getStats("Combined", combined_pred_bin)
      
                            # Save all stats
                            if lstm_num not in stats:
                              stats[lstm_num] = {}
                            if word_embedding_size not in stats[lstm_num]:
                              stats[lstm_num][word_embedding_size] = {}
                            if user_embedding_size not in stats[lstm_num][word_embedding_size]:
                              stats[lstm_num][word_embedding_size][user_embedding_size] = {}
                            if num_friends not in stats[lstm_num][word_embedding_size][user_embedding_size]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends] = {}
                            if num_tweets not in stats[lstm_num][word_embedding_size][user_embedding_size][num_friends]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets] = {}
                            if 'replies' not in stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['replies'] = {}
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['replies']['auc'] = replies_auc
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['replies']['precission'] = replies_precission
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['replies']['recall'] = replies_recall
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['replies']['f1'] = replies_f1
                            if 'mentions' not in stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['mentions'] = {}
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['mentions']['auc'] = mentions_auc
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['mentions']['precission'] = mentions_precission
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['mentions']['recall'] = mentions_recall
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['mentions']['f1'] = mentions_f1
                            if 'quote_tweets' not in stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['quote_tweets'] = {}
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['quote_tweets']['auc'] = quote_tweets_auc
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['quote_tweets']['precission'] = quote_tweets_precission
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['quote_tweets']['recall'] = quote_tweets_recall
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['quote_tweets']['f1'] = quote_tweets_f1
                            if 'combined' not in stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]:
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['combined'] = {}
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['combined']['auc'] = combined_auc
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['combined']['precission'] = combined_precission
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['combined']['recall'] = combined_recall
                              stats[lstm_num][word_embedding_size][user_embedding_size][num_friends][num_tweets]['combined']['f1'] = combined_f1
      
                            # Confusion Matrix
                            getConfusionMatrix("Replies", replies_pred_bin)
                            getConfusionMatrix("Mentions", mentions_pred_bin)
                            getConfusionMatrix("Quote-Tweets", quote_tweets_pred_bin)
                            getConfusionMatrix("Combined", combined_pred_bin)

pickle.dump(stats, open( graph_model_stats_file, "wb" ))


Output hidden; open in https://colab.research.google.com to view.