In [1]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from genraweb.lib.genrapy import GenRAPredClassHybrid, GenRAPredBinary, GenRAPredBinaryHybrid, GenRAPredValueHybrid
from genra.rax.skl.cls import GenRAPredClass
from genra.rax.skl.reg import GenRAPredValue
import numpy as np
import pandas as pd

In [2]:
# Set up the data

# "Yes" is positive here
Y_binary = pd.DataFrame(
    [
        "Yes", # Chem0 
        "No", # Chem1
        "Yes", # Chem2
        "No", # Chem3
        "Yes", # Chem4
    ],
    index=[f"Chem_{i}" for i in range(5)],
)
pos_label = "Yes"
Y_multiclass = pd.DataFrame(
    [
        "A", # Chem0 
        "B", # Chem1
        "A", # Chem2
        "C", # Chem3
        "A", # Chem4
    ],
    index=[f"Chem_{i}" for i in range(5)],
)
Y_continuous = pd.DataFrame(
    [
        70, # Chem0 
        -10, # Chem1
        45, # Chem2
        90, # Chem3
        33, # Chem4
    ],
    index=[f"Chem_{i}" for i in range(5)],
)

data = [
    {
        "fp": "fp1",
        "target": pd.DataFrame([[0, 1, 1, 1, 0]]),
        "X": pd.DataFrame(
            [
                [1.0, 0, 1, 1, 0], # 2/4=0.5
                [0, 1, 0, 0, 0], # 1/3=0.333333...
                [0, 0, 0, 1, 1], # 1/4=0.25
                [1, 1, 1, 1, 0], # 3/4=0.75
                [0, 0, 1, 1, 1], # 2/4=0.5
            ],
            columns=[f"fp1_{i}" for i in range(5)],
            index=[f"Chem_{i}" for i in range(5)],
        ),
        "binary": {
            # this is also the SWA
            "expected": (2/4 + 1/4 + 2/4) / (2/4 + 1/3 + 1/4 + 3/4 + 2/4), # ~=0.5357
        },
        "multiclass": {
            "expected": pd.DataFrame({
                "A": (2/4 + 1/4 + 2/4) / (2/4 + 1/3 + 1/4 + 3/4 + 2/4),
                "B": (1/3) / (2/4 + 1/3 + 1/4 + 3/4 + 2/4),
                "C": (3/4) / (2/4 + 1/3 + 1/4 + 3/4 + 2/4),
            }, index=[0]),
        },
        "continuous": {
            "expected": (70*(2/4) + -10*(1/3) + 45*(1/4) + 90*(3/4) + 33*(2/4)) / (2/4 + 1/3 + 1/4 + 3/4 + 2/4), # ~=54.3929
        },
        "hybrid_weight": 1,
    },
    {
        "fp": "fp2",
        "target": pd.DataFrame([[0, 1, 1]]),
        "X": pd.DataFrame(
            [
                [1, 0, 1], # 1/3=0.333333...
                [0, 1, 0], # 1/2=0.5
                [0, 0, 0], # 0/2=0
                [1, 1, 1], # 2/3=0.666667...
                [None, None, None], # DNE so N/A
            ],
            columns=[f"fp2_{i}" for i in range(3)],
            index=[f"Chem_{i}" for i in range(5)]
        ),
        "binary": {
                    # this is also the SWA

            "expected": (1/3 + 0/2) / (1/3 + 1/2 + 0/2 + 2/3), # ~=52.2222
        },
        "multiclass": {
            "expected": pd.DataFrame({
                "A": (1/3 + 0/2) / (1/3 + 1/2 + 0/2 + 2/3), # ~=0.2222
                "B": (1/2) / (1/3 + 1/2 + 0/2 + 2/3), # ~=0.3333
                "C": (2/3) / (1/3 + 1/2 + 0/2 + 2/3), # ~=0.4444
        
            }, index=[0]),
        },
        "continuous": {
            "expected": (70*(1/3) + -10*(1/2) + 45*(0/2) + 90*(2/3)) / (1/3 + 1/2 + 0/2 + 2/3), # ~= 54.3929
        },
        "hybrid_weight": 4,
    },
    {
        "fp": "fp3",
        "target": pd.DataFrame([[1, 0, 1, 1]]),
        "X": pd.DataFrame(
            [
                [1, 0, 1, 0], # 2/3=0.66666...
                [None, None, None, None], # DNE so N/A
                [0, 1, 0, 0], # 0/4=0
                [1, 1, 1, 1], # 3/4=0.75
                [0, 1, 0, 1], # 1/4=0.25
            ],
            columns=[f"fp3_{i}" for i in range(4)],
            index=[f"Chem_{i}" for i in range(5)],
        ),
        "binary": {
            # This is also the SWA
            "expected": (2/3 + 0/4 + 1/4) / (2/3 + 0/4 + 3/4 + 1/4), # =0.55
        },
        "multiclass": {
            "expected": pd.DataFrame({
                "A": (2/3 + 0/4 + 1/4) / (2/3 + 0/4 + 3/4 + 1/4), # =0.55
                # "B" DNE so N/A,
                "C": (3/4) / (2/3 + 0/4 + 3/4 + 1/4), # =0.45
            }, index=[0]),
        },
        "continuous": {
            "expected": ( 70*(2/3) + 45*(0/4) + 90*(3/4) + 33*(1/4) ) / ( 2/3 + 0/4 + 3/4 + 1/4 ), # = 73.45
        },
        "hybrid_weight": 3,
    },
]

