Skip to content

Commit

Permalink
refactor: remove deep_transformer module (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrislemke authored Mar 9, 2023
1 parent 0cb3f5a commit 27687e8
Show file tree
Hide file tree
Showing 11 changed files with 538 additions and 2,910 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ repos:
args: ["--profile=black"]

- repo: https://github.com/PyCQA/pylint
rev: v2.16.2
rev: v2.17.0
hooks:
- id: pylint
args: ["--rcfile=pyproject.toml"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.1.1
hooks:
- id: mypy
args:
Expand All @@ -139,7 +139,7 @@ repos:
- id: nbstripout

- repo: https://github.com/python-poetry/poetry
rev: 1.3.0
rev: 1.4.0
hooks:
- id: poetry-check
- id: poetry-lock
Expand Down
1 change: 0 additions & 1 deletion docs/API-reference/transformer/deep_transformer.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ poetry install
| ------ | ----------- | ----------- |
|[`Datetime transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/)|[`DateColumnsTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/#sk_transformers.datetime_transformer.DateColumnsTransformer)|Splits a date column into multiple columns.|
|[`Datetime transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/)|[`DurationCalculatorTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/#sk_transformers.datetime_transformer.DurationCalculatorTransformer)|Calculates the duration between to given dates.|
|[`Deep transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/)|[`ToVecTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/#sk_transformers.deep_transformer.ToVecTransformer)|This transformer trains an [FT-Transformer](https://paperswithcode.com/method/ft-transformer) using the [pytorch-widedeep package](https://github.com/jrzaurin/pytorch-widedeep) and extracts the embeddings from its embedding layer.|
|[`Encoder transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/)|[`MeanEncoderTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/#sk_transformers.encoder_transformer.MeanEncoderTransformer)|Scikit-learn API for the [feature-engine MeanEncoder](https://feature-engine.readthedocs.io/en/latest/api_doc/encoding/MeanEncoder.html).|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AggregateTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AggregateTransformer)|This transformer uses Pandas groupby method and aggregate to apply function on a column grouped by another column.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AllowedValuesTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AllowedValuesTransformer)|This transformer replaces values that are *not* in a list with another value.|
Expand Down
53 changes: 0 additions & 53 deletions examples/playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,59 +116,6 @@
"transformer.fit_transform(X).to_numpy()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Deep transformer](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### [`ToVecTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/#sk_transformers.deep_transformer.ToVecTransformer)\n",
"\n",
"This transformer trains an [FT-Transformer](https://paperswithcode.com/method/ft-transformer)\n",
"using the [pytorch-widedeep package](https://github.com/jrzaurin/pytorch-widedeep) and extracts the embeddings\n",
"from its embedding layer. The output shape of the transformer is (number of rows,(`input_dim` * number of columns)).\n",
"Please refer to [this example](https://pytorch-widedeep.readthedocs.io/en/latest/examples/09_extracting_embeddings.html)\n",
"for pytorch_widedeep example on how to extract embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from pytorch_widedeep.datasets import load_adult\n",
"from sk_transformers import ToVecTransformer\n",
"\n",
"df = load_adult(as_frame=True)\n",
"df[\"target\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
"df = df.drop([\"income\", \"educational-num\"], axis=1)\n",
"\n",
"cat_cols, cont_cols = [], []\n",
"for col in df.columns:\n",
" if df[col].dtype == \"O\" or df[col].nunique() < 50 and col != \"target\":\n",
" cat_cols.append(col)\n",
" elif col != \"target\":\n",
" cont_cols.append(col)\n",
"\n",
"target_col = \"target\"\n",
"target = df[target_col].to_numpy()\n",
"\n",
"transformer = ToVecTransformer(\n",
" cat_cols, cont_cols, verbose=0, training_objective=\"binary\"\n",
")\n",
"transformer.fit_transform(df, target).shape"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
Loading

0 comments on commit 27687e8

Please sign in to comment.