In [None]:
import mlflow.fastai
from fastai.vision.all import (
    CategoryBlock,
    DataBlock,
    GrandparentSplitter,
    ImageBlock,
    PILImage,
    URLs,
)
from fastai.vision.all import cnn_learner, get_image_files, parent_label, resnet18, untar_data

In [None]:
splitter = GrandparentSplitter(train_name="training", valid_name="testing")

# Prepare DataBlock which is a generic container to quickly build Datasets and DataLoaders
mnist = DataBlock(
    blocks=(ImageBlock(PILImage), CategoryBlock),
    get_items=get_image_files,
    splitter=splitter,
    get_y=parent_label,
)

In [None]:
data = mnist.dataloaders(untar_data(URLs.MNIST), bs=256, num_workers=0)

In [None]:
    # mlflow.fastai.autolog()

In [None]:
# Create Learner model
learn = cnn_learner(data, resnet18)

In [None]:

with mlflow.start_run() as run:
    # Train and fit with default or supplied command line arguments
    learn.fit_one_cycle(1, 0.1)
    mlflow.fastai.log_model(learn, "model")

In [None]:
model_uri = "runs:/{}/model".format(run.info.run_id)
loaded_model = mlflow.fastai.load_model(model_uri)

In [112]:
loaded_model.predict(data.train_ds[0][0])

('3',
 TensorBase(3),
 TensorBase([1.3418e-13, 5.8279e-14, 5.2054e-12, 1.0000e+00, 3.8687e-13, 1.6166e-07,
         3.1506e-11, 4.8758e-13, 2.6322e-06, 4.4112e-10]))

In [None]:
model_uri