In [None]:
import torch
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torcheval.metrics.functional import binary_auprc

from util_final import (
    get_mlp,
    normalize_features,
    get_binary_cross_entropy,
    get_binary_accuracy,
    pbt,
    AdamW
)

from sampling_methods import (
    random_undersample,
    smote,
    knn_undersampling,
    tomek_links
)

In [None]:
device = 'mps'
config = {
        "alpha": 0.25,
        "dataset_path": "creditcard.pt",
        "device": device,
        "ensemble_shape": (64,),
        "features_dtype": torch.float32,
        "gamma": 2.0,
        "labels_dtype": torch.float32,
        "float_dtype": torch.float32,
        "hyperparameter_raw_init_distributions": {
            "epsilon": torch.distributions.Uniform(
                torch.tensor(-10, device=device, dtype=torch.float32),
                torch.tensor(-5, device=device, dtype=torch.float32)
            ),
            "first_moment_decay": torch.distributions.Uniform(
                torch.tensor(-3, device=device, dtype=torch.float32),
                torch.tensor(0, device=device, dtype=torch.float32)
            ),
            "learning_rate": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            ),
            "second_moment_decay": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            ),
            "weight_decay": torch.distributions.Uniform(
                torch.tensor(-5, device=device, dtype=torch.float32),
                torch.tensor(-1, device=device, dtype=torch.float32)
            )
        },
        "hyperparameter_raw_perturb": {
            "epsilon": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "first_moment_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "learning_rate": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "second_moment_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            ),
            "weight_decay": torch.distributions.Normal(
                torch.tensor(0, device=device, dtype=torch.float32),
                torch.tensor(1, device=device, dtype=torch.float32)
            )
        },
        "hyperparameter_transforms": {
            "epsilon": lambda log10: 10 ** log10,
            "first_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
            "learning_rate": lambda log10: 10 ** log10,
            "second_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
            "weight_decay": lambda log10: 10 ** log10
        },
        "improvement_threshold": 1e-4,
        "minibatch_size": 128,
        "minibatch_size_eval": 1 << 8,
        "pbt": True,
        "seed": 0,
        "steps_num": 100_000,
        "steps_without_improvement": 1000,
        "valid_interval": 1000,
        "welch_confidence_level": .95,
        "welch_sample_size": 10,
    }

In [None]:
def get_binary_auprc(
    config: dict,
    logits: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    """
    Get the binary accuracy between a label and a logit tensor.
    It can handle arbitrary ensemble shapes.

    Parameters
    ----------
    logits : torch.Tensor
        The logit tensor. We assume it has shape
        `ensemble_shape + (dataset_size, 1)`.
    labels : torch.Tensor
        The tensor of true labels. We assume it has shape
        `(dataset_size,)` or `ensemble_shape + (dataset_size,)`.

    Returns
    -------
    The tensor of binary accuracies per ensemble member
    of shape `ensemble_shape`.
    """
    logit_positive = logits[..., 0]
    prob_positive = torch.sigmoid(logit_positive)
    true_positives = labels.broadcast_to(
        prob_positive.shape
    ).to(torch.bool)
    if len(logits.shape) == 1:
        num_tasks = 1
    else:
        num_tasks = logits.shape[0]
    return binary_auprc(prob_positive, true_positives, num_tasks=num_tasks)

In [None]:
torch.manual_seed(config["seed"])

In [None]:
def load_data():
    dataset = torch.load(config['dataset_path'], weights_only=True)
    train_features, train_labels = dataset['train_features'], dataset['train_labels']
    test_features, test_labels = dataset['test_features'], dataset['test_labels']

    train_features = train_features.to(dtype=config['features_dtype'], device=config['device'])
    train_labels = train_labels.to(dtype=config['labels_dtype'], device=config['device'])
    test_features = test_features.to(dtype=config['features_dtype'], device=config['device'])
    test_labels = test_labels.to(dtype=config['labels_dtype'], device=config['device'])

    return train_features, train_labels, test_features, test_labels

def train_valid_split(train_features, train_labels, test_size=0.1):
    train_features_np = train_features.detach().cpu().numpy()
    train_labels_np = train_labels.detach().cpu().numpy()

    train_features, valid_features, train_labels, valid_labels = train_test_split(
        train_features_np, train_labels_np, test_size=test_size, stratify=train_labels_np, random_state=config['seed']
    )

    train_features = torch.tensor(train_features, device=config['device'], dtype=config['features_dtype'])
    valid_features = torch.tensor(valid_features, device=config['device'], dtype=config['features_dtype'])
    train_labels = torch.tensor(train_labels, device=config['device'], dtype=config['labels_dtype'])
    valid_labels = torch.tensor(valid_labels, device=config['device'], dtype=config['labels_dtype'])

    return train_features, train_labels, valid_features, valid_labels

In [None]:
def grid_search(hyperparameters, sampling_procedure, loss_func=get_binary_cross_entropy):
    log = {}

    for param in hyperparameters:
        train_features, train_labels, valid_features, valid_labels, test_features, test_labels = sampling_procedure(param)

        model = get_mlp(config, train_features.shape[-1], 1, 3, 128)
        optimizer = AdamW(model.parameters())

        output = pbt(
            config,
            loss_func,
            get_binary_accuracy,
            model,
            optimizer,
            train_features,
            train_labels,
            valid_features,
            valid_labels
        )

        pred_logits = model(test_features)
        auprc = get_binary_auprc(
            config,
            pred_logits,
            test_labels
        )
        
        accuracy = get_binary_accuracy(
            config,
            pred_logits,
            test_labels
        )

        log[str(param)] = {
            "auprc" : auprc.max().item(),
            "accuracy" : accuracy.max().item(),
            "output" : output
        }

    return log

In [None]:
def baseline(param):
    train_features, train_labels, test_features, test_labels = load_data()
    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels, test_size=test_features.shape[0])

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

