Skip to content

Commit

Permalink
fix #36, update vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
dandls committed May 14, 2024
1 parent f7a4f18 commit 50424da
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 112 deletions.
39 changes: 21 additions & 18 deletions vignettes/introduction.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ credit application is rejected.

### Data: German Credit Dataset

As training data, we use the German Credit Data from the `rchallenge`` package.
As training data, we use the German Credit Data from the `rchallenge` package.
The data set contains 1000 observations with 21 features and the binary target variable `credit_risk`.
For illustrative purposes, we only consider 8 of the 21 features in the following:

```{r}
data(german, package = "rchallenge")
credit = german[, c("duration", "amount", "purpose", "age", "employment_duration", "housing", "number_credits", "credit_risk")]
credit = german[, c("duration", "amount", "purpose", "age",
"employment_duration", "housing", "number_credits", "credit_risk")]
```

```{r, echo=FALSE}
Expand All @@ -64,7 +65,8 @@ column_descr = data.frame(
)
)
names(column_descr) <- c("Variable", "Description")
knitr::kable(column_descr, escape = FALSE, format = "html", table.attr = "style='width:100%;'")
knitr::kable(column_descr, escape = FALSE, format = "html",
table.attr = "style='width:100%;'")
```

### Fitting a model
Expand Down Expand Up @@ -120,7 +122,8 @@ if (!file.exists("introduction-res/cfactuals_credit.RDS")) {
predictor, epsilon = 0, fixed_features = c("age", "employment_duration"),
termination_crit = "genstag", n_generations = 10L, quiet = TRUE
)
cfactuals = moc_classif$find_counterfactuals(x_interest, desired_class = "good", desired_prob = c(0.6, 1))
cfactuals = moc_classif$find_counterfactuals(x_interest,
desired_class = "good", desired_prob = c(0.6, 1))
dir.create("introduction-res")
saveRDS(moc_classif, file = "introduction-res/moc_classif_credit.RDS")
saveRDS(cfactuals, file = "introduction-res/cfactuals_credit.RDS")
Expand All @@ -134,11 +137,6 @@ cfactuals = readRDS("introduction-res/cfactuals_credit.RDS")

The resulting `Counterfactuals` object holds the counterfactuals in the `data` field and possesses several methods for their
evaluation and visualization.

```{r}
class(cfactuals)
```

Printing a `Counterfactuals` object, gives an overview of the results.
```{r}
print(cfactuals)
Expand All @@ -158,15 +156,17 @@ and negative values indicate a decrease; for factors, the counterfactual feature
differs from `x_interest.`; `NA` means "no difference" in both cases.

```{r}
head(cfactuals$evaluate(show_diff = TRUE, measures = c("dist_x_interest", "dist_target", "no_changed", "dist_train")), 3L)
head(cfactuals$evaluate(show_diff = TRUE,
measures = c("dist_x_interest", "dist_target", "no_changed", "dist_train")), 3L)
```

By design, not all counterfactuals generated with MOC have a prediction equal to the desired
prediction. We can use `subset_to_valid()` to omit all counterfactuals that do not achieve
the desired predicition.
the desired prediction. This step can be reverted with `revert_subset_to_valid()`.

```{r}
cfactuals$subset_to_valid()
nrow(cfactuals$data)
nrow(cfactuals$data)
```

The `plot_freq_of_feature_changes()` method plots the frequency of feature changes across all
Expand All @@ -182,7 +182,7 @@ We specify `feature_names` to order the features according to their frequency of

```{r, message=FALSE, fig.height=2.5}
cfactuals$plot_parallel(feature_names = names(
cfactuals$get_freq_of_feature_changes()), digits_min_max = 2L)
cfactuals$get_freq_of_feature_changes()), digits_min_max = 2L)
```

Expand Down Expand Up @@ -269,7 +269,8 @@ First, we train a model to predict `plasma_retinol`, again omitting `x_interest`
This time we use a regression tree trained with the `mlr3` and `rpart` package.

```{r}
tsk = mlr3::TaskRegr$new(id = "plasma", backend = plasma[-100L,], target = "retplasma")
tsk = TaskRegr$new(id = "plasma", backend = plasma[-100L,],
target = "retplasma")
tree = lrn("regr.rpart")
model = tree$train(tsk)
```
Expand Down Expand Up @@ -300,10 +301,11 @@ nice_regr = NICERegr$new(predictor, optimization = "proximity",
```

Then, we use the `find_counterfactuals()` method to find counterfactuals for `x_interest` with a predicted
plasma concentration in the interval [500, Inf].
plasma concentration in the interval [500, Inf).

```{r}
cfactuals = nice_regr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
cfactuals = nice_regr$find_counterfactuals(x_interest,
desired_outcome = c(500, Inf))
```

### The counterfactuals object
Expand All @@ -317,7 +319,7 @@ cfactuals
To inspect the counterfactual, we can use the same tools as before.
For example, in the surface plot, we see that increasing betaplasma helps while changing the age alone has no impact on the prediction.

```{r, fig.height=3, fig.width=9}
```{r, fig.height=3}
cfactuals$plot_surface(feature_names = c("betaplasma", "age"), grid_size = 200)
```

Expand Down Expand Up @@ -353,7 +355,8 @@ Replacing the distance function is fairly easy:
nice_regr = NICERegr$new(predictor, optimization = "proximity",
margin_correct = 0.5, return_multiple = FALSE,
distance_function = l0_norm)
cfactuals = nice_regr$find_counterfactuals(x_interest, desired_outcome = c(500, 1000))
cfactuals = nice_regr$find_counterfactuals(x_interest,
desired_outcome = c(500, 1000))
cfactuals
```

Expand Down

0 comments on commit 50424da

Please sign in to comment.