In [1]:
from pytorch_tabnet.tab_model import TabNetClassifier

import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
np.random.seed(0)


import os
import wget
from pathlib import Path
import shutil
import gzip

from matplotlib import pyplot as plt
%matplotlib inline

# Download ForestCoverType dataset

In [2]:
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
dataset_name = 'forest-cover-type'
tmp_out = Path('./data/'+dataset_name+'.gz')
out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')

In [3]:
out.parent.mkdir(parents=True, exist_ok=True)
if out.exists():
    print("File already exists.")
else:
    print("Downloading file...")
    wget.download(url, tmp_out.as_posix())
    with gzip.open(tmp_out, 'rb') as f_in:
        with open(out, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    


File already exists.


# Load data and split
Same split as in original paper

In [4]:
target = "Covertype"

bool_columns = [
    "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3",
    "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4",
    "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9",
    "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14",
    "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19",
    "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
    "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29",
    "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34",
    "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39",
    "Soil_Type40"
]

int_columns = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points"
]

feature_columns = (
    int_columns + bool_columns + [target])


In [5]:
train = pd.read_csv(out, header=None, names=feature_columns)

n_total = len(train)

# Train, val and test split follows
# Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank.
# Xgboost: Scalable GPU accelerated learning. arXiv:1806.11248, 2018.

train_val_indices, test_indices = train_test_split(
    range(n_total), test_size=0.2, random_state=0)
train_indices, valid_indices = train_test_split(
    train_val_indices, test_size=0.2 / 0.6, random_state=0)


# Simple preprocessing

Label encode categorical features and fill empty cells.

In [6]:
categorical_columns = []
categorical_dims =  {}
for col in train.columns[train.dtypes == object]:
    print(col, train[col].nunique())
    l_enc = LabelEncoder()
    train[col] = train[col].fillna("VV_likely")
    train[col] = l_enc.fit_transform(train[col].values)
    categorical_columns.append(col)
    categorical_dims[col] = len(l_enc.classes_)

for col in train.columns[train.dtypes == 'float64']:
    train.fillna(train.loc[train_indices, col].mean(), inplace=True)

# Define categorical features for categorical embeddings

In [7]:
unused_feat = []

features = [ col for col in train.columns if col not in unused_feat+[target]] 

cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]

cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]


# Network parameters

In [8]:
clf = TabNetClassifier(
    n_d=64, n_a=64, n_steps=5,
    gamma=1.5, n_independent=2, n_shared=2,
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=1,
    lambda_sparse=1e-4, momentum=0.3, clip_value=2.,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params = {"gamma": 0.95,
                     "step_size": 1},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    epsilon=1e-15,
)



# Training

In [9]:
if os.getenv("CI", False):
# Take only a subsample to run CI
    X_train = train[features].values[train_indices][:1000,:]
    y_train = train[target].values[train_indices][:1000]
else:
    X_train = train[features].values[train_indices]
    y_train = train[target].values[train_indices]

X_valid = train[features].values[valid_indices]
y_valid = train[target].values[valid_indices]

X_test = train[features].values[test_indices]
y_test = train[target].values[test_indices]

In [10]:
max_epochs = 100 if not os.getenv("CI", False) else 2

In [11]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_valid, y_valid)],
    eval_name=['train', 'valid'],
    max_epochs=max_epochs, patience=100,
    batch_size=16384, virtual_batch_size=2048,
    mixed_precision=False
) 

