In [None]:
from src.game import classify_epochs, evaluate_nodes, check_until
# from src.core import REc

from itertools import combinations
from os import listdir, makedirs
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import random as rd

In [None]:
from inspect import isfunction, ismethod, isgeneratorfunction, isgenerator, isroutine
from inspect import isabstract, isclass, ismodule, istraceback, isframe, iscode, isbuiltin
from inspect import ismethoddescriptor, isdatadescriptor, isgetsetdescriptor, ismemberdescriptor
from inspect import isawaitable, iscoroutinefunction, iscoroutine

from collections.abc import Iterable as iterable

from pickle import load, dump

def isfx(field): return ismethod(field) or isfunction(field)

class GhostSet:
    """ enhanced interface (ghost) to retrieve class fields """
    def _meta(data): return {k:v for k,v in data.__dict__.items() if not isfx(v)}
    def _at_last(_, sets): pass
    def _set(object, **sets):
        ''' use to fast initialize fields | needed to avoid initialization problems at copy by value '''
        for field in sets: setattr(object, field, sets[field])
        object._at_last(sets)
GSet = GhostSet

def meta(object):
    ''' retrieves clonable object metadata (__dict__) as a copy '''
    if isinstance(object, GSet): return object._meta()
    return {}

class ClonableObjectGhost:
    """ enhanced interface (ghost) for clonable objects """
    def _by_val(_, depth=-1, _layer=0): pass
GCo = ClonableObjectGhost

class ClonableObject(GSet, GCo):
    """ base clonable object """
    def __init__(this, **data): this._set(**data)
    def __call__(_, **options): _._set(**options)
    def _by_val(_, depth=-1, _layer=0):
        copy = type(_)()
        copy._set(**_._meta())
        if depth<0 or depth>_layer:
            for field in copy.__dict__:
                if isinstance(copy.__dict__[field], ClonableObjectGhost):
                    copy.__dict__[field] = copy.__dict__[field]._by_val(depth,_layer+1)
        return copy
COb = ClonableObject

def copy_by_val(object, depth=-1, _layer=0):
    if isinstance(object, GCo): return object._by_val(depth,_layer)
    return object
copy = by_val = vof = copy_by_val

class ComparableGhost:
    """ enhanced interface (ghost) for comparing instances """
    def _compare(a, b):
        if type(a) != type(b): return False
        if a.__dict__ == b.__dict__: return True
        return False
    def __eq__(a, b): return a._compare(b)
GEq = ComparableGhost

class IterableObjectGhost(GSet):
    """ enhanced interface (ghost) for iterables: exposes __dict__,
        therefore Iterable Objects are like lua dictionaries """
    def __contains__(this, key): return key in this.__dict__
    def __iter__(this): return iter(this.__dict__)
    def items(my): return my.__dict__.items()
    def __getitem__(by, field): return by.__dict__[field]
    def __setitem__(by, field, value): by.__dict__[field] = value
    def pop(by, field): return by.__dict__.pop(field)
GIo = IterableObjectGhost

class ReprGhost:
    """ enhanced interface (ghost) for the skeleton method _repr,
        see implementation of Struct for a working example;
        Record __repr__ override uses _lines_ for max lines display """
    _lines_ = 31
    _chars_ = 13
    _msgsz_ = 62
    _ellipsis_ = ' ... '
    def _repr(my, value):
        _type = ''.join(''.join(str(type(value)).split('class ')).split("'"))
        _value = '{}'.format(value)
        if len(_value)>my._chars_:
            show = int(my._chars_/2)
            _value = _value[:show]+my._ellipsis_+_value[-show:]
        return '{} {}'.format(_type, _value)
    def _resize(this, message, at=.7):
        if len(message)>this._msgsz_:
            start = int(at*this._msgsz_)
            end = this._msgsz_-start
            return message[:start]+this._ellipsis_+message[-end:]
        return message
GRe = ReprGhost

def set_repr_to(lines): GRe._lines_ = lines

class Struct(COb, GEq, GIo, GRe):
    """ structured autoprintable object, behaves like a lua dictionary """
    def __repr__(_):
        return '\n'.join(['{}:\t{}'.format(k, _._repr(v)) for k,v in _.items()])
struct = Struct

class RecordableGhost:
    """ enhanced interface (ghost) for type recording,
        see Record for a working example """
    @staticmethod
    def load(filename):
        with open(filename, 'rb') as file: return load(file)
    def save(data, filename):
        with open(filename, 'wb') as file: dump(data, file)
        
GRec = RecordableGhost

