# Setup working directory and import path

In [None]:
import os
import pytz

MY_TIMEZONE = pytz.timezone("Asia/Ho_Chi_Minh")

DATASET_DIR = "."
CHECKPOINT_DIR = "./checkpoints"
WORKING_DIR = os.getcwd()

# Import

In [None]:
#import tensorflow as tf
#tf.compat.v1.disable_v2_behavior()
import matplotlib.pyplot as plt
import numpy as np

import copy
from tqdm import trange
from sklearn import metrics
from sklearn.utils import class_weight

#import os
import dgl
import torch
import dgl.nn as dglnn
import pandas as pd
import numpy as np
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F


# Config

In [None]:
import yaml
from yaml.loader import SafeLoader
with open("config2.yaml") as f:
    config = yaml.load(f, Loader=SafeLoader)

config['server']['model'] = "DeepTraLogmodel"

# Load and Transform data

In [None]:
DATASET_PATH = DATASET_DIR + "/glove"
NODES_PATH = DATASET_PATH + "cadets-e3-2-nodes.csv"
EDGES_PATH = DATASET_PATH + "cadets-e3-2-edges.csv"

nodes_df = pd.read_csv(NODES_PATH, index_col=0)
edges_df = pd.read_csv(EDGES_PATH, index_col=0)

In [None]:
def transform_feat(string):
  data = np.array(list(map(np.float32, string.split(","))))
  return data


nodes_df["h"] = nodes_df["feat_str"].apply(transform_feat)

In [None]:
def create_edges_df_with_node_ids(nodes_df, edges_df):
  edges_df = edges_df[["source_uuid", "destination_uuid"]]

  source_nodes_df = nodes_df[["node_id", "uuid"]].rename(columns={
        "node_id": "source_node_id",
        "uuid": "node_uuid"
      })
  edges_df = pd.merge(edges_df, source_nodes_df, how="left",
                      left_on="source_uuid", right_on="node_uuid").drop(columns=["node_uuid"])

  destination_nodes_df = nodes_df[["node_id", "uuid"]].rename(columns={
        "node_id": "destination_node_id",
        "uuid": "node_uuid"
      })
  edges_df = pd.merge(edges_df, destination_nodes_df, how="left",
                      left_on="destination_uuid", right_on="node_uuid").drop(columns=["node_uuid"])

  edges_df = edges_df.dropna(how="any")
  edges_df = edges_df.astype({"source_node_id": "int32", "destination_node_id": "int32"})

  return edges_df


def create_node_ids(nodes_df):
  nodes_df = nodes_df.copy()
  nodes_df = nodes_df.reset_index().reset_index().rename({ "level_0": "node_id" }, axis=1)
  return nodes_df


def create_graph(nodes_df, edges_df):
  nodes_df = create_node_ids(nodes_df)
  edges_df = create_edges_df_with_node_ids(nodes_df, edges_df)

  source_node_ids = torch.from_numpy(np.array(edges_df["source_node_id"].values.tolist(), dtype=np.int32))
  destination_node_ids = torch.from_numpy(np.array(edges_df["destination_node_id"].values.tolist(), dtype=np.int32))
  g = dgl.graph((source_node_ids, destination_node_ids), num_nodes=len(nodes_df))

  g.ndata["h"] = torch.from_numpy(np.array(nodes_df["h"].values.tolist(), dtype=np.float32))
  g.ndata["label"] = torch.from_numpy(np.array(nodes_df["label"].values.tolist(), dtype=np.int32))
  g.ndata["train_mask"] = torch.ones(len(g.ndata['h']), dtype=torch.bool)

  return g

In [None]:
from mdlp2.data.pipe.fed import fed_pipe
from datetime import datetime
import pickle


TEST_SIZE = config['dataset']['test_size']
NUM_CLIENTS = config['server']['n_client']

data_dir = CHECKPOINT_DIR + "/data/deeptralog-n10-dcadets-111323"

if not os.path.exists(data_dir):
  os.makedirs(data_dir)

if not os.path.exists(data_dir + "/test_data.pkl") or not os.path.exists(data_dir + "/client_data.pkl"):
  fed_pipeline = fed_pipe(source_type="DataFrame", split_radio=TEST_SIZE, n_client=NUM_CLIENTS, keep_radio_feature='label', random_state=123, verbose=False)
  test_data, client_data = fed_pipeline(nodes_df)
  with open(data_dir + "/test_data.pkl", "wb") as f:
    pickle.dump(test_data, f)
  with open(data_dir + "/client_data.pkl", "wb") as f:
    pickle.dump(client_data, f)
