In [1]:
%load_ext autoreload
%autoreload 2

from shared import Db

from main.ai import ai_setups
from main.data import Mol, data
from main.tree import AiTree, DictAiTree
from main.types import AiInput, Setup

from IPython.display import display

import numpy as np
import pandas

with Db("scores_merged", True) as db_scores_merged:
    def get_stats(mol: Mol, setup: Setup):
        ai_input: AiInput = {"smiles": mol.smiles, "setup": setup}
        return AiTree((db.read(["ai_postprocess", ai_input], DictAiTree), db_scores_merged)).stats()

    with Db("db", True) as db:
        mols = data()
        setup_and_stats = [(setup, [get_stats(mol, setup) for mol in mols]) for setup in ai_setups]
        for characteristic in ("max_depth", "max_width", "node_count", "not_solved_count"):
            display(characteristic)
            rows: list[tuple[str,float,float,float]] = []
            for s, l in setup_and_stats:
                ll: list[int] = [e[characteristic] for e in l]
                setup = f"{s['score']}-{s['agg']}-{s['uw_multiplier']}-{s['normalize']}"
                rows.append((setup, np.std(ll).item(), np.average(ll).item(), np.median(ll).item()))
            display(pandas.DataFrame(rows, columns=["setup", "std", "avg", "median"]))

'max_depth'

Unnamed: 0,setup,std,avg,median
0,"sc-max-0.0-(2.5, 4.5, False)",3.15806,8.836735,9.0
1,"sc-max-0.15-(2.5, 4.5, False)",3.0817,8.816327,9.0
2,"sc-max-0.15-(3.0, 4.5, False)",3.028258,8.816327,9.0
3,"sc-max-0.15-(3.5, 4.5, False)",2.906713,8.714286,9.0
4,"sc-max-0.4-(2.5, 4.5, False)",3.063944,8.857143,9.0
5,"sc-max-0.4-(3.0, 4.5, False)",3.048409,8.816327,9.0
6,"sc-max-0.4-(3.5, 4.5, False)",3.04376,8.795918,9.0
7,"sc-max-0.6-(2.5, 4.5, False)",3.026056,8.836735,9.0
8,"sc-max-0.6-(3.0, 4.5, False)",3.119046,8.836735,9.0
9,"sc-max-0.6-(3.5, 4.5, False)",3.011985,8.77551,9.0


'max_width'

Unnamed: 0,setup,std,avg,median
0,"sc-max-0.0-(2.5, 4.5, False)",78.871872,114.591837,101.0
1,"sc-max-0.15-(2.5, 4.5, False)",77.302421,114.265306,101.0
2,"sc-max-0.15-(3.0, 4.5, False)",71.402281,113.571429,103.0
3,"sc-max-0.15-(3.5, 4.5, False)",80.725985,117.734694,104.0
4,"sc-max-0.4-(2.5, 4.5, False)",72.661941,114.55102,103.0
5,"sc-max-0.4-(3.0, 4.5, False)",78.814889,116.734694,103.0
6,"sc-max-0.4-(3.5, 4.5, False)",68.788111,113.306122,103.0
7,"sc-max-0.6-(2.5, 4.5, False)",91.918707,120.367347,105.0
8,"sc-max-0.6-(3.0, 4.5, False)",93.702041,120.734694,105.0
9,"sc-max-0.6-(3.5, 4.5, False)",79.533076,117.428571,104.0


'node_count'

Unnamed: 0,setup,std,avg,median
0,"sc-max-0.0-(2.5, 4.5, False)",373.255247,614.102041,622.0
1,"sc-max-0.15-(2.5, 4.5, False)",365.726489,614.367347,629.0
2,"sc-max-0.15-(3.0, 4.5, False)",365.461167,622.816327,632.0
3,"sc-max-0.15-(3.5, 4.5, False)",369.438513,626.959184,634.0
4,"sc-max-0.4-(2.5, 4.5, False)",357.939257,624.755102,632.0
5,"sc-max-0.4-(3.0, 4.5, False)",363.3189,627.77551,633.0
6,"sc-max-0.4-(3.5, 4.5, False)",358.565557,627.081633,629.0
7,"sc-max-0.6-(2.5, 4.5, False)",375.355114,638.387755,633.0
8,"sc-max-0.6-(3.0, 4.5, False)",366.142283,635.693878,646.0
9,"sc-max-0.6-(3.5, 4.5, False)",361.020515,630.938776,642.0


'not_solved_count'

Unnamed: 0,setup,std,avg,median
0,"sc-max-0.0-(2.5, 4.5, False)",23.872108,15.857143,1.0
1,"sc-max-0.15-(2.5, 4.5, False)",24.029951,15.897959,1.0
2,"sc-max-0.15-(3.0, 4.5, False)",23.925906,15.857143,1.0
3,"sc-max-0.15-(3.5, 4.5, False)",23.623928,15.693878,1.0
4,"sc-max-0.4-(2.5, 4.5, False)",23.540673,15.285714,1.0
5,"sc-max-0.4-(3.0, 4.5, False)",23.910548,15.571429,1.0
6,"sc-max-0.4-(3.5, 4.5, False)",23.156315,15.22449,1.0
7,"sc-max-0.6-(2.5, 4.5, False)",22.902597,14.959184,1.0
8,"sc-max-0.6-(3.0, 4.5, False)",22.597358,14.632653,1.0
9,"sc-max-0.6-(3.5, 4.5, False)",23.008392,14.959184,1.0
