In [1]:
%load_ext autoreload
%autoreload 2
import warnings
from IPython.core.interactiveshell import InteractiveShell

warnings.filterwarnings('ignore')
InteractiveShell.ast_node_interactivity = 'last_expr'

In [16]:
import configparser

config = configparser.ConfigParser()
config.read('config.ini')
mlflow_uri = config['mlflow']['mlflow_uri']
mlflow_artifact_uri = config['mlflow']['target_mlflow_artifact_uri']
linux_mlflow_artifact_uri = config['mlflow']['linux_target_mlflow_artifact_uri']

In [3]:
import flash
from flash.core.data.utils import download_data

download_data(config['classification-example']['download_url'], 'data/')

In [4]:
from flash.tabular import TabularClassifier, TabularData

cat_cols, num_cols = ['Sex','Age','SibSp','Parch','Ticket','Cabin','Embarked'], ['Fare']
target_fields, val_split = 'Survived', 0.25
datamodule = TabularData.from_csv(
    categorical_fields=cat_cols,
    numerical_fields=num_cols,
    target_fields=target_fields,
    train_file=config['classification-example']['train_file'],
    test_file=config['classification-example']['test_file'],
    val_split=val_split
)

NumExpr defaulting to 4 threads.


In [5]:
from torchmetrics.classification import Accuracy,Precision,Recall
metrics = [Accuracy(),Precision(),Recall()]
model = TabularClassifier.from_data(datamodule=datamodule, metrics=metrics)

In [7]:
import os
from pytorch_lightning.loggers import MLFlowLogger

username = os.getenv('USER_UID')
mlf_logger = MLFlowLogger(experiment_name=f'lf-class-{username}', tracking_uri=mlflow_uri)

In [8]:
trainer = flash.Trainer(logger=mlf_logger, max_epochs=10)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [9]:
trainer.fit(model, datamodule=datamodule)

Experiment with name lf-class-chengb not found. Creating it.

  | Name    | Type       | Params
---------------------------------------
0 | model   | TabNet     | 28.2 K
1 | metrics | ModuleDict | 0     
---------------------------------------
28.2 K    Trainable params
0         Non-trainable params
28.2 K    Total params
0.113     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 200/200 [00:11<00:00, 16.69it/s, loss=0.555, v_num=c33f, val_accuracy=0.775, val_precision=0.775, val_recall=0.775, val_cross_entropy=0.520, train_accuracy_step=0.500, train_precision_step=0.500, train_recall_step=0.500, train_cross_entropy_step=0.710, train_accuracy_epoch=0.722, train_precision_epoch=0.722, train_recall_epoch=0.722, train_cross_entropy_epoch=0.574]


In [10]:
trainer.test()

Testing: 100%|██████████| 23/23 [00:00<00:00, 65.34it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.7111111283302307,
 'test_cross_entropy': 0.6236193776130676,
 'test_precision': 0.7111111283302307,
 'test_recall': 0.7111111283302307}
--------------------------------------------------------------------------------


[{'test_accuracy': 0.7111111283302307,
  'test_precision': 0.7111111283302307,
  'test_recall': 0.7111111283302307,
  'test_cross_entropy': 0.6236193776130676}]

In [12]:
# If one is training on windows laptop or training on linux blade
import shutil
import platform

exp_num = '146'
if platform.system() == 'Windows':
    file_names = os.listdir()
    current_artifact_base_fpath = os.getcwd()
    source_dir = f'{current_artifact_base_fpath}'
    target_dir = mlflow_artifact_uri
    for file_name in file_names:
        if file_name  == exp_num:
            for subfolder in os.listdir(file_name):
                for artifacts in os.listdir(os.path.join(file_name,subfolder)):
                    if artifacts == 'checkpoints':
                        orig_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), artifacts)
                        # print(orig_artifacts_loc)
                        new_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), 'artifacts')
                        # print(new_artifacts_loc)
                        shutil.move(orig_artifacts_loc, new_artifacts_loc)
            shutil.move(os.path.join(source_dir, file_name), target_dir)
elif platform.system() == 'Linux':
    file_names = os.listdir()
    current_artifact_base_fpath = os.getcwd()
    source_dir = f'{current_artifact_base_fpath}'
    target_dir = linux_mlflow_artifact_uri
    for file_name in file_names:
        if file_name  == exp_num:
            for subfolder in os.listdir(file_name):
                for artifacts in os.listdir(os.path.join(file_name,subfolder)):
                    if artifacts == 'checkpoints':
                        orig_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), artifacts)
                        # print(orig_artifacts_loc)
                        new_artifacts_loc = os.path.join(os.path.join(os.path.join(source_dir, file_name), subfolder), 'artifacts')
                        # print(new_artifacts_loc)
                        shutil.move(orig_artifacts_loc, new_artifacts_loc)
            shutil.move(os.path.join(source_dir, file_name), target_dir)

In [13]:
run_id = 'c4da22cf7718485c80fe2d29aa3cc33f'
if platform.system() == 'Windows':
    model_path = f'{mlflow_artifact_uri}\\{exp_num}\\{run_id}\\artifacts'
elif platform.system() == 'Linux':
    model_path = f'{linux_mlflow_artifact_uri}/{exp_num}/{run_id}/artifacts'
for pt in os.listdir(model_path):
    if pt.endswith('.ckpt'):
        if platform.system() == 'Windows':
            ckpt_path = f'{model_path}\\{pt}'
        elif platform.system() == 'Linux':
            ckpt_path = f'{model_path}/{pt}'

In [14]:
model = TabularClassifier.load_from_checkpoint(ckpt_path)

In [17]:
predictions = model.predict(config['classification-example']['predict_file'])
print(predictions)

[[0.2692548930644989, 0.7307450771331787], [0.20101088285446167, 0.7989890575408936], [0.7869592905044556, 0.2130407691001892], [0.7679380774497986, 0.23206186294555664], [0.7289539575576782, 0.2710460424423218], [0.8362942337989807, 0.1637057363986969], [0.7638676166534424, 0.23613230884075165], [0.761445939540863, 0.23855401575565338], [0.8224246501922607, 0.17757539451122284], [0.7843981981277466, 0.21560174226760864], [0.3727164566516876, 0.6272834539413452], [0.8926337957382202, 0.10736621171236038], [0.5770226120948792, 0.42297738790512085], [0.34723520278930664, 0.6527647972106934], [0.7843812704086304, 0.21561871469020844], [0.7846007943153381, 0.21539916098117828], [0.7235276699066162, 0.2764723598957062], [0.4098041355609894, 0.590195894241333], [0.8182392120361328, 0.18176080286502838], [0.796673059463501, 0.20332694053649902], [0.844999372959137, 0.15500059723854065], [0.382688969373703, 0.6173110008239746], [0.8152332901954651, 0.1847667247056961], [0.7858826518058777, 0.2