epoch 0  | loss: 1.18026 | train_accuracy: 0.05225 | valid_accuracy: 0.05184 |  0:00:05s
epoch 1  | loss: 0.72212 | train_accuracy: 0.05526 | valid_accuracy: 0.05463 |  0:00:11s
epoch 2  | loss: 0.66789 | train_accuracy: 0.06352 | valid_accuracy: 0.06296 |  0:00:17s
epoch 3  | loss: 0.64399 | train_accuracy: 0.06687 | valid_accuracy: 0.06616 |  0:00:23s
epoch 4  | loss: 0.62573 | train_accuracy: 0.07494 | valid_accuracy: 0.07431 |  0:00:29s
epoch 5  | loss: 0.61111 | train_accuracy: 0.10072 | valid_accuracy: 0.09982 |  0:00:34s
epoch 6  | loss: 0.60135 | train_accuracy: 0.12704 | valid_accuracy: 0.12655 |  0:00:40s
epoch 7  | loss: 0.59482 | train_accuracy: 0.14916 | valid_accuracy: 0.14848 |  0:00:46s
epoch 8  | loss: 0.57626 | train_accuracy: 0.1892  | valid_accuracy: 0.1895  |  0:00:51s
epoch 9  | loss: 0.5652  | train_accuracy: 0.24534 | valid_accuracy: 0.24436 |  0:00:57s
epoch 10 | loss: 0.55783 | train_accuracy: 0.30329 | valid_accuracy: 0.30262 |  0:01:03s
epoch 11 | loss: 0.54

epoch 93 | loss: 0.34567 | train_accuracy: 0.86745 | valid_accuracy: 0.86118 |  0:08:59s
epoch 94 | loss: 0.34509 | train_accuracy: 0.86751 | valid_accuracy: 0.86151 |  0:09:05s
epoch 95 | loss: 0.34447 | train_accuracy: 0.86766 | valid_accuracy: 0.86178 |  0:09:10s
epoch 96 | loss: 0.34503 | train_accuracy: 0.86831 | valid_accuracy: 0.86187 |  0:09:16s
epoch 97 | loss: 0.34385 | train_accuracy: 0.86809 | valid_accuracy: 0.86161 |  0:09:22s
epoch 98 | loss: 0.34399 | train_accuracy: 0.86796 | valid_accuracy: 0.86178 |  0:09:27s
epoch 99 | loss: 0.34412 | train_accuracy: 0.86842 | valid_accuracy: 0.86223 |  0:09:33s
Stop training because you reached max_epochs = 100 with best_epoch = 92 and best_valid_accuracy = 0.86227




In [12]:
no mixed precision
epoch 0  | loss: 1.18026 | train_accuracy: 0.05225 | valid_accuracy: 0.05184 |  0:00:05s
epoch 1  | loss: 0.72212 | train_accuracy: 0.05526 | valid_accuracy: 0.05463 |  0:00:11s
epoch 2  | loss: 0.66789 | train_accuracy: 0.06352 | valid_accuracy: 0.06296 |  0:00:16s
epoch 3  | loss: 0.64399 | train_accuracy: 0.06687 | valid_accuracy: 0.06616 |  0:00:22s
epoch 4  | loss: 0.62573 | train_accuracy: 0.07494 | valid_accuracy: 0.07431 |  0:00:27s
epoch 5  | loss: 0.61111 | train_accuracy: 0.10072 | valid_accuracy: 0.09982 |  0:00:33s
epoch 6  | loss: 0.60135 | train_accuracy: 0.12704 | valid_accuracy: 0.12655 |  0:00:38s
epoch 7  | loss: 0.59482 | train_accuracy: 0.14916 | valid_accuracy: 0.14848 |  0:00:44s
epoch 8  | loss: 0.57626 | train_accuracy: 0.1892  | valid_accuracy: 0.1895  |  0:00:49s
epoch 9  | loss: 0.5652  | train_accuracy: 0.24534 | valid_accuracy: 0.24436 |  0:00:55s
Stop training because you reached max_epochs = 10 with best_epoch = 9 and best_valid_accuracy = 0.24436


SyntaxError: invalid syntax (<ipython-input-12-97418754e051>, line 1)

