In [None]:
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import seaborn as sns
from joblib import Parallel, delayed
import torch
import pandas as pd

from mothernet.priors.boolean_conjunctions import BooleanConjunctionPrior

In [None]:
def function_of_rank(rank=2, length=4):
    inputs = np.array(np.meshgrid(*[[-1, 1]]*length)).T.reshape(-1, length)
    outputs = np.zeros(2**length, dtype=bool)
    while 3 * outputs.sum() < len(inputs):
        selected_bits = np.random.choice(length, size=rank, replace=False)
        signs = np.random.choice([-1, 1], size=rank)
        outputs = outputs + ((signs * inputs[:, selected_bits]) == 1).all(axis=1)
    return (inputs + 1) / 2, outputs

In [None]:
def function_of_rank_random_data(rank=2, max_length=20, max_samples=1000):
    length = np.random.randint(rank, max(max_length, rank + 1))
    n_samples = np.random.randint(2, max_samples)
    inputs = 2 * np.random.randint(0, 2, (n_samples, length)) - 1
    outputs = np.zeros(n_samples, dtype=bool)


    while 3 * outputs.sum() < len(inputs):
        selected_bits = np.random.choice(length, size=rank, replace=False)
        signs = np.random.choice([-1, 1], size=rank)
        outputs = outputs + ((signs * inputs[:, selected_bits]) == 1).all(axis=1)
    return (inputs + 1) / 2, outputs

In [None]:
def sample_boolean_data(hyperparameters, n_samples, num_features, device):
    max_rank = hyperparameters.get("max_rank", 10)
    rank = np.random.randint(1, min(max_rank, num_features))
    n_samples = n_samples
    inputs = 2 * torch.randint(0, 2, (n_samples, num_features), device=device) - 1
    outputs = torch.zeros(n_samples, dtype=bool, device=device)

    while 3 * torch.sum(outputs) < len(inputs):
        selected_bits = torch.multinomial(torch.ones(num_features), rank, replacement=False)
        signs = torch.randint(2, (rank,))*2-1
        outputs = outputs + ((signs * inputs[:, selected_bits]) == 1).all(dim=1)
    return (inputs + 1) / 2, outputs

In [None]:
def get_scores_rank(rank, models, mode='unused_features'):
    if mode == "random":
        X, y = function_of_rank_random_data(rank=rank)
    elif mode == "unused_features":
        X, y, _ = BooleanConjunctionPrior(hyperparameters={'max_rank': rank}).sample(n_samples=1000, num_features=20, device="cpu")
    elif mode == "enumerate":
        X, y = function_of_rank(rank=rank, length=10)
    result = {'rank': rank}
    for model_name, model in models.items():
        result[model_name] = np.mean(cross_validate(model, X, y, cv=StratifiedKFold(shuffle=True), scoring="roc_auc", error_score='raise')['test_score'])
    return result

In [None]:
from mothernet.prediction import EnsembleMeta, MotherNetClassifier, TabPFNClassifier
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import torch
import numpy as np
from joblib import Parallel, delayed
import pandas as pd

torch.set_num_threads(1)
device = "cpu"
tabpfn = TabPFNClassifier(device="cpu", model_string="tabpfn__emsize_512_nlayers_12_steps_2048_bs_32ada_lr_0.0001_1_gpu_07_24_2023_01_43_33_nooptimizer", epoch="1650", N_ensemble_configurations=3)
mothernet = EnsembleMeta(MotherNetClassifier(path="mn_d2048_H4096_L2_W32_P512_1_gpu_warm_08_25_2023_21_46_25_epoch_3940_no_optimizer.pickle", device=device), n_estimators=3)


prototypes = np.arange(1, 100, 5)
models = {
    'MLP': MLPClassifier(max_iter=4000),
    'TabPFN': tabpfn,
    'RandomForest': RandomForestClassifier(),
    'MotherNet': mothernet

}
res = Parallel(n_jobs=1)(delayed(get_scores_rank)(rank=rank, models=models) for i in range(20) for rank in range(1, 11))
rank = pd.DataFrame.from_dict(res)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 3))
sns.lineplot(data=rank2.melt(id_vars="rank", var_name="model", value_name="score"), x="rank", y="score", hue="model", ax=plt.gca())
plt.savefig("figures/boolean_conjunction_random_data.pdf", dpi=300, bbox_inches="tight")
plt.ylabel("ROC AUC")