In [33]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import sklearn.metrics
from sklearn.preprocessing import StandardScaler
from early_stopping import EarlyStopping

from datasets import CaliforniaHousingDataset, AdultDataset, TitanicDataset, AutoMpgDataset, WineDataset
from metrics import calculate_global_fidelity, calculate_global_neighborhood_fidelity
from models.base_model import BaseClassifier, BaseRegressor
from models.surrogate_model import SurrogateClassifier, SurrogateRegressor

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [34]:
housing_train = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=True)
housing_test = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=False)

adult_train = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=True)
adult_test = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=False)

titanic_train = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=True)
titanic_test = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=False)

wine_train = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=True)
wine_test = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=False)

autompg_train = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=True)
autompg_test = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=False)

In [35]:
# target normalization just to check (the mse scores in the paper are low - what's the transformation / metric used?)
# scaler = StandardScaler()

# after seeing the results, it seems the authors just put the target through a StandardScaler
# housing_train.target = pd.Series(scaler.fit_transform(np.array(housing_train.target).reshape(-1, 1)).flatten())
# housing_test.target = pd.Series(scaler.fit_transform(np.array(housing_test.target).reshape(-1, 1)).flatten())
# wine_train.target = pd.Series(scaler.fit_transform(np.array(wine_train.target).reshape(-1, 1)).flatten())
# wine_test.target = pd.Series(scaler.fit_transform(np.array(wine_test.target).reshape(-1, 1)).flatten())
# autompg_train.target = pd.Series(scaler.fit_transform(np.array(autompg_train.target).reshape(-1, 1)).flatten())
# autompg_test.target = pd.Series(scaler.fit_transform(np.array(autompg_test.target).reshape(-1, 1)).flatten())

# TODO move that to datasets

In [36]:
lr = 0.001
batch_size = 128  # not from the paper
binary_classification_criterion = torch.nn.BCELoss()
regression_criterion = ... # "logarithm of the hyperbolic cosine" from the paper (?)
regression_criterion = torch.nn.MSELoss()
# TODO early stopping

