In [19]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import sklearn.metrics
from sklearn.preprocessing import StandardScaler
from early_stopping import EarlyStopping
from log_cosh_loss import LogCoshLoss

from datasets import CaliforniaHousingDataset, AdultDataset, TitanicDataset, AutoMpgDataset, WineDataset
from metrics import calculate_global_fidelity, calculate_global_neighborhood_fidelity
from models.base_model import BaseClassifier, BaseRegressor
from models.surrogate_model import SurrogateClassifier, SurrogateRegressor

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [20]:
housing_train = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=True)
housing_test = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=False)

adult_train = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=True)
adult_test = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=False)

titanic_train = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=True)
titanic_test = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=False)

wine_train = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=True)
wine_test = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=False)

autompg_train = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=True)
autompg_test = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=False)

In [22]:
lr = 0.001
batch_size = 128  # not from the paper
binary_classification_criterion = torch.nn.BCELoss()
regression_criterion = LogCoshLoss() # "logarithm of the hyperbolic cosine" from the paper

In [23]:
def train(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        train_data: Dataset,
        criterion,
        epochs: int,
        alpha: float,
        early_stopping: EarlyStopping = None
):
    params = list(base_model.parameters()) + list(surrogate_model.parameters())
    optimizer = Adam(params, lr=lr)
    loader = DataLoader(train_data, batch_size=batch_size)
    for epoch in range(epochs):
        running_loss = 0
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            labels = labels.reshape(-1, 1)
            optimizer.zero_grad()

            base_model_preds = base_model(data)
            surrogate_model_preds = surrogate_model(data)
            loss = criterion(base_model_preds, labels)
            point_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
            mtl_loss = alpha * loss + (1 - alpha) * point_fidelity
        
            mtl_loss.backward()
            optimizer.step()
            running_loss += mtl_loss
        
        train_loss = running_loss / len(loader)
        if early_stopping is not None:
            early_stopping(train_loss, base_model, surrogate_model)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        print(f"epoch: {epoch + 1}, train loss: {train_loss:.3f}")