In [3]:
# set up the models, predictions, and uncertainty calculations

fit_X, predict_X, hybrid_weights = [], [], []
fit_Y_binary, fit_Y_multiclass, fit_Y_continuous = [], [], []
for component in data:
    fp, target, X = component["fp"], component["target"], component["X"]
    # drop rows with None/NaN from X
    nan_rows = X.index[X.isnull().any(axis=1)]
    X = X.drop(index=nan_rows)
    # prepare lists for hybrid
    fit_X.append(X)
    predict_X.append(target)
    hybrid_weights.append(component["hybrid_weight"])

    # binary (component)
    fp_Y = Y_binary.drop(index=nan_rows)
    model = GenRAPredBinary(
        metric='jaccard',
        weights=lambda distances: 1 - distances,
        n_neighbors=X.shape[0],
    )
    model.fit(X, fp_Y)
    component["binary"].update({
        "model": model,
        "pred": model.predict(target),
        "proba": pd.DataFrame(model.predict_proba(target), columns=model.classes_),
        "uncertainty": model.calc_uncertainty(pos_label=pos_label),
    })
    fit_Y_binary.append(fp_Y)

    # multiclass (component)
    fp_Y = Y_multiclass.drop(index=nan_rows)
    model = GenRAPredClass(
        metric='jaccard',
        weights=lambda distances: 1 - distances,
        n_neighbors=X.shape[0],
    )
    model.fit(X, fp_Y)
    component["multiclass"].update({
        "model": model,
        "pred": model.predict(target),
        "proba": pd.DataFrame(model.predict_proba(target), columns=model.classes_),
    })
    fit_Y_multiclass.append(fp_Y)

    # continuous (component)
    fp_Y = Y_continuous.drop(index=nan_rows)
    model = GenRAPredValue(
        metric='jaccard',
        weights=lambda distances: 1 - distances,
        n_neighbors=X.shape[0],
    )
    model.fit(X, fp_Y)
    component["continuous"].update({
        "model": model,
        "pred": model.predict(target),
    })
    fit_Y_continuous.append(fp_Y)

hybrid_results = {
    "binary": {},
    "multiclass": {},
    "continuous": {},
}

n_neighbors = max([component["X"].shape[0] for component in data])

# hybrid binary
model = GenRAPredBinaryHybrid(
    metric='jaccard',
    weights=lambda distances: 1 - distances,
    n_neighbors=n_neighbors,
)
model.fit(fit_X, fit_Y_binary)
hybrid_results["binary"].update({
    "model": model,
    "pred": model.predict(predict_X, hybrid_weights=hybrid_weights),
    "proba": pd.DataFrame(
        model.predict_proba(predict_X, hybrid_weights=hybrid_weights),
        columns=model.classes_
    ),
    "uncertainty": model.calc_uncertainty(pos_label=pos_label, hybrid_weights=hybrid_weights),
})

# hybrid multiclass
model = GenRAPredClassHybrid(
    metric='jaccard',
    weights=lambda distances: 1 - distances,
    n_neighbors=n_neighbors,
)
model.fit(fit_X, fit_Y_multiclass)
hybrid_results["multiclass"].update({
    "model": model,
    "pred": model.predict(predict_X, hybrid_weights=hybrid_weights),
    "proba": pd.DataFrame(
        model.predict_proba(predict_X, hybrid_weights=hybrid_weights),
        columns=model.classes_,
    ),
})

# hybrid continuous
model = GenRAPredValueHybrid(
    metric='jaccard',
    weights=lambda distances: 1 - distances,
    n_neighbors=n_neighbors,
)
model.fit(fit_X, fit_Y_continuous)
hybrid_results["continuous"].update({
    "model": model,
    "pred": model.predict(predict_X, hybrid_weights=hybrid_weights),
})

In [6]:
# check the results

def is_equal(a, b):
    return abs(a-b) < 0.00001

binary, multiclass, continuous = [], [], []
for component in data:
    # component level check
    
    # binary
    # (this is also the SWA, since it's the probability of positive observation)
    swa = component["binary"]["proba"][pos_label][0]
    assert is_equal(
        swa,
        component["binary"]["expected"]
    )
    binary.append(swa)

    # multiclass
    classed = {}
    for _class in ["A", "B", "C"]:
        if _class in component["multiclass"]["model"].classes_:
            proba = component["multiclass"]["proba"][_class][0]
            assert is_equal(
                proba,
                component["multiclass"]["expected"][_class][0]
            )
            classed[_class] = proba
    multiclass.append(classed)

    # continuous
    pred = component["continuous"]["pred"][0][0]
    assert is_equal(
        pred,
        component["continuous"]["expected"],
    )
    continuous.append(pred)

weights = np.array(hybrid_weights)
sum_weights = np.sum(weights)

# hybrid binary
assert is_equal(
    hybrid_results["binary"]["proba"][pos_label][0],
    np.dot(np.array(binary), weights)/sum_weights
)

# hybrid multiclass
for _class in ["A", "B", "C"]:
    probas = np.array([classed.get(_class, 0) for classed in multiclass])
    assert is_equal(
        hybrid_results["multiclass"]["proba"][_class][0],
        np.dot(probas, weights)/sum_weights
    )
        
# hybrid continuous
assert is_equal(
    hybrid_results["continuous"]["pred"][0][0],
    np.dot(np.array(continuous), weights)/sum_weights
)