else:
  with open(data_dir + "/test_data.pkl", "rb") as f:
    test_data = pickle.load(f)
  with open(data_dir + "/client_data.pkl", "rb") as f:
    client_data = pickle.load(f)


In [None]:
node_test_data = test_data
G_test = create_graph(node_test_data, edges_df)
G_test = dgl.add_self_loop(G_test)

client_G_data = [None]*len(client_data)
for idx in range(len(client_data)):
    node_client_data = client_data[idx]
    client_G_data[idx] = create_graph(node_client_data, edges_df)
    client_G_data[idx] = dgl.add_self_loop(client_G_data[idx])

## Define Graph

In [None]:
class GGNN(nn.Module):
    def __init__(self, in_feats, out_feats, dropout):
        super().__init__()

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        self.conv1 = dglnn.GatedGraphConv(
            in_feats=in_feats, out_feats=in_feats, n_steps=3, n_etypes=1
        )

        self.conv2 = dglnn.GatedGraphConv(
            in_feats=in_feats, out_feats=in_feats, n_steps=3, n_etypes=1
        )

        self.lin1 = nn.Linear(in_feats, 64)
        self.lin2 = nn.Linear(64, out_feats)

    def forward(self, graph, feat, eweight=None):
        h = self.conv1(graph, feat)
        h = self.activation(h)
        h = self.dropout(h)

        h = self.conv2(graph, feat)
        h = self.activation(h)
        h = self.dropout(h)

        h = self.lin1(h)
        h = self.activation(h)

        return self.lin2(h)

In [None]:
def train(G_train, model, optim, criterion, config):

  node_label = G_train.ndata["label"]
  train_mask = G_train.ndata["train_mask"]

  best_model_weights = None
  min_loss = None

  history = []

  t_epochs = trange(config["epochs"])

  for _ in t_epochs:
    # print(f'Epoch {epoch+1}/{config["epochs"]}')
    pred = model(
        G_train, G_train.ndata["h"])

    loss = criterion(pred[train_mask],
                        node_label[train_mask].type(torch.LongTensor))
    optim.zero_grad()
    loss.backward()
    optim.step()

    y_true = node_label[train_mask]
    y_pred = pred[train_mask].argmax(1).detach().cpu().numpy()

    local_history = [loss.item()] + [metrics.accuracy_score(y_true, y_pred)] + list(
          metrics.precision_recall_fscore_support(y_true, y_pred, average='binary'))[:-1]
    local_history = [round(x, 4) for x in local_history]
    history.append(local_history)
    t_epochs.set_postfix_str(
          f"loss: {local_history[0]} - Accuracy: {local_history[1]} - precision: {local_history[2]} - recall: {local_history[3]} - f1_score: {local_history[4]}")
      # t_epochs.refresh()
      # print(f"loss: {local_history[0]} - Accuracy: {local_history[1]} - precision: {local_history[2]} - recall: {local_history[3]} - f1_score: {local_history[4]}")

    if min_loss is None:
      min_loss = local_history[0]
      best_model_weights = model.state_dict()
    elif local_history[0] < min_loss:
      min_loss = local_history[0]
      best_model_weights = model.state_dict()

  model.load_state_dict(best_model_weights)

  return history


# FedAvg

## Aggregated Function

In [None]:
def avg_aggr(client_param, lr_s, save_weight):
  avg_weight = {}
  total_client_data = 0

  for k, v in save_weight.items():
    avg_weight[k] = torch.zeros_like(v)
  
  for weight, sample_count in client_param:
    total_client_data += sample_count
    
    for k, v in weight.items():
      avg_weight[k] += sample_count * (v - save_weight[k])

  for k, v in avg_weight.items():
    avg_weight[k] = save_weight[k] + v * lr_s / total_client_data
  
  return avg_weight

## Train

In [None]:
global_model = GGNN(in_feats=G_test.ndata["h"].shape[2],
                          out_feats=2,
                          dropout=0.2).cpu()
client_model = [GGNN(in_feats=G_test.ndata["h"].shape[2],
                          out_feats=2,
                          dropout=0.2).cpu() for _ in range(config['server']['n_client'])]
for model in client_model:
  model.load_state_dict(global_model.state_dict())
simulation_return = []

# if "evaluation_history" in config['simulation_return']:
#   simulation_return['evaluation_history'] = {}