In [None]:
def undersample_random(size):
    train_features, train_labels, test_features, test_labels = load_data()

    positive_mask = train_labels > 0
    positive_features = train_features[positive_mask]
    negative_features = train_features[~positive_mask]
    positive_labels = train_labels[positive_mask]
    negative_labels = train_labels[~positive_mask]

    negative_features, negative_labels = random_undersample(negative_features, negative_labels, size)

    train_features = torch.cat((positive_features, negative_features), dim=0)
    train_labels = torch.cat((positive_labels, negative_labels), dim=0)
    indices = torch.randperm(train_features.shape[0])
    train_features = train_features[indices]
    train_labels = train_labels[indices]

    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

def undersample_tomek(param):
    train_features, train_labels, test_features, test_labels = load_data()
    train_features, train_labels = tomek_links(config, train_features, train_labels, 1)
    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

def undersample_knn(k):
    train_features, train_labels, test_features, test_labels = load_data()
    train_features, train_labels = knn_undersampling(config, train_features, train_labels, 1, k)
    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels


In [None]:
def oversample_smote(N):
    train_features, train_labels, test_features, test_labels = load_data()

    positive_mask = train_labels > 0
    positive_features = train_features[positive_mask]
    negative_features = train_features[~positive_mask]
    positive_labels = train_labels[positive_mask]
    negative_labels = train_labels[~positive_mask]

    positive_features, positive_labels = smote(config, positive_features, positive_labels, 1, N)

    train_features = torch.cat((positive_features, negative_features), dim=0)
    train_labels = torch.cat((positive_labels, negative_labels), dim=0)
    indices = torch.randperm(train_features.shape[0])
    train_features = train_features[indices]
    train_labels = train_labels[indices]

    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

def smote_random_undersample(params):
    N, size = params
    train_features, train_labels, test_features, test_labels = load_data()

    positive_mask = train_labels > 0
    positive_features = train_features[positive_mask]
    negative_features = train_features[~positive_mask]
    positive_labels = train_labels[positive_mask]
    negative_labels = train_labels[~positive_mask]

    positive_features, positive_labels = smote(config, positive_features, positive_labels, 1, N)
    negative_features, negative_labels = random_undersample(negative_features, negative_labels, size)

    train_features = torch.cat((positive_features, negative_features), dim=0)
    train_labels = torch.cat((positive_labels, negative_labels), dim=0)
    indices = torch.randperm(train_features.shape[0])
    train_features = train_features[indices]
    train_labels = train_labels[indices]

    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

