In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from counterfactuals.cf_methods import PPCEFR
from counterfactuals.datasets.regression_file_dataset import RegressionFileDataset
from counterfactuals.datasets.method_dataset import MethodDataset
from counterfactuals.preprocessing import (
    MinMaxScalingStep,
    PreprocessingPipeline,
    TorchDataTypeStep,
)
from counterfactuals.datasets.preprocess_datasets.preprocess_moons import (
    transform_moons,
)
from counterfactuals.metrics.metrics import evaluate_cf
from counterfactuals.models import MLPRegressor, MaskedAutoregressiveFlow

In [22]:
file_dataset = RegressionFileDataset(config_path="../config/datasets/toy_regression.yaml")
preprocessing_pipeline = PreprocessingPipeline(
    [
        ("minmax", MinMaxScalingStep()),
        ("torch_dtype", TorchDataTypeStep()),
    ]
)
dataset = MethodDataset(file_dataset, preprocessing_pipeline)
train_dataset = TensorDataset(
    torch.tensor(dataset.X_train, dtype=torch.float32),
    torch.tensor(dataset.y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(dataset.X_test, dtype=torch.float32),
    torch.tensor(dataset.y_test, dtype=torch.float32),
)

train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [26]:
disc_model = MLPRegressor(num_inputs=dataset.X_train.shape[1], num_targets=1)
disc_model.fit(train_dataloader, test_dataloader, epochs=1000)


Epoch 367, Train Loss: 207.2553, Test Loss: 101.0486, Patience: 19:  37%|███▋      | 368/1000 [00:02<00:04, 135.07it/s] 


In [30]:
gen_model = MaskedAutoregressiveFlow(
    features=dataset.X_train.shape[1],
    hidden_features=16,
)
gen_model.fit(train_dataloader, test_dataloader, epochs=1000)

Epoch 473, Train: -0.8418, test: -0.7440, patience: 20:  47%|████▋     | 473/1000 [00:05<00:06, 86.17it/s]