In [None]:
epoch 0  | loss: 1.62946 | train_accuracy: 0.47121 | valid_accuracy: 0.46998 |  0:00:05s
epoch 1  | loss: 0.84275 | train_accuracy: 0.44343 | valid_accuracy: 0.44265 |  0:00:11s
epoch 2  | loss: 0.74182 | train_accuracy: 0.34433 | valid_accuracy: 0.34282 |  0:00:16s
epoch 3  | loss: 0.70406 | train_accuracy: 0.36308 | valid_accuracy: 0.36312 |  0:00:22s
epoch 4  | loss: 0.66785 | train_accuracy: 0.34933 | valid_accuracy: 0.35105 |  0:00:27s
epoch 5  | loss: 0.63612 | train_accuracy: 0.37187 | valid_accuracy: 0.37215 |  0:00:33s
epoch 6  | loss: 0.61824 | train_accuracy: 0.34983 | valid_accuracy: 0.34929 |  0:00:38s
epoch 7  | loss: 0.60716 | train_accuracy: 0.39124 | valid_accuracy: 0.39228 |  0:00:44s
epoch 8  | loss: 0.60422 | train_accuracy: 0.38846 | valid_accuracy: 0.3891  |  0:00:49s
epoch 9  | loss: 0.60199 | train_accuracy: 0.44437 | valid_accuracy: 0.44475 |  0:00:55s
epoch 10 | loss: 0.59844 | train_accuracy: 0.47141 | valid_accuracy: 0.47171 |  0:01:01s
epoch 11 | loss: 0.5888  | train_accuracy: 0.49226 | valid_accuracy: 0.49278 |  0:01:06s
epoch 12 | loss: 0.5784  | train_accuracy: 0.53464 | valid_accuracy: 0.53424 |  0:01:12s
epoch 13 | loss: 0.57251 | train_accuracy: 0.56823 | valid_accuracy: 0.56693 |  0:01:17s
epoch 14 | loss: 0.56808 | train_accuracy: 0.6078  | valid_accuracy: 0.60717 |  0:01:23s
epoch 15 | loss: 0.57061 | train_accuracy: 0.65449 | valid_accuracy: 0.65291 |  0:01:28s
epoch 16 | loss: 0.55801 | train_accuracy: 0.68168 | valid_accuracy: 0.68079 |  0:01:34s
epoch 17 | loss: 0.55119 | train_accuracy: 0.69452 | valid_accuracy: 0.69401 |  0:01:39s
epoch 18 | loss: 0.54191 | train_accuracy: 0.71516 | valid_accuracy: 0.71452 |  0:01:45s
epoch 19 | loss: 0.541   | train_accuracy: 0.72662 | valid_accuracy: 0.7252  |  0:01:51s
epoch 20 | loss: 0.5372  | train_accuracy: 0.74042 | valid_accuracy: 0.73939 |  0:01:56s
epoch 21 | loss: 0.52811 | train_accuracy: 0.74999 | valid_accuracy: 0.74764 |  0:02:02s
epoch 22 | loss: 0.5228  | train_accuracy: 0.75533 | valid_accuracy: 0.75354 |  0:02:07s
epoch 23 | loss: 0.51944 | train_accuracy: 0.76728 | valid_accuracy: 0.76503 |  0:02:13s
epoch 24 | loss: 0.51691 | train_accuracy: 0.76939 | valid_accuracy: 0.76745 |  0:02:18s
epoch 25 | loss: 0.51027 | train_accuracy: 0.77586 | valid_accuracy: 0.7734  |  0:02:24s
epoch 26 | loss: 0.50655 | train_accuracy: 0.77766 | valid_accuracy: 0.77435 |  0:02:29s
epoch 27 | loss: 0.50238 | train_accuracy: 0.78135 | valid_accuracy: 0.77854 |  0:02:35s
epoch 28 | loss: 0.49799 | train_accuracy: 0.78715 | valid_accuracy: 0.78467 |  0:02:40s
epoch 29 | loss: 0.49501 | train_accuracy: 0.78662 | valid_accuracy: 0.78479 |  0:02:46s
epoch 30 | loss: 0.49317 | train_accuracy: 0.78916 | valid_accuracy: 0.78605 |  0:02:52s
epoch 31 | loss: 0.49016 | train_accuracy: 0.79215 | valid_accuracy: 0.78833 |  0:02:57s
epoch 32 | loss: 0.4871  | train_accuracy: 0.79384 | valid_accuracy: 0.78982 |  0:03:03s
epoch 33 | loss: 0.48153 | train_accuracy: 0.79845 | valid_accuracy: 0.7948  |  0:03:08s
epoch 34 | loss: 0.47868 | train_accuracy: 0.79749 | valid_accuracy: 0.79417 |  0:03:14s
epoch 35 | loss: 0.47837 | train_accuracy: 0.79813 | valid_accuracy: 0.79485 |  0:03:20s
epoch 36 | loss: 0.48494 | train_accuracy: 0.79583 | valid_accuracy: 0.79246 |  0:03:25s
epoch 37 | loss: 0.47763 | train_accuracy: 0.80158 | valid_accuracy: 0.79738 |  0:03:31s
epoch 38 | loss: 0.46918 | train_accuracy: 0.80414 | valid_accuracy: 0.79956 |  0:03:36s
epoch 39 | loss: 0.46376 | train_accuracy: 0.80906 | valid_accuracy: 0.80508 |  0:03:42s
epoch 40 | loss: 0.45897 | train_accuracy: 0.81105 | valid_accuracy: 0.80619 |  0:03:48s
epoch 41 | loss: 0.45522 | train_accuracy: 0.81318 | valid_accuracy: 0.80877 |  0:03:53s
epoch 42 | loss: 0.45178 | train_accuracy: 0.81499 | valid_accuracy: 0.81054 |  0:03:59s
epoch 43 | loss: 0.44998 | train_accuracy: 0.817   | valid_accuracy: 0.81201 |  0:04:04s
epoch 44 | loss: 0.44837 | train_accuracy: 0.81578 | valid_accuracy: 0.81031 |  0:04:10s
epoch 45 | loss: 0.44559 | train_accuracy: 0.81898 | valid_accuracy: 0.81469 |  0:04:16s
epoch 46 | loss: 0.44294 | train_accuracy: 0.82016 | valid_accuracy: 0.81516 |  0:04:21s
epoch 47 | loss: 0.44021 | train_accuracy: 0.82089 | valid_accuracy: 0.81531 |  0:04:27s
epoch 48 | loss: 0.4374  | train_accuracy: 0.82238 | valid_accuracy: 0.81725 |  0:04:32s
epoch 49 | loss: 0.43614 | train_accuracy: 0.82236 | valid_accuracy: 0.81749 |  0:04:38s
epoch 50 | loss: 0.43641 | train_accuracy: 0.82131 | valid_accuracy: 0.81633 |  0:04:43s
epoch 51 | loss: 0.43356 | train_accuracy: 0.8231  | valid_accuracy: 0.81807 |  0:04:49s
epoch 52 | loss: 0.43491 | train_accuracy: 0.82482 | valid_accuracy: 0.81918 |  0:04:54s
epoch 53 | loss: 0.43073 | train_accuracy: 0.82402 | valid_accuracy: 0.81975 |  0:05:00s
epoch 54 | loss: 0.4311  | train_accuracy: 0.82558 | valid_accuracy: 0.82035 |  0:05:06s
epoch 55 | loss: 0.42898 | train_accuracy: 0.82576 | valid_accuracy: 0.82084 |  0:05:11s
epoch 56 | loss: 0.42605 | train_accuracy: 0.82747 | valid_accuracy: 0.82216 |  0:05:17s
epoch 57 | loss: 0.42523 | train_accuracy: 0.82792 | valid_accuracy: 0.82297 |  0:05:22s
epoch 58 | loss: 0.42374 | train_accuracy: 0.82854 | valid_accuracy: 0.82361 |  0:05:28s
epoch 59 | loss: 0.42286 | train_accuracy: 0.83027 | valid_accuracy: 0.8245  |  0:05:33s
epoch 60 | loss: 0.42026 | train_accuracy: 0.82984 | valid_accuracy: 0.82491 |  0:05:39s
epoch 61 | loss: 0.4193  | train_accuracy: 0.82945 | valid_accuracy: 0.82457 |  0:05:44s
epoch 62 | loss: 0.41782 | train_accuracy: 0.83085 | valid_accuracy: 0.82583 |  0:05:50s
epoch 63 | loss: 0.41707 | train_accuracy: 0.83164 | valid_accuracy: 0.82637 |  0:05:55s
epoch 64 | loss: 0.41711 | train_accuracy: 0.83287 | valid_accuracy: 0.82795 |  0:06:01s
epoch 65 | loss: 0.41542 | train_accuracy: 0.83276 | valid_accuracy: 0.82764 |  0:06:06s
epoch 66 | loss: 0.41468 | train_accuracy: 0.83267 | valid_accuracy: 0.82752 |  0:06:12s
epoch 67 | loss: 0.41376 | train_accuracy: 0.8329  | valid_accuracy: 0.82774 |  0:06:18s
epoch 68 | loss: 0.41365 | train_accuracy: 0.83185 | valid_accuracy: 0.82677 |  0:06:23s
epoch 69 | loss: 0.4129  | train_accuracy: 0.83356 | valid_accuracy: 0.82779 |  0:06:29s
epoch 70 | loss: 0.41196 | train_accuracy: 0.83321 | valid_accuracy: 0.82827 |  0:06:34s
epoch 71 | loss: 0.41163 | train_accuracy: 0.83358 | valid_accuracy: 0.82876 |  0:06:40s
epoch 72 | loss: 0.41223 | train_accuracy: 0.83365 | valid_accuracy: 0.82881 |  0:06:45s
epoch 73 | loss: 0.41025 | train_accuracy: 0.83525 | valid_accuracy: 0.83012 |  0:06:51s
epoch 74 | loss: 0.40895 | train_accuracy: 0.83477 | valid_accuracy: 0.82986 |  0:06:57s
epoch 75 | loss: 0.40942 | train_accuracy: 0.8363  | valid_accuracy: 0.83077 |  0:07:02s
epoch 76 | loss: 0.40804 | train_accuracy: 0.83549 | valid_accuracy: 0.83113 |  0:07:08s
epoch 77 | loss: 0.40784 | train_accuracy: 0.83659 | valid_accuracy: 0.83153 |  0:07:13s
epoch 78 | loss: 0.40672 | train_accuracy: 0.8369  | valid_accuracy: 0.83158 |  0:07:19s
epoch 79 | loss: 0.40761 | train_accuracy: 0.83454 | valid_accuracy: 0.82975 |  0:07:25s
epoch 80 | loss: 0.4083  | train_accuracy: 0.83638 | valid_accuracy: 0.83141 |  0:07:30s
epoch 81 | loss: 0.40877 | train_accuracy: 0.8334  | valid_accuracy: 0.82804 |  0:07:36s
epoch 82 | loss: 0.40934 | train_accuracy: 0.83634 | valid_accuracy: 0.83082 |  0:07:42s
epoch 83 | loss: 0.40639 | train_accuracy: 0.83599 | valid_accuracy: 0.83093 |  0:07:47s
epoch 84 | loss: 0.40534 | train_accuracy: 0.83705 | valid_accuracy: 0.83203 |  0:07:53s
epoch 85 | loss: 0.40572 | train_accuracy: 0.83578 | valid_accuracy: 0.8302  |  0:07:58s
epoch 86 | loss: 0.40767 | train_accuracy: 0.83523 | valid_accuracy: 0.83003 |  0:08:04s
epoch 87 | loss: 0.40642 | train_accuracy: 0.83631 | valid_accuracy: 0.83143 |  0:08:09s
epoch 88 | loss: 0.40331 | train_accuracy: 0.83703 | valid_accuracy: 0.83179 |  0:08:15s
epoch 89 | loss: 0.40398 | train_accuracy: 0.83662 | valid_accuracy: 0.83138 |  0:08:21s
epoch 90 | loss: 0.40426 | train_accuracy: 0.83821 | valid_accuracy: 0.83324 |  0:08:26s
epoch 91 | loss: 0.40181 | train_accuracy: 0.83872 | valid_accuracy: 0.83323 |  0:08:32s
epoch 92 | loss: 0.40314 | train_accuracy: 0.8388  | valid_accuracy: 0.83278 |  0:08:38s

epoch 93 | loss: 0.40221 | train_accuracy: 0.83754 | valid_accuracy: 0.83235 |  0:08:43s
epoch 94 | loss: 0.40202 | train_accuracy: 0.83791 | valid_accuracy: 0.8324  |  0:08:49s
epoch 95 | loss: 0.40125 | train_accuracy: 0.83873 | valid_accuracy: 0.83332 |  0:08:54s
epoch 96 | loss: 0.40122 | train_accuracy: 0.83866 | valid_accuracy: 0.83323 |  0:09:00s
epoch 97 | loss: 0.4009  | train_accuracy: 0.83901 | valid_accuracy: 0.83336 |  0:09:06s
epoch 98 | loss: 0.40075 | train_accuracy: 0.83903 | valid_accuracy: 0.83343 |  0:09:11s
epoch 99 | loss: 0.40264 | train_accuracy: 0.83874 | valid_accuracy: 0.833   |  0:09:17s
Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_accuracy = 0.83343


In [None]:
# plot losses
plt.plot(clf.history['loss'])

In [None]:
# plot accuracy
plt.plot(clf.history['train_accuracy'])
plt.plot(clf.history['valid_accuracy'])

### Predictions


In [None]:
# To get final results you may need to use a mapping for classes 
# as you are allowed to use targets like ["yes", "no", "maybe", "I don't know"]

preds_mapper = { idx : class_name for idx, class_name in enumerate(clf.classes_)}

preds = clf.predict_proba(X_test)

y_pred = np.vectorize(preds_mapper.get)(np.argmax(preds, axis=1))

test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)

print(f"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}")
print(f"FINAL TEST SCORE FOR {dataset_name} : {test_acc}")

In [None]:
# or you can simply use the predict method

y_pred = clf.predict(X_test)
test_acc = accuracy_score(y_pred=y_pred, y_true=y_test)
print(f"FINAL TEST SCORE FOR {dataset_name} : {test_acc}")

# Save and load Model

In [None]:
# save state dict
saved_filename = clf.save_model('test_model')

In [None]:
# define new model and load save parameters
loaded_clf = TabNetClassifier()
loaded_clf.load_model(saved_filename)

In [None]:
loaded_preds = loaded_clf.predict_proba(X_test)
loaded_y_pred = np.vectorize(preds_mapper.get)(np.argmax(loaded_preds, axis=1))

loaded_test_acc = accuracy_score(y_pred=loaded_y_pred, y_true=y_test)

print(f"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_acc}")

In [None]:
assert(test_acc == loaded_test_acc)

# Global explainability : feat importance summing to 1

In [None]:
clf.feature_importances_

# Local explainability and masks

In [None]:
explain_matrix, masks = clf.explain(X_test)

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(20,20))

for i in range(5):
    axs[i].imshow(masks[i][:50])
    axs[i].set_title(f"mask {i}")

# XGB

In [None]:
n_estimators = 1000 if not os.getenv("CI", False) else 20

In [None]:
from xgboost import XGBClassifier

clf_xgb = XGBClassifier(max_depth=8,
    learning_rate=0.1,
    n_estimators=n_estimators,
    verbosity=0,
    silent=None,
    objective="multi:softmax",
    booster='gbtree',
    n_jobs=-1,
    nthread=None,
    gamma=0,
    min_child_weight=1,
    max_delta_step=0,
    subsample=0.7,
    colsample_bytree=1,
    colsample_bylevel=1,
    colsample_bynode=1,
    reg_alpha=0,
    reg_lambda=1,
    scale_pos_weight=1,
    base_score=0.5,
    random_state=0,
    seed=None,)

clf_xgb.fit(X_train, y_train,
            eval_set=[(X_valid, y_valid)],
            early_stopping_rounds=40,
            verbose=10)

In [None]:
preds_valid = np.array(clf_xgb.predict_proba(X_valid, ))
valid_acc = accuracy_score(y_pred=np.argmax(preds_valid, axis=1) + 1, y_true=y_valid)
print(valid_acc)

preds_test = np.array(clf_xgb.predict_proba(X_test))
test_acc = accuracy_score(y_pred=np.argmax(preds_test, axis=1) + 1, y_true=y_test)
print(test_acc)