class Record(GSet, GCo, GRec, GEq, GRe):
    """ wrapper for any object or value, auto-inspects and provides load/save type structure """
    data = None
    _check = dict(
            isfunction=isfunction, ismethod=ismethod, isgeneratorfunction=isgeneratorfunction, isgenerator=isgenerator, isroutine=isroutine,
            isabstract=isabstract, isclass=isclass, ismodule=ismodule, istraceback=istraceback, isframe=isframe, iscode=iscode, isbuiltin=isbuiltin,
            ismethoddescriptor=ismethoddescriptor, isdatadescriptor=isdatadescriptor, isgetsetdescriptor=isgetsetdescriptor, ismemberdescriptor=ismemberdescriptor,
            isawaitable=isawaitable, iscoroutinefunction=iscoroutinefunction, iscoroutine=iscoroutine
                   )
    def __init__(this, token, **meta):
        this.data = token
        this.__dict__.update({k:v(token) for k,v in this._check.items()})
        super()._set(**meta)
    @property
    def type(_): return type(_.data)
    def inherits(_, *types): return issubclass(_.type, types)
    @property
    def isbaseiterable(_): return _.inherits(tuple, list, dict, set) or _.isgenerator or _.isgeneratorfunction
    @property
    def isiterable(_): return isinstance(_.data, iterable) and _.type is not str
    def _clone_iterable(_):
        if _.inherits(dict): return _.data.copy()
        elif _.isgenerator or _.isgeneratorfunction: return (i for i in list(_.data))
        else: return type(_.data)(list(_.data)[:])
    def _meta(data): return {k:v for k,v in data.__dict__.items() if k != 'data' and not isfx(v)}
    def _by_val(_, depth=-1, layer=0):
        data = _.data
        if _.isiterable: data = _._clone_iterable()
        elif _.inherits(ClonableObjectGhost): data = by_val(data, depth, layer)
        return type(_)(data, **meta(_))
    def __enter__(self): self._instance = self; return self
    def __exit__(self, type, value, traceback): self._instance = None
    def __repr__(self):
        if not hasattr(self, '_preprint'): return Record(self.data, _preprint='', _lines=Record(Record._lines_)).__repr__()
        if self.isbaseiterable:
            pre, repr = self._preprint, ''
            for n,i in enumerate(self.data):
                if self._lines.data == 0: break
                else: self._lines.data -= 1
                index, item = str(n), i
                if self.inherits(dict): index += ' ({})'.format(str(i)); item = self.data[i]
                repr += pre+'{}: '.format(index)
                next = Record(item, _preprint=pre+'\t', _lines=self._lines)
                if next.isiterable: repr += '\n'
                repr += next.__repr__()
                repr += '\n'
            return repr
        elif self.inherits(GCo): return Record(self.data._meta(), _preprint=self._preprint, _lines=self._lines).__repr__()
        else: return self._repr(self.data)
REc = Record

class Bisect(list, COb):
    """ bisect implementation using clonable objects """
    def __init__(set, *items, key=None, reverse=False):
        if not key: key = lambda  x:x
        super().__init__(sorted(items, reverse=reverse, key=key))
    def _bisect(set, item, key, reverse, bottom, top):
        def _(check):
            if key: return key(check)
            return check
        at = int((top-bottom)/2)+bottom
        if len(set)==0: return (0,-1)
        if item==_(set[at]): return (at,0)
        bigger = item<_(set[at])
        if bigger != reverse:
            if at-bottom>0: return set._bisect(item, key, reverse, bottom, at)
            return (at,-1)
        elif top-at>1: return set._bisect(item, key, reverse, at, top)
        return (at,1)
    def search(_, item, key=None, reverse=False):
        if not key: key = lambda x:x
        return _._bisect(item, key, reverse, 0, len(_))
    def _by_val(_, depth=-1, _layer=0):
        copy = super()._by_val(depth, _layer)
        copy += _[:]
        return copy
BSx = Bisect

In [None]:
main_folder = "/home/kivi/gdrive/epigame-folder/"

path_cm = main_folder + "connectivity_matrices/" 

In [None]:
woi = input("Time window:\n 1. Non-seizure (baseline)\n 2. Pre-seizure (5 min prior to seizure)\n 3. Pre-seizure (4 min prior to seizure)\n 4. Pre-seizure (3 min prior to seizure)\n 5. Pre-seizure (2 min prior to seizure)\n 6. Pre-seizure (1 min prior to seizure)\n 7. Transition to seizure (1 min interval)\n 8. Transition to seizure (2 min interval)\n 9. Transition to seizure (60% seizure length interval)\n 10. Seizure\n Indicate a number: ")

woi_code = {'1':"baseline", '2':"preseizure5", '3':"preseizure4", '4':"preseizure3", '5':"preseizure2", '6':"preseizure1", '7':"transition1", '8':"transition2", '9':"transition60", '10':"seizure"}

In [None]:
max_net_size = 18

In [None]:
path_net = main_folder + "selected_network/"
makedirs(path_net, exist_ok=True)

