# Trees

## set working directory
Warning: only run the cell below once per kernel session

In [None]:
import os
from pathlib import Path

project_root = Path.cwd().parent
os.chdir(project_root)
print(os.getcwd())

## loading data

In [None]:
from data import Data
import tomllib

with open("config.toml", "rb") as file:
    config = tomllib.load(file)
    
data = Data(config)

## Setup for TreeBuilder

In [None]:
from ml.tree import TreeBuilder
from rpy2 import robjects
treebuilder = TreeBuilder(data)

## ctree: single tree

All plots are saved in `ml/[predictor]/`, or, e.g., `ml/[predictor1]_[predictor3]/` when multiple predictors are defined.

In [None]:
%matplotlib inline
from IPython.display import Image

model = treebuilder.build_ctree(
    testtype= "Bonferroni",
    teststat="quad",
    splittest=False,
    predictors=[
        "."]
)

treebuilder.save_tree(model=model, type=["img", "model"], testtype="Bonferroni", teststat="quad", splittest=False, predictors= ["."]) 
# Image(image_path)


## ctree: multiprocessed

Note: this might take a while depending on the number of parameters and the size of the dataset.

All plots are saved in `ml/[predictor]/`, or, e.g., `ml/[predictor1]_[predictor3]/` when multiple predictors are defined.

### Prepare parameters

A tree will be built for each combination of parameters. The parameters are defined below.

In [None]:
import itertools

# define parameters to be passed to ctree
teststats: list[str] = ["quad", "max"]
testtypes: list[str | list] = [
    "Teststatistic",
    "Univariate",
    "Bonferroni",
    "MonteCarlo",
    robjects.r.c("MonteCarlo", "Bonferroni"),
]
splitstats: list[str] = ["quad", "max"]
# splittests: list[bool] = [True, False]
alphas: list[float] = [0.1, 0.05, 0.01]
predictors: list[list[str]] = [
    ["consensus independent component 1"],
    ["consensus independent component 2"],
    ["consensus independent component 3"],
]

# make all possible combinations of parameters above
arg_combos = itertools.product(
    teststats, testtypes, splitstats, alphas, predictors
)


In [None]:
import multiprocessing as mp
from datetime import datetime

t0 = datetime.now()
print(f"starting time: {t0.time()}")

with mp.Pool(config["execution"]["cores"]) as p:
    p.starmap(treebuilder.build_ctree, arg_combos)
    
print("Time taken:", datetime.now() - t0)