In [None]:
def evaluate(G_test, model):
  criterion = nn.CrossEntropyLoss()
  edge_label = G_test.ndata["label"]
  test_mask = G_test.ndata["train_mask"]
  with torch.no_grad():
    pred = model(G_test, G_test.ndata["h"])
    loss = criterion(pred[test_mask], edge_label[test_mask].type(torch.LongTensor))
    y_true = edge_label[test_mask]
    y_pred = pred[test_mask].argmax(1).detach().cpu().numpy()
  history = [loss.item()] + [metrics.accuracy_score(y_true, y_pred)] + list(metrics.precision_recall_fscore_support(y_true, y_pred, average = 'binary'))[:-1]
  history = [round(x,4) for x in history]
  print(f"Evaluation - loss: {history[0]} - Accuracy: {history[1]} - precision: {history[2]} - recall: {history[3]} - f1_score: {history[4]}")
  return history

In [None]:
for r in range(config['server']['r']):
  print(f"[o] Start Round: {r+1}")

  
  for c in range(config['server']['n_client']):
    print(f"[c{c+1}] Start Training")
    G_train = copy.deepcopy(client_G_data[c])

    class_weights = class_weight.compute_class_weight(class_weight='balanced',
                                                  classes=np.unique(G_train.ndata['label'].numpy()),
                                                  y=G_train.ndata['label'].numpy())
    class_weights = torch.FloatTensor(class_weights).cpu()

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optim_config = config['client']['compile']['optim']
    if config['client']['compile']['optim'] == "SGD":
      opt = torch.optim.SGD(client_model[c].parameters(), lr = config['client']['compile']['lr'], momentum = config['client']['compile']['momentum'])
    else:
      opt = torch.optim.Adam(client_model[c].parameters(), lr = config['client']['compile']['lr'])
    
    train(G_train, client_model[c], opt, criterion,  config['client']['train'])
    print(f"[c{c+1}] Finish Training")


  print("[s] calculate aggregate value")
  client_param = []
  for idx, model in enumerate(client_model):
    client_param.append([copy.deepcopy(model.state_dict()), client_G_data[idx].ndata["h"].shape[0]])
  
  avg_weight = avg_aggr(client_param, config['server']['lr_s'], copy.deepcopy(global_model.state_dict()))
  print("[s] Update client weights")
  global_model.load_state_dict(avg_weight)
  for model in client_model:
    model.load_state_dict(avg_weight)
  print("[s] Evaluation")
  r_history = evaluate(G_test, global_model)

  simulation_return.append(copy.deepcopy(r_history))

  print(f"[o] End Round: {r+1}")

## Evaluation

In [None]:
evaluate(G_test, global_model)

# Save

In [None]:
import pickle

if not os.path.exists(os.path.join(CHECKPOINT_DIR, f"./train/{config['server']['model']}")):
    os.makedirs(os.path.join(CHECKPOINT_DIR, f"./train"),exist_ok=True)
    

with open(os.path.join(CHECKPOINT_DIR, f"./train/{config['server']['model']}.pkl"), 'wb') as f:
    pickle.dump([global_model.state_dict(), simulation_return], f)


# Visualization

In [None]:
mertrics = iter(zip(*simulation_return))
losses = next(mertrics)
losses = [float(x) for x in losses]
acc = next(mertrics)
acc = [float(x) for x in acc]
prec = next(mertrics)
prec = [float(x) for x in prec]
recall = next(mertrics)
recall = [float(x) for x in recall]
f1 = next(mertrics)
f1 = [float(x) for x in f1]
n_round = config['server']['r'] + 1

In [None]:
plt.plot(range(1,n_round), acc, label='Accuracy')
plt.plot(range(1,n_round), prec, label='Precision')
plt.plot(range(1,n_round), recall, label='Recall')
plt.plot(range(1,n_round), f1, label='F1-score')
plt.xticks(range(1,n_round))
#plt.yticks(np.arange(0.4, 1.1, step=0.1))
plt.title(f"kb4-f:fedavg-n:{config['server']['n_client']}-m:{config['server']['model']}d:{config['dataset']['name']}-t:{datetime.now(MY_TIMEZONE).strftime('%d/%m/%Y')}")
plt.xlabel('Rounds')
plt.ylabel('Metrics')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.legend()
plt.show()

In [None]:
print("Loss Accuracy Precision Recall F1-Score")
for a in zip(losses, acc, prec, recall, f1):
    print(*a)
# print(acc)
# print(prec)
# print(recall)
# print(f1)