In [None]:
import torch
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torcheval.metrics.functional import binary_auprc

from util_final import (
    get_mlp,
    normalize_features,
    get_binary_cross_entropy,
    get_binary_accuracy,
    pbt,
    AdamW
)

from sampling_methods import (
    random_undersample,
    smote,
    knn_undersampling,
    tomek_links
)

In [None]:
device = 'mps'
config = {
        "alpha": 0.25,
        "dataset_path": "creditcard.pt",
        "device": device,
        "ensemble_shape": (64,),
        "features_dtype": torch.float32,
        "gamma": 2.0,
        "labels_dtype": torch.float32,
        "float_dtype": torch.float32,
        "hyperparameter_raw_init_distributions": {
            "epsilon": torch.distributions.Uniform(
                torch.tensor(-10, device=device, dtype=torch.float32),
                torch.tensor(-5, device=device, dtype=torch.float32)
            ),
            "first_moment_decay": torch.distributions.Uniform(
                torch.tensor(-3, device=device, dtype=torch.float32),
                torch.tensor(0, device=device, dtype=torch.float32)
            ),
            "learning_rate": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            ),
            "second_moment_decay": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            ),
            "weight_decay": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            )
        },
        "hyperparameter_raw_perturb": {
            "epsilon": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "first_moment_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "learning_rate": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "second_moment_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "weight_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            )
        },
        "hyperparameter_transforms": {
            "epsilon": lambda log10: 10 ** log10,
            "first_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
            "learning_rate": lambda log10: 10 ** log10,
            "second_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
            "weight_decay": lambda log10: 10 ** log10
        },
        "improvement_threshold": 1e-4,
        "minibatch_size": 128,
        "minibatch_size_eval": 1 << 8,
        "pbt": True,
        "seed": 0,
        "steps_num": 100_000,
        "steps_without_improvement": 1000,
        "valid_interval": 1000,
        "welch_confidence_level": .95,
        "welch_sample_size": 10,
    }

torch.manual_seed(config["seed"])

In [None]:
train_features, train_labels, valid_features, valid_labels, test_features, test_labels = sampling_procedure(param) # TODO

model = get_mlp(config, train_features.shape[-1], 1, 3, 128)
optimizer = AdamW(model.parameters())

output = pbt(
    config,
    get_binary_cross_entropy,
    get_binary_accuracy,
    model,
    optimizer,
    train_features,
    train_labels,
    valid_features,
    valid_labels
)

pred_logits = model(test_features)
accuracy = get_binary_accuracy(
    config,
    pred_logits,
    test_labels
)

print(f'Accuracy: f{accuracy.max().item()}')