In [7]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import sklearn.metrics

from datasets import CaliforniaHousingDataset, AdultDataset, TitanicDataset, AutoMpgDataset, WineDataset
from metrics import calculate_global_fidelity, calculate_global_neighborhood_fidelity
from models.base_model import BaseClassifier, BaseRegressor
from models.surrogate_model import SurrogateClassifier, SurrogateRegressor

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [8]:
housing_train = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=True)
housing_test = CaliforniaHousingDataset(
    dataset_path="data/california_housing/cal_housing.data", normalize=True, train=False)

adult_train = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=True)
adult_test = AdultDataset(dataset_path="data/adult/adult.data", normalize=True, train=False)

titanic_train = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=True)
titanic_test = TitanicDataset(dataset_path="data/titanic/titanic.arff", normalize=True, train=False)

wine_train = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=True)
wine_test = WineDataset(dataset_path="data/wines/winequality-red.csv", normalize=True, train=False)

autompg_train = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=True)
autompg_test = AutoMpgDataset(dataset_path="data/autompg/auto-mpg.data", normalize=True, train=False)

In [9]:
lr = 0.001
batch_size = 128  # not from the paper
binary_classification_criterion = torch.nn.BCELoss()
regression_criterion = ... # "logarithm of the hyperbolic cosine" from the paper (?)
regression_criterion = torch.nn.MSELoss()
# TODO early stopping

In [10]:
def train(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        train_data: Dataset,
        test_data: Dataset,
        criterion,
        epochs: int,
        alpha: float
):
    params = list(base_model.parameters()) + list(surrogate_model.parameters())
    optimizer = Adam(params, lr=lr)
    loader = DataLoader(train_data, batch_size=batch_size)
    for epoch in range(epochs):
        running_loss = 0
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            labels = labels.reshape(-1, 1)
            optimizer.zero_grad()

            base_model_preds = base_model(data)
            surrogate_model_preds = surrogate_model(data)
            loss = criterion(base_model_preds, labels)
            point_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
            mtl_loss = alpha * loss + (1 - alpha) * point_fidelity

            mtl_loss.backward()
            optimizer.step()
            running_loss += mtl_loss
        print(f"epoch: {epoch}, train loss: {running_loss / len(loader):.4f}")


def validate_base_classifier(
        model: nn.Module,
        test_data: Dataset,
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds_proba = model(data)
        preds = torch.where(preds_proba >= 0.5, 1, 0)
        accuracy = sklearn.metrics.accuracy_score(labels.cpu(), preds.cpu())
        f1_score = sklearn.metrics.f1_score(labels.cpu(), preds.cpu())
        print(f"test accuracy: {accuracy:.4f}, f1 score: {f1_score:.4f}")


def validate_base_regressor(
        model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, labels = next(iter(loader))
        data, labels = data.to(device), labels.to(device)
        labels = labels.reshape(-1, 1)
        preds = model(data)
        mse = sklearn.metrics.mean_squared_error(labels.cpu(), preds.cpu())
        print(f"test mse: {mse:.4f}")


def validate_surrogate_model(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    loader = DataLoader(test_data, batch_size=len(test_data))
    with torch.no_grad():
        data, _ = next(iter(loader))
        data = data.to(device)
        base_model_preds = base_model(data)
        surrogate_model_preds = surrogate_model(data)
        global_fidelity = calculate_global_fidelity(base_model_preds, surrogate_model_preds)
        global_neighborhood_fidelity = calculate_global_neighborhood_fidelity(base_model, surrogate_model, data)
        print(f"global fidelity: {global_fidelity}, global neighborhood fidelity: {global_neighborhood_fidelity}")


def validate_regressors(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_regressor(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)


def validate_classifiers(
        base_model: nn.Module,
        surrogate_model: nn.Module,
        test_data: Dataset
):
    validate_base_classifier(base_model, test_data)
    validate_surrogate_model(base_model, surrogate_model, test_data)

# TODO local explainability evaluation
# TODO wine regression (special case - ordinal regression)

In [11]:
base_model = BaseClassifier(input_dim=titanic_train.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
surrogate_model = SurrogateClassifier(input_dim=titanic_train.features.shape[1], output_dim=1).to(device)

train(base_model, surrogate_model, titanic_train, titanic_test, binary_classification_criterion, 10, 0.9)

epoch: 0, train loss: 0.6087
epoch: 1, train loss: 0.5437
epoch: 2, train loss: 0.4664
epoch: 3, train loss: 0.4366
epoch: 4, train loss: 0.4171
epoch: 5, train loss: 0.4124
epoch: 6, train loss: 0.4044
epoch: 7, train loss: 0.3977
epoch: 8, train loss: 0.3931
epoch: 9, train loss: 0.3894


In [12]:
validate_classifiers(base_model, surrogate_model, titanic_test)

test accuracy: 0.8373, f1 score: 0.7848
global fidelity: 0.07478304952383041, global neighborhood fidelity: 0.07480886578559875


In [13]:
base_regressor = BaseRegressor(input_dim=housing_train.features.shape[1], output_dim=1, n_hidden_layers=4, layer_size=128).to(device)
surrogate_regressor = SurrogateRegressor(input_dim=housing_train.features.shape[1], output_dim=1).to(device)

train(base_regressor, surrogate_regressor, housing_train, housing_test, regression_criterion, 10, 0.5)

epoch: 0, train loss: 26086823936.0000
epoch: 1, train loss: 17037743104.0000
epoch: 2, train loss: 16028775424.0000
epoch: 3, train loss: 15654859776.0000
epoch: 4, train loss: 15445307392.0000
epoch: 5, train loss: 15328828416.0000
epoch: 6, train loss: 15268144128.0000
epoch: 7, train loss: 15236384768.0000
epoch: 8, train loss: 15216320512.0000
epoch: 9, train loss: 15200400384.0000


In [14]:
# TODO normalize something?

validate_regressors(base_regressor, surrogate_regressor, housing_test)

test mse: 17189199872.0000
global fidelity: 12885541888.0, global neighborhood fidelity: 13027083264.0
