In [2]:
from dense_subgraph import sdp
import numpy as np
import json
import utils
import classification
from pipeline import Pipeline
from cs_transformer import ContrastSubgraphTransformer
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler

In [3]:
# For reproducible results, set random_state to a number, otherwise set it to None
random_state = 42

A_LABEL="ASD"
B_LABEL="TD"

GRAPH_DIR_PREFIX = "./data/lanciano_datasets_corr_thresh_80/"
DATA_DESCRIPTOR = "Lanciano-Processed"

# Best Parameters reported by Lanciano et al.
best_params = {
    "CSP1": {
        "children": {
            "percentile_TD_ASD": 70,
            "percentile_ASD_TD": 80,
        },
        "adolescents": {
            "percentile_TD_ASD": 95,
            "percentile_ASD_TD": 70,
        },
        "eyesclosed": {
            "percentile_TD_ASD": 75,
            "percentile_ASD_TD": 75,
        },
        "male": {
            "percentile_TD_ASD": 75,
            "percentile_ASD_TD": 70,
        },
    },
    "CSP2": {
        "children": 70,
        "adolescents": 75,
        "eyesclosed": 70,
        "male": 70,
    },
}

In [5]:
for dataset_name in ["children", "adolescents", "eyesclosed", "male"]:
    A_GRAPH_DIR = f"{GRAPH_DIR_PREFIX}{dataset_name}/asd/"
    B_GRAPH_DIR = f"{GRAPH_DIR_PREFIX}{dataset_name}/td/"

    graphs_A = utils.get_graphs_from_files(A_GRAPH_DIR)
    graphs_B = utils.get_graphs_from_files(B_GRAPH_DIR)
    asd_count = len(graphs_A) if A_LABEL == "ASD" else len(graphs_B)
    td_count = len(graphs_B) if B_LABEL == "TD" else len(graphs_A)

    graphs, labels = utils.label_and_concatenate_graphs(graphs_A=graphs_A, graphs_B=graphs_B, a_label=A_LABEL, b_label=B_LABEL)
    
    pipe = [ContrastSubgraphTransformer, StandardScaler, SVC]
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)

    # CSP1 SDP
    p_grid = {
      "SVC": {
        "C": [100], "gamma": [0.1]
      },
      "ContrastSubgraphTransformer": {
        "a_label": [A_LABEL],
        "b_label": [B_LABEL],
        "alpha": [None],
        "alpha2": [None],

        # ASD - TD
        "percentile": [best_params["CSP1"][dataset_name]["percentile_ASD_TD"]],
        # TD - ASD
        "percentile2": [best_params["CSP1"][dataset_name]["percentile_TD_ASD"]],

        "problem": [1],
        "solver": [sdp],
        "num_cs": [1],
      }
    }
    results, _ = classification.grid_search_cv(X=graphs, y=labels, pipeline_steps=pipe, step_param_grids=p_grid, cv=cv, random_state=random_state)
    classification.write_results_to_file(filename=f'./outputs/{DATA_DESCRIPTOR}-CV-CSP1-SDP-N1-{dataset_name}.txt',
                            summary=results["summary"], results=results["best_results"], parameter_grid=p_grid, asd_count=asd_count, td_count=td_count)

    #CSP2 SDP
    p_grid = {"SVC": {"C": [100], "gamma": [0.1]},
          "ContrastSubgraphTransformer": {
            "a_label": [A_LABEL],
            "b_label": [B_LABEL],
            "alpha": [None],
            "alpha2": [None],

            "percentile": [best_params["CSP2"][dataset_name]],
            "percentile2": [None],

            "problem": [2],
            "solver": [sdp],
            "num_cs": [1],
            }
          }
    results, _ = classification.grid_search_cv(X=graphs, y=labels, pipeline_steps=pipe, step_param_grids=p_grid, cv=cv, random_state=random_state)
    classification.write_results_to_file(filename=f'./outputs/{DATA_DESCRIPTOR}-CV-CSP2-SDP-N1-{dataset_name}.txt',
                            summary=results["summary"], results=results["best_results"], parameter_grid=p_grid, asd_count=asd_count, td_count=td_count)

grid_search_cv will take aproximately 0:00:41.199440
grid_search_cv will take aproximately 0:00:11.062035
grid_search_cv will take aproximately 0:00:35.270875
grid_search_cv will take aproximately 0:00:15.782120
grid_search_cv will take aproximately 0:00:35.029075
grid_search_cv will take aproximately 0:00:13.517770
grid_search_cv will take aproximately 0:00:40.386160
grid_search_cv will take aproximately 0:00:16.623225