In [None]:
for file_cm in listdir(path_cm):
  
  if file_cm.split("-")[1]==woi_code[woi]:

    print("\n--------------------------------------------------------------")
    print("\nProcessing...")

    subject_id = file_cm.split("/")[-1][0:3]
    print("Connectivity matrices of", file_cm)

    cm = REc.load(path_cm + file_cm).data

    nodes = cm.nodes
    node_ids = list(range(len(nodes))) 
    print("Number of nodes =",len(nodes))
    print("\nNodes:", nodes)

    print("\nTotal number of epochs =", len(cm.X))
    print("Connectivity matrix shape =", cm.X[0].shape)
    print("All matrices have the same shape:", all([m.shape==(len(nodes),len(nodes)) for m in cm.X]))
    # plt.figure(figsize=(5,5))
    # plt.imshow(cm.X[-1], cmap='Blues', interpolation='nearest')
    # plt.show()
    # print(cm.X[-1])

    node_pairs = combinations(node_ids, 2)

    print("\nProcessing node combinations...")

    parallelize = Parallel(n_jobs=-1)(delayed(evaluate_nodes)(pair, nodes, classify_epochs(cm, pair)) for pair in node_pairs)
    base = [p for p in parallelize]

    print(f"{len(base)} finished")

    base.sort(key=lambda x:x[-1], reverse=True)
    best_pair = base[0]
    best_net = [best_pair]
    print(f"Best node pair: {best_net}")

    best_score, net_size, possible_node_groups, test_nets = base[0][-1], 3, base[:], []
    print("Best score =", best_score)

    all_node_groups = {} # This dictionary saves all tested node groups, under a key indicating net_size (number of grouped nodes) 
    all_node_groups[2] = base

    while net_size <= max_net_size:

      all_node_groups[net_size] = []

      print(f"\nChecking networks with {net_size} nodes...")

      head = check_until(possible_node_groups, fall=best_score)
      
      count_node_groups = 0

      # The condition below checks if all tested node groups have the same score (the best score);
      # if this is the case, we stop the process and save the selected network as all possible nodes.
      # We predited that this could occur in the seizure propagation time window, e.g.
      if possible_node_groups[:head] == possible_node_groups: 

        print("All possible networks present the best score.")
        selected_net = nodes
        print(f"\nSelected network: {selected_net} ({len(selected_net)} nodes in total)")

        file_net = file_cm.split(".")[0]
        REc(struct(test_nets=all_node_groups, nodes=selected_net)).save(path_net + f"{file_net}.res")
        break

      else:

        # In case there not all, but many network with the best score, the processing time could become impractical;
        # to bypass this, we define a limit of maximally considered number of top networks as the *max_net_size* parameter.
        # (If the selected network is much larger than the actual resection in good outcome patients, the result is useless.)
        # Thus, among the top networks, a number equal to *max_net_size* of randomly picked networks are selected for the next iteration.
        possible_node_groups = possible_node_groups[:head if head>0 else 1]
        if len(possible_node_groups) >= max_net_size: 

          print(f"More than {max_net_size} networks present the best score. Randomly selecting {max_net_size} networks from the pool.")
          possible_node_groups = rd.sample(possible_node_groups, max_net_size)

        for node_group in possible_node_groups:
            # Here, we iterate through the node groups with the highest score, as possibly there are more than one

            for node in node_ids:
              # All possible nodes are added to the group and tested

              if node not in node_group[0]:
                  # Avoiding duplicate nodes

                  test_group = node_group[0] + (node,)

                  # Perform the classification between baseline and WOI epochs, using the support vector machine
                  # Compute the cross-validation scores, using the K-Fold method
                  # Apply the evaluation function to the cross-validation scores
                  eval = evaluate_nodes(test_group, nodes, classify_epochs(cm, test_group))

                  # Store the tested node groups in test_nets list and all_node_groups dictionary, under the net_size key
                  test_nets.append(eval)
                  all_node_groups[net_size].append(eval)

              count_node_groups += 1

        print(f"Tested {count_node_groups} node groups.")

        # Sort the latest networks by their score (indexed -1) and save the best evaluation score
        test_nets.sort(key=lambda x:x[-1], reverse=True)
        all_node_groups[net_size].sort(key=lambda x:x[-1], reverse=True)

        evaluation_score = test_nets[0][-1]

        print(f"Best score for networks of size {net_size} =", evaluation_score)
        print(f"Best network of size {net_size}: {test_nets[0][1]}")

        if evaluation_score >= best_score:
            # If the new score is higher than the previous best score, 
            # update the best score and the possible node groups for the next iteration
            if net_size <= max_net_size:

                best_score = evaluation_score
                print("\nNew best score =", best_score)

                head_i = check_until(test_nets, fall=best_score)
                best_net = test_nets[:head_i if head_i>0 else 1]
                print("\nNew best network =", best_net)

                possible_node_groups = best_net
                test_nets = []
                                
            net_size += 1
            
        else: 
          print("A better network not found.")

          selected_net = sorted(set([t for n in best_net for t in n[1].split('<->')]))

          print(f"\nSelected network: {selected_net} ({len(selected_net)} nodes in total)")

          file_net = file_cm.split(".")[0]
          REc(struct(test_nets=all_node_groups, nodes=selected_net)).save(path_net + f"{file_net}.res")
          break
    
      if net_size==max_net_size+1:
        print("Reached the maximum network size.")

        selected_net = sorted(set([t for n in best_net for t in n[1].split('<->')]))

        print(f"\nSelected network: {selected_net} ({len(selected_net)} nodes in total)")

        file_net = file_cm.split(".")[0]
        REc(struct(test_nets=all_node_groups, nodes=selected_net)).save(path_net + f"{file_net}.res")
        break