DHTV is a framework to learn CPWL functions in an interpretable manner.
In this repository, we aim to:
Solve the regression problem with CPWL functions.
Reproduce of the results of the following paper:
Table of Contents
To install the package, we first create an environment with python 3.7 (or greater):
>> conda create -n DHTV python==3.9.7
>> conda activate DHTV
>> git clone https://github.com/mehrsapo/DHTV.git
>> cd <repository_dir>/
>> conda install -n DHTV ipykernel --update-deps --force-reinstall
>> pip install --upgrade -r requirements.txt
First we need to build the DHTV model and compute forward (H) and regularization (L) operators:
tri = MyDelaunay(X, y) # X: input variables, y: target values
tri.construct_forward_matrix() # constructing H
tri.construct_regularization_matrix() # constructing L
Then we solve the learning task:
dhtv_sol, _ = double_fista(tri.data_values, tri.H, tri.L, tri.lip_H, tri.lip_L, lmbda, n_iter1, n_iter2, device='cuda:0')
We can use this values to predict the model values:
dhtv_predict = tri.evaluate(X, dhtv_sol.cpu().numpy())
In the 2 dimensional case, we can also plot the model using:
tri.update_values(dhtv_sol.cpu().numpy())
plot_with_gradient_map(tri, 0.5, 1, 1, 1)
See for more details <https://github.com/mehrsapo/DHTV/blob/main/intp_metric.ipynb>.
The paper reults are available in notebooks IV.A, IV.B and IV.C, main_compare.py is resposible for creating the loaded data in IV.C.
DHTV is developed by the Biomedical Imaging Group, École Polytéchnique Fédérale de Lausanne, Switzerland.
[Pourya2022] | <https://arxiv.org/pdf/2208.07787.pdf> |
This work was supported in part by the European Research Council (ERC Project FunLearn) under Grant 101020573 and in part by the Swiss National Science Foundation, Grant 200020 184646/1.