Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3d02eb2
commit ffd7c28
Showing
3 changed files
with
334 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pytorch_tabnet.tab_model import TabNetRegressor\n", | ||
"\n", | ||
"import torch\n", | ||
"from sklearn.preprocessing import LabelEncoder\n", | ||
"from sklearn.metrics import mean_squared_error\n", | ||
"\n", | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"np.random.seed(0)\n", | ||
"\n", | ||
"\n", | ||
"import os\n", | ||
"import wget\n", | ||
"from pathlib import Path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Download census-income dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", | ||
"dataset_name = 'census-income'\n", | ||
"out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"out.parent.mkdir(parents=True, exist_ok=True)\n", | ||
"if out.exists():\n", | ||
" print(\"File already exists.\")\n", | ||
"else:\n", | ||
" print(\"Downloading file...\")\n", | ||
" wget.download(url, out.as_posix())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load data and split" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train = pd.read_csv(out)\n", | ||
"target = ' <=50K'\n", | ||
"if \"Set\" not in train.columns:\n", | ||
" train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", | ||
"\n", | ||
"train_indices = train[train.Set==\"train\"].index\n", | ||
"valid_indices = train[train.Set==\"valid\"].index\n", | ||
"test_indices = train[train.Set==\"test\"].index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Simple preprocessing\n", | ||
"\n", | ||
"Label encode categorical features and fill empty cells." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"categorical_columns = []\n", | ||
"categorical_dims = {}\n", | ||
"for col in train.columns[train.dtypes == object]:\n", | ||
" print(col, train[col].nunique())\n", | ||
" l_enc = LabelEncoder()\n", | ||
" train[col] = train[col].fillna(\"VV_likely\")\n", | ||
" train[col] = l_enc.fit_transform(train[col].values)\n", | ||
" categorical_columns.append(col)\n", | ||
" categorical_dims[col] = len(l_enc.classes_)\n", | ||
"\n", | ||
"for col in train.columns[train.dtypes == 'float64']:\n", | ||
" train.fillna(train.loc[train_indices, col].mean(), inplace=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Define categorical features for categorical embeddings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"unused_feat = ['Set']\n", | ||
"\n", | ||
"features = [ col for col in train.columns if col not in unused_feat+[target]] \n", | ||
"\n", | ||
"cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", | ||
"\n", | ||
"cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n", | ||
"\n", | ||
"# define your embedding sizes : here just a random choice\n", | ||
"cat_emb_dim = [5, 4, 3, 6, 2, 2, 1, 10]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Network parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"clf = TabNetRegressor(cat_dims=cat_dims, cat_emb_dim=cat_emb_dim, cat_idxs=cat_idxs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### We will simulate 5 targets here to perform multi regression without changing anything!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"n_targets = 8\n", | ||
"\n", | ||
"X_train = train[features].values[train_indices]\n", | ||
"y_train = train[target].values[train_indices]\n", | ||
"y_train = np.transpose(np.tile(y_train, (n_targets,1)))\n", | ||
"\n", | ||
"X_valid = train[features].values[valid_indices]\n", | ||
"y_valid = train[target].values[valid_indices]\n", | ||
"y_valid = np.transpose(np.tile(y_valid, (n_targets,1)))\n", | ||
"\n", | ||
"X_test = train[features].values[test_indices]\n", | ||
"y_test = train[target].values[test_indices]\n", | ||
"y_test = np.transpose(np.tile(y_test, (n_targets,1)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"clf.fit(\n", | ||
" X_train=X_train, y_train=y_train,\n", | ||
" X_valid=X_valid, y_valid=y_valid,\n", | ||
" max_epochs=1000,\n", | ||
" patience=50,\n", | ||
" batch_size=1024, virtual_batch_size=128,\n", | ||
" num_workers=0,\n", | ||
" drop_last=False\n", | ||
") " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Deprecated : best model is automatically loaded at end of fit\n", | ||
"# clf.load_best_model()\n", | ||
"\n", | ||
"preds = clf.predict(X_test)\n", | ||
"\n", | ||
"y_true = y_test\n", | ||
"\n", | ||
"test_score = mean_squared_error(y_pred=preds, y_true=y_true)\n", | ||
"\n", | ||
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n", | ||
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Global explainability : feat importance summing to 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"clf.feature_importances_" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Local explainability and masks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"explain_matrix, masks = clf.explain(X_test)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from matplotlib import pyplot as plt\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fig, axs = plt.subplots(1, 3, figsize=(20,20))\n", | ||
"\n", | ||
"for i in range(3):\n", | ||
" axs[i].imshow(masks[i][:50])\n", | ||
" axs[i].set_title(f\"mask {i}\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# XGB : unfortunately this is still not possible with XGBoost\n", | ||
"\n", | ||
"https://github.com/dmlc/xgboost/issues/2087" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.