def validate_base_classifier(
        model: nn.Module,
        test_data: Dataset,
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds_proba = model(data)
        preds = torch.where(preds_proba >= 0.5, 1, 0)
        accuracy = sklearn.metrics.accuracy_score(labels.cpu(), preds.cpu())
        f1_score = sklearn.metrics.f1_score(labels.cpu(), preds.cpu())
        print(f"test accuracy: {accuracy:.3f}, f1 score: {f1_score:.3f}")


def validate_base_regressor(
        model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds = model(data)
        mse = sklearn.metrics.mean_squared_error(labels.cpu(), preds.cpu())
        print(f"test mse: {mse:.3f}")


def validate_surrogate_model(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, _ = next(iter(loader))
        data = data.to(device)
        base_model_preds = base_model(data)
        surrogate_model_preds = surrogate_model(data)
        global_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
        global_neighborhood_fidelity = calculate_global_neighborhood_fidelity(base_model, surrogate_model, data)
        print(f"global fidelity: {global_fidelity:.3f}, global neighborhood fidelity: {global_neighborhood_fidelity:.3f}")


def validate_regressors(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_regressor(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)


def validate_classifiers(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_classifier(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)

# TODO local explainability evaluation

In [24]:
epochs = 100
alpha = 0.5
patience = 10  # for early stopping
save_dir = "checkpoints"

classification_data = {
    "adult": (adult_train, adult_test),
    "titanic": (titanic_train, titanic_test)
}
classifiers = []

for dataset in classification_data.keys():
    train_data, test_data = classification_data[dataset]
    base_model = BaseClassifier(
        input_dim=train_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateClassifier(input_dim=train_data.features.shape[1], output_dim=1).to(device)
    print(f"{dataset}:")
    early_stopping = EarlyStopping(dir=save_dir, dataset_name=dataset, patience=patience, verbose=False)
    train(base_model, surrogate_model, train_data, binary_classification_criterion, epochs, alpha, early_stopping)
    classifiers.append((base_model, surrogate_model))
    print()

adult:
Saving models ...
epoch: 1, train loss: 0.220
Saving models ...
epoch: 2, train loss: 0.181
Saving models ...
epoch: 3, train loss: 0.171
Saving models ...
epoch: 4, train loss: 0.165
Saving models ...
epoch: 5, train loss: 0.160
Saving models ...
epoch: 6, train loss: 0.156
Saving models ...
epoch: 7, train loss: 0.153
Saving models ...
epoch: 8, train loss: 0.151
Saving models ...
epoch: 9, train loss: 0.149
Saving models ...
epoch: 10, train loss: 0.147
Saving models ...
epoch: 11, train loss: 0.146
Saving models ...
epoch: 12, train loss: 0.144
Saving models ...
epoch: 13, train loss: 0.143
Saving models ...
epoch: 14, train loss: 0.141
Saving models ...
epoch: 15, train loss: 0.139
Saving models ...
epoch: 16, train loss: 0.139
Saving models ...
epoch: 17, train loss: 0.137
Saving models ...
epoch: 18, train loss: 0.135
Saving models ...
epoch: 19, train loss: 0.134
Saving models ...
epoch: 20, train loss: 0.133
Saving models ...
epoch: 21, train loss: 0.132
Saving models .

In [25]:
for i, dataset in enumerate(classification_data.keys()):
    test_data = classification_data[dataset][1]
    
    base_model = BaseClassifier(
        input_dim=test_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateClassifier(input_dim=test_data.features.shape[1], output_dim=1).to(device)
    
    base_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/base_model_checkpoint.pt"))
    surrogate_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/surrogate_model_checkpoint.pt"))
    
    base_model.eval()
    surrogate_model.eval()
    
    print(f"{dataset}:")
    validate_classifiers(base_model, surrogate_model, test_data)
    print()

adult:
test accuracy: 0.922, f1 score: 0.833
global fidelity: 0.052, global neighborhood fidelity: 0.054

titanic:
test accuracy: 0.804, f1 score: 0.725
global fidelity: 0.021, global neighborhood fidelity: 0.021



In [26]:
regression_data = {
    "wine": (wine_train, wine_test),
    "housing": (housing_train, housing_test),
    "autompg": (autompg_train, autompg_test)
}
regressors = []

for dataset in regression_data.keys():
    train_data, test_data = regression_data[dataset]
    base_regressor = BaseRegressor(
        input_dim=train_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_regressor = SurrogateRegressor(input_dim=train_data.features.shape[1], output_dim=1).to(device)
    print(f"{dataset}:")
    early_stopping = EarlyStopping(dir=save_dir, dataset_name=dataset, patience=patience, verbose=False)
    train(base_regressor, surrogate_regressor, train_data, regression_criterion, epochs, alpha, early_stopping)
    regressors.append((base_regressor, surrogate_regressor))
    print()

wine:
Saving models ...
epoch: 1, train loss: 0.339
Saving models ...
epoch: 2, train loss: 0.214
Saving models ...
epoch: 3, train loss: 0.194
Saving models ...
epoch: 4, train loss: 0.184
Saving models ...
epoch: 5, train loss: 0.180
Saving models ...
epoch: 6, train loss: 0.175
Saving models ...
epoch: 7, train loss: 0.171
Saving models ...
epoch: 8, train loss: 0.168
Saving models ...
epoch: 9, train loss: 0.164
Saving models ...
epoch: 10, train loss: 0.161
Saving models ...
epoch: 11, train loss: 0.158
Saving models ...
epoch: 12, train loss: 0.155
Saving models ...
epoch: 13, train loss: 0.153
Saving models ...
epoch: 14, train loss: 0.150
Saving models ...
epoch: 15, train loss: 0.148
Saving models ...
epoch: 16, train loss: 0.145
Saving models ...
epoch: 17, train loss: 0.143
Saving models ...
epoch: 18, train loss: 0.141
Saving models ...
epoch: 19, train loss: 0.139
Saving models ...
epoch: 20, train loss: 0.137
Saving models ...
epoch: 21, train loss: 0.135
Saving models ..

In [27]:
for i, dataset in enumerate(regression_data.keys()):
    test_data = regression_data[dataset][1]
    
    base_model = BaseRegressor(
        input_dim=test_data.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
    surrogate_model = SurrogateRegressor(input_dim=test_data.features.shape[1], output_dim=1).to(device)
    
    base_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/base_model_checkpoint.pt"))
    surrogate_model.load_state_dict(torch.load(f"{save_dir}/{dataset}/surrogate_model_checkpoint.pt"))
    
    base_model.eval()
    surrogate_model.eval()
    
    print(f"{dataset}:")
    validate_regressors(base_model, surrogate_model, test_data)
    print()

wine:
test mse: 0.415
global fidelity: 0.042, global neighborhood fidelity: 0.041

housing:
test mse: 0.256
global fidelity: 0.022, global neighborhood fidelity: 0.027

autompg:
test mse: 0.118
global fidelity: 0.022, global neighborhood fidelity: 0.022

