In [None]:
from src.eeg import struct
from src.core import REc
from src.tools import exclude_from_sz_cm
from src.game import classify_epochs, evaluate_nodes, check_until

from itertools import combinations
from os import listdir, makedirs
import matplotlib.pyplot as plt
from joblib import Parallel, delayed

In [None]:
main_folder = "/home/kivi/gdrive/epigame-folder/"

path_cm = main_folder + "connectivity_matrices/" 

In [None]:
woi = input("Time window:\n 1. Non-seizure (baseline)\n 2. Pre-seizure (5 min prior to seizure)\n 3. Pre-seizure (4 min prior to seizure)\n 4. Pre-seizure (3 min prior to seizure)\n 5. Pre-seizure (2 min prior to seizure)\n 6. Pre-seizure (1 min prior to seizure)\n 7. Transition to seizure (1 min interval)\n 8. Transition to seizure (2 min interval)\n 9. Transition to seizure (60% seizure length interval)\n 10. Seizure\n Indicate a number: ")

woi_code = {'1':"baseline", '2':"preseizure5", '3':"preseizure4", '4':"preseizure3", '5':"preseizure2", '6':"preseizure1", '7':"transition1", '8':"transition2", '9':"transition60", '10':"seizure"}

In [None]:
max_net_size = 20

In [None]:
path_net = main_folder + "selected_network/"
makedirs(path_net, exist_ok=True)

In [None]:
for file_cm in listdir(path_cm):
  if file_cm.split("-")[1]==woi_code[woi]:

    print("\n--------------------------------------------------------------")
    print("\nProcessing...")

    subject_id = file_cm.split("/")[-1][0:3]
    print("Connectivity matrices of", file_cm)

    cm = REc.load(path_cm + file_cm).data

    # For subject ASJ, for connectivity analysis using CC and SCR delta (WOI 2-5),
    # the seizure connectivity matrices have an extra channel that was not excluded;
    # The node was excluded from the labels list (nodes), therefore:
    # This block of code excludes the channel "J9-J10", with the index of the channel "J10-J11" in the current labels list.
    if subject_id=="ASJ" and (file_cm.split("-")[2]=="CC" or (file_cm.split("-")[2]=="SCR" and int(file_cm.split("-")[1][-1]) in [2,3,4,5])):
      mismatch_channel_id = cm.nodes.index("J10-J11")
      print("Mismatched node index:", mismatch_channel_id)
      cm.X = exclude_from_sz_cm(cm.X, mismatch_channel_id)
      print("Check PREP file sets:", list(cm.__dict__))

    print("\nTotal number of epochs =", len(cm.X))
    print("Connectivity matrix shape =", cm.X[0].shape)
    print("All matrices have the same shape:", all([m.shape==(len(nodes),len(nodes)) for m in cm.X]))
    plt.figure(figsize=(5,5))
    plt.imshow(cm.X[-1], cmap='Blues', interpolation='nearest')
    plt.show()
    print(cm.X[-1])

    nodes = cm.nodes
    node_ids = list(range(len(nodes))) 
    print("Number of nodes =",len(nodes))
    print("\nNodes:", nodes)

    node_pairs = combinations(node_ids, 2)

    print("\nProcessing node combinations...")

    parallelize = Parallel(n_jobs=-1)(delayed(evaluate_nodes)(pair, nodes, classify_epochs(cm, pair)) for pair in node_pairs)
    base = [p for p in parallelize]

    print(f"{len(base)} finished")

    base.sort(key=lambda x:x[-1], reverse=True)
    best_pair = base[0]
    best_net = [best_pair]
    print(f"Best node pair: {best_net}")

    best_score, net_size, possible_node_groups, test_nets = base[0][-1], 3, base[:], []
    print("Best score =", best_score)

    all_node_groups = {} # This dictionary saves all tested node groups, under a key indicating net_size (number of grouped nodes) 
    all_node_groups[2] = base

    while net_size <= max_net_size:

      all_node_groups[net_size] = []

      print(f"\nChecking networks with {net_size} nodes...")

      head = check_until(possible_node_groups, fall=best_score)
      count_node_groups = 0

      for node_group in possible_node_groups[:head if head>0 else 1]:
          # Here, we iterate through the node groups with the highest score, as possibly there are more than one

          for node in node_ids:
            # All possible nodes are added to the group and tested

            if node not in node_group[0]:
                # Avoiding duplicate nodes

                test_group = node_group[0] + (node,)

                # Perform the classification between baseline and WOI epochs, using the support vector machine
                # Compute the cross-validation scores, using the K-Fold method
                # Apply the evaluation function to the cross-validation scores
                eval = evaluate_nodes(test_group, nodes, classify_epochs(cm, test_group))

                # Store the tested node groups in test_nets list and all_node_groups dictionary, under the net_size key
                test_nets.append(eval)
                all_node_groups[net_size].append(eval)

            count_node_groups += 1

      print(f"Tested {count_node_groups} node groups.")

      # Sort the latest networks by their score (indexed -1) and save the best evaluation score
      test_nets.sort(key=lambda x:x[-1], reverse=True)
      all_node_groups[net_size].sort(key=lambda x:x[-1], reverse=True)

      evaluation_score = test_nets[0][-1]

      print(f"Best score for networks of size {net_size} =", evaluation_score)
      print(f"Best network of size {net_size}: {test_nets[0][1]}")

      if evaluation_score >= best_score:
          # If the new score is higher than the previous best score, 
          # update the best score and the possible node groups for the next iteration
          if net_size <= max_net_size:

              best_score = evaluation_score
              print("\nNew best score =", evaluation_score)

              best_net = test_nets[:check_until(test_nets, fall=evaluation_score)]
              print("\nNew best network =", best_net)

              possible_node_groups = best_net
              test_nets = []
                              
          net_size += 1
          
      else: 
        print("A better network not found.")
        break

    selected_net = sorted(set([t for n in best_net for t in n[1].split('<->')]))


    print(f"\nSelected network: {selected_net} ({len(selected_net)} nodes in total)")

    file_net = file_cm.split(".")[0]
    REc(struct(test_nets=all_node_groups, nodes=selected_net)).save(path_net + f"{file_net}.res")