In [37]:
def train(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        train_data: Dataset,
        criterion,
        epochs: int,
        alpha: float,
        early_stopping: EarlyStopping = None
):
    params = list(base_model.parameters()) + list(surrogate_model.parameters())
    optimizer = Adam(params, lr=lr)
    loader = DataLoader(train_data, batch_size=batch_size)
    for epoch in range(epochs):
        running_loss = 0
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            labels = labels.reshape(-1, 1)
            optimizer.zero_grad()

            base_model_preds = base_model(data)
            surrogate_model_preds = surrogate_model(data)
            loss = criterion(base_model_preds, labels)
            point_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
            mtl_loss = alpha * loss + (1 - alpha) * point_fidelity
        
            mtl_loss.backward()
            optimizer.step()
            running_loss += mtl_loss
        
        train_loss = running_loss / len(loader)
        if early_stopping is not None:
            early_stopping(train_loss, base_model, surrogate_model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        print(f"epoch: {epoch + 1}, train loss: {train_loss:.3f}")


def validate_base_classifier(
        model: nn.Module,
        test_data: Dataset,
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds_proba = model(data)
        preds = torch.where(preds_proba >= 0.5, 1, 0)
        accuracy = sklearn.metrics.accuracy_score(labels.cpu(), preds.cpu())
        f1_score = sklearn.metrics.f1_score(labels.cpu(), preds.cpu())
        print(f"test accuracy: {accuracy:.3f}, f1 score: {f1_score:.3f}")


def validate_base_regressor(
        model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds = model(data)
        mse = sklearn.metrics.mean_squared_error(labels.cpu(), preds.cpu())
        print(f"test mse: {mse:.3f}")


def validate_surrogate_model(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, _ = next(iter(loader))
        data = data.to(device)
        base_model_preds = base_model(data)
        surrogate_model_preds = surrogate_model(data)
        global_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
        global_neighborhood_fidelity = calculate_global_neighborhood_fidelity(base_model, surrogate_model, data)
        print(f"global fidelity: {global_fidelity:.3f}, global neighborhood fidelity: {global_neighborhood_fidelity:.3f}")


def validate_regressors(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_regressor(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)


def validate_classifiers(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_classifier(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)

# TODO local explainability evaluation

In [39]:
epochs = 100
alpha = 0.5
patience = 10 # num of epochs to early stop if loss doesn't decrease
save_dir = "checkpoints"

classification_data = {
    "adult": (adult_train, adult_test),
    "titanic": (titanic_train, titanic_test)
}
classifiers = []

for dataset in classification_data.keys():
    train_data, test_data = classification_data[dataset]
    base_model = BaseClassifier(
        input_dim=train_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateClassifier(input_dim=train_data.features.shape[1], output_dim=1).to(device)
    print(f"{dataset}:")
    early_stopping = EarlyStopping(dir=save_dir, dataset_name=dataset, patience=patience, verbose=False)
    train(base_model, surrogate_model, train_data, binary_classification_criterion, epochs, alpha, early_stopping)
    classifiers.append((base_model, surrogate_model))
    print()

adult:
Saving models ...
epoch: 1, train loss: 0.216
Saving models ...
epoch: 2, train loss: 0.178
Saving models ...
epoch: 3, train loss: 0.169
Saving models ...
epoch: 4, train loss: 0.163
Saving models ...
epoch: 5, train loss: 0.159
Saving models ...
epoch: 6, train loss: 0.156
Saving models ...
epoch: 7, train loss: 0.154
Saving models ...
epoch: 8, train loss: 0.151
Saving models ...
epoch: 9, train loss: 0.150
Saving models ...
epoch: 10, train loss: 0.148
Saving models ...
epoch: 11, train loss: 0.147
Saving models ...
epoch: 12, train loss: 0.145
Saving models ...
epoch: 13, train loss: 0.144
Saving models ...
epoch: 14, train loss: 0.142
EarlyStopping counter: 1 out of 10
epoch: 15, train loss: 0.143
Saving models ...
epoch: 16, train loss: 0.140
Saving models ...
epoch: 17, train loss: 0.139
EarlyStopping counter: 1 out of 10
epoch: 18, train loss: 0.141
EarlyStopping counter: 2 out of 10
epoch: 19, train loss: 0.139
EarlyStopping counter: 3 out of 10
epoch: 20, train loss: 

In [40]:
for i, dataset in enumerate(classification_data.keys()):
    test_data = classification_data[dataset][1]
    
    base_model = BaseClassifier(
        input_dim=test_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateClassifier(input_dim=test_data.features.shape[1], output_dim=1).to(device)
    
    base_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/base_model_checkpoint.pt"))
    surrogate_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/surrogate_model_checkpoint.pt"))
    
    base_model.eval()
    surrogate_model.eval()
    
    print(f"{dataset}:")
    validate_classifiers(base_model, surrogate_model, test_data)
    print()

adult:
test accuracy: 0.921, f1 score: 0.834
global fidelity: 0.048, global neighborhood fidelity: 0.050

titanic:
test accuracy: 0.837, f1 score: 0.770
global fidelity: 0.032, global neighborhood fidelity: 0.032



In [41]:
regression_data = {
    "wine": (wine_train, wine_test),
    "housing": (housing_train, housing_test),
    "autompg": (autompg_train, autompg_test)
}
regressors = []

for dataset in regression_data.keys():
    train_data, test_data = regression_data[dataset]
    base_regressor = BaseRegressor(
        input_dim=train_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_regressor = SurrogateRegressor(input_dim=train_data.features.shape[1], output_dim=1).to(device)
    print(f"{dataset}:")
    early_stopping = EarlyStopping(dir=save_dir, dataset_name=dataset, patience=patience, verbose=False)
    train(base_regressor, surrogate_regressor, train_data, regression_criterion, epochs, alpha, early_stopping)
    regressors.append((base_regressor, surrogate_regressor))
    print()

wine:
Saving models ...
epoch: 1, train loss: 0.537
Saving models ...
epoch: 2, train loss: 0.459
Saving models ...
epoch: 3, train loss: 0.436
Saving models ...
epoch: 4, train loss: 0.422
Saving models ...
epoch: 5, train loss: 0.406
Saving models ...
epoch: 6, train loss: 0.392
Saving models ...
epoch: 7, train loss: 0.381
Saving models ...
epoch: 8, train loss: 0.370
Saving models ...
epoch: 9, train loss: 0.360
Saving models ...
epoch: 10, train loss: 0.350
Saving models ...
epoch: 11, train loss: 0.340
Saving models ...
epoch: 12, train loss: 0.330
Saving models ...
epoch: 13, train loss: 0.321
Saving models ...
epoch: 14, train loss: 0.312
Saving models ...
epoch: 15, train loss: 0.305
Saving models ...
epoch: 16, train loss: 0.298
Saving models ...
epoch: 17, train loss: 0.291
Saving models ...
epoch: 18, train loss: 0.285
Saving models ...
epoch: 19, train loss: 0.281
Saving models ...
epoch: 20, train loss: 0.276
Saving models ...
epoch: 21, train loss: 0.270
Saving models ..

In [42]:
for i, dataset in enumerate(regression_data.keys()):
    test_data = regression_data[dataset][1]
    
    base_model = BaseRegressor(
        input_dim=test_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateRegressor(input_dim=test_data.features.shape[1], output_dim=1).to(device)
    
    base_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/base_model_checkpoint.pt"))
    surrogate_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/surrogate_model_checkpoint.pt"))
    
    base_model.eval()
    surrogate_model.eval()
    
    print(f"{dataset}:")
    validate_regressors(base_model, surrogate_model, test_data)
    print()

wine:
test mse: 0.275
global fidelity: 0.170, global neighborhood fidelity: 0.160

housing:
test mse: 0.189
global fidelity: 0.065, global neighborhood fidelity: 0.079

autompg:
test mse: 0.135
global fidelity: 0.114, global neighborhood fidelity: 0.115