def smote_tomek(N):
    train_features, train_labels, test_features, test_labels = load_data()

    positive_mask = train_labels > 0
    positive_features = train_features[positive_mask]
    negative_features = train_features[~positive_mask]
    positive_labels = train_labels[positive_mask]
    negative_labels = train_labels[~positive_mask]

    positive_features, positive_labels = smote(config, positive_features, positive_labels, 1, N)

    train_features = torch.cat((positive_features, negative_features), dim=0)
    train_labels = torch.cat((positive_labels, negative_labels), dim=0)
    indices = torch.randperm(train_features.shape[0])
    train_features = train_features[indices]
    train_labels = train_labels[indices]

    train_features, train_labels = tomek_links(config, train_features, train_labels, 1)

    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

def smote_knn(params):
    N, k = params
    train_features, train_labels, test_features, test_labels = load_data()

    positive_mask = train_labels > 0
    positive_features = train_features[positive_mask]
    negative_features = train_features[~positive_mask]
    positive_labels = train_labels[positive_mask]
    negative_labels = train_labels[~positive_mask]

    positive_features, positive_labels = smote(config, positive_features, positive_labels, 1, N)


    train_features = torch.cat((positive_features, negative_features), dim=0)
    train_labels = torch.cat((positive_labels, negative_labels), dim=0)
    indices = torch.randperm(train_features.shape[0])
    train_features = train_features[indices]
    train_labels = train_labels[indices]

    train_features, train_labels = knn_undersampling(config, train_features, train_labels, 1, k)
    train_features, train_labels, valid_features, valid_labels = train_valid_split(train_features, train_labels)

    normalize_features(
        train_features,
        (valid_features, test_features),
        verbose=False
    )

    return train_features, train_labels, valid_features, valid_labels, test_features, test_labels

In [None]:
def get_binary_focal_loss(
    config: dict,
    logits: torch.Tensor,
    labels: torch.Tensor
) -> torch.Tensor:
    """
    Compute the focal loss between a label and a logit tensor.
    It can handle arbitrary ensemble shapes.

    Parameters
    ----------
    config: dict
        Expects keys 'alpha' and 'gamma' with float values
    logits : torch.Tensor
        The logit tensor. We assume it has shape
        `ensemble_shape + (dataset_size,)`.
    labels : torch.Tensor
        The tensor of true labels. We assume it has shape
        `(dataset_size,)` or `ensemble_shape + (dataset_size, 1)`.

    Returns
    -------
    torch.Tensor
        The tensor of focal losses per ensemble member
        of shape `ensemble_shape`.
    """
    alpha, gamma = config["alpha"], config["gamma"]
    
    logits = logits[..., 0]
    labels = labels.broadcast_to(logits.shape)
    
    bce_loss = F.binary_cross_entropy_with_logits(
        logits,
        labels,
        reduction='none'
    )

    probs = torch.sigmoid(logits)
    probs = torch.clip(probs, 1e-7, 1-1e-7)

    p_t = torch.where(labels == 1, probs, 1 - probs)
    modulating_factor = (1 - p_t) ** gamma
    alpha_factor = torch.where(labels == 1, alpha, 1 - alpha)
    focal_loss = alpha_factor * modulating_factor * bce_loss

    return focal_loss.mean(dim=-1)

In [None]:
output = grid_search(['0'], baseline, get_binary_focal_loss)

In [None]:
output

In [None]:
train_features, train_labels, test_features, test_labels = load_data()

output = grid_search([int(n * train_labels.sum().item()) for n in range(1, 9)], undersample_random)

In [None]:
output = grid_search(['0'], undersample_tomek)

In [None]:
output = grid_search([k for k in range(50, 250, 50)], undersample_knn)

In [None]:
output = grid_search([n for n in range(2, 11)], oversample_smote)

In [None]:
train_features, train_labels, test_features, test_labels = load_data()

params = [(N, int(size * train_labels.sum().item() * N)) for N in range(9, 11) for size in range(5, 9)]
output = grid_search(params, smote_random_undersample)
for key in output:
    print(key, output[key]['auprc'])

In [None]:
for key in output:
    print(key, output[key]['auprc'])

In [None]:
output = grid_search([n for n in range(2, 11)], smote_tomek)

In [None]:
params = [(N, k) for N in range(9, 11) for k in range(50, 250, 50)]
output = grid_search(params, smote_knn)
for key in output:
    print(output[key]['auprc'])

In [None]:
output = grid_search(['0'], baseline, get_binary_focal_loss)