In [1]:
%load_ext autoreload
%autoreload 2
import warnings
from IPython.core.interactiveshell import InteractiveShell

warnings.filterwarnings('ignore')
InteractiveShell.ast_node_interactivity = 'last_expr'

In [2]:
import configparser

config = configparser.ConfigParser()
config.read('config.ini')
mlflow_uri = config['mlflow']['mlflow_uri']
mlflow_artifact_uri = config['mlflow']['target_mlflow_artifact_uri']
linux_mlflow_artifact_uri = config['mlflow']['linux_target_mlflow_artifact_uri']

In [3]:
import flash
from flash.core.data.utils import download_data

download_data(config['classification-example']['download_url'], 'data/')

In [4]:
from flash.tabular import TabularClassifier, TabularData

cat_cols, num_cols = ['Sex','Age','SibSp','Parch','Ticket','Cabin','Embarked'], ['Fare']
target_fields, val_split = 'Survived', 0.25
datamodule = TabularData.from_csv(
    categorical_fields=cat_cols,
    numerical_fields=num_cols,
    target_fields=target_fields,
    train_file=config['classification-example']['train_file'],
    test_file=config['classification-example']['test_file'],
    val_split=val_split
)

NumExpr defaulting to 4 threads.


In [5]:
from torchmetrics.classification import Accuracy,Precision,Recall
metrics = [Accuracy(),Precision(),Recall()]
model = TabularClassifier.from_data(datamodule=datamodule, metrics=metrics)

In [7]:
import os
from pytorch_lightning.loggers import MLFlowLogger

username = os.getenv('USER_UID')
mlf_logger = MLFlowLogger(experiment_name=f'lf-classification-{username}', tracking_uri=mlflow_uri)

In [8]:
trainer = flash.Trainer(logger=mlf_logger, max_epochs=10)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [9]:
trainer.fit(model, datamodule=datamodule)

Experiment with name lf-classification-chengb not found. Creating it.

  | Name    | Type       | Params
---------------------------------------
0 | model   | TabNet     | 28.2 K
1 | metrics | ModuleDict | 0     
---------------------------------------
28.2 K    Trainable params
0         Non-trainable params
28.2 K    Total params
0.113     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 200/200 [00:15<00:00, 12.96it/s, loss=0.583, v_num=427e, val_accuracy=0.710, val_precision=0.710, val_recall=0.710, val_cross_entropy=0.598, train_accuracy_step=0.500, train_precision_step=0.500, train_recall_step=0.500, train_cross_entropy_step=0.968, train_accuracy_epoch=0.708, train_precision_epoch=0.708, train_recall_epoch=0.708, train_cross_entropy_epoch=0.584]


In [10]:
trainer.test()

Testing: 100%|██████████| 23/23 [00:00<00:00, 62.33it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.6666666865348816,
 'test_cross_entropy': 0.7112483382225037,
 'test_precision': 0.6666666865348816,
 'test_recall': 0.6666666865348816}
--------------------------------------------------------------------------------


[{'test_accuracy': 0.6666666865348816,
  'test_precision': 0.6666666865348816,
  'test_recall': 0.6666666865348816,
  'test_cross_entropy': 0.7112483382225037}]

In [12]:
# If one is training on windows laptop or training on linux blade
import shutil
import platform

exp_num = '147'
if platform.system() in set(['Windows','Linux']):
    file_names = os.listdir()
    current_artifact_base_fpath = os.getcwd()
    source_dir = f'{current_artifact_base_fpath}'
    target_dict = {'Windows': mlflow_artifact_uri, 'Linux': linux_mlflow_artifact_uri}
    target_dir = target_dict[platform.system()]
    for file_name in file_names:
        if file_name  == exp_num:
            for subfolder in os.listdir(file_name):
                for artifacts in os.listdir(os.path.join(file_name,subfolder)):
                    if artifacts == 'checkpoints':
                        orig_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), artifacts)
                        # print(orig_artifacts_loc)
                        new_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), 'artifacts')
                        # print(new_artifacts_loc)
                        shutil.move(orig_artifacts_loc, new_artifacts_loc)
            shutil.move(os.path.join(source_dir, file_name), target_dir)

In [13]:
run_id = '41d7ba3ba4b8461ebb4fbe96b014427e'

win_model_path = f'{mlflow_artifact_uri}\\{exp_num}\\{run_id}\\artifacts'
linux_model_path = f'{linux_mlflow_artifact_uri}/{exp_num}/{run_id}/artifacts'
model_path_dict = {'Windows': win_model_path, 'Linux': linux_model_path}

if platform.system() in set(list(model_path_dict.keys())):
    model_path = model_path_dict[platform.system()]
for pt in os.listdir(model_path):
    if pt.endswith('.ckpt'):
        win_ckpt_path, linux_ckpt_path = f'{model_path}\\{pt}', f'{model_path}/{pt}'
        ckpt_path_dict = {'Windows': win_ckpt_path, 'Linux': linux_ckpt_path}
        if platform.system() in set(list(ckpt_path_dict.keys())):
            ckpt_path = ckpt_path_dict[platform.system()]

In [14]:
model = TabularClassifier.load_from_checkpoint(ckpt_path)

In [15]:
predictions = model.predict(config['classification-example']['predict_file'])
print(predictions)

[[0.5276029706001282, 0.47239696979522705], [0.3114148676395416, 0.6885851621627808], [0.774603545665741, 0.22539640963077545], [0.7591143250465393, 0.24088571965694427], [0.8354960083961487, 0.16450397670269012], [0.7833905816078186, 0.2166094183921814], [0.7704339623451233, 0.22956611216068268], [0.6087574362754822, 0.3912425637245178], [0.6843942403793335, 0.31560570001602173], [0.72352534532547, 0.27647462487220764], [0.4071795642375946, 0.5928204655647278], [0.6966503858566284, 0.30334964394569397], [0.7569880485534668, 0.24301189184188843], [0.7583250403404236, 0.24167494475841522], [0.7710530161857605, 0.2289469987154007], [0.7673313021659851, 0.2326686978340149], [0.7481235265731812, 0.25187647342681885], [0.5453111529350281, 0.4546888768672943], [0.7842716574668884, 0.21572835743427277], [0.8562350273132324, 0.14376500248908997], [0.849664568901062, 0.150335431098938], [0.3537507951259613, 0.6462491750717163], [0.8288909792900085, 0.17110903561115265], [0.6757763028144836, 0.3