# Train Linear Learner model with MNIST using Amazon FSx for Lustre

This notebook example is similar to [An Introduction to Linear Learner with MNIST](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/linear_learner_mnist/linear_learner_mnist.ipynb).

[An Introduction to Linear Learner with MNIST](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/linear_learner_mnist/linear_learner_mnist.ipynb) has been adapted to walk you through on using AWS FSx for Lustre (FSxLustre) as an input datasource to training jobs.

Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook.

## Setup

To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on.

In [None]:
import boto3
import re
import sagemaker
from sagemaker import get_execution_role
from sagemaker.session import Session
from sagemaker.image_uris import retrieve

In [None]:
role = get_execution_role()
region = boto3.Session().region_name
container = retrieve("linear-learner", region)

bucket = Session().default_bucket()
prefix = "sagemaker/DEMO-linear-mnist"
output_location = f"s3://{bucket}/{prefix}/output"

print(f"sagemaker execution role: {role}")
print(f"training artifacts will be uploaded to: {output_location}")
print(f"container: {container}")

## Preprocessing

### Data ingestion

Next, we read the MNIST dataset [1] from an existing repository into memory, for preprocessing prior to training. It was downloaded from this [link](http://deeplearning.net/data/mnist/mnist.pkl.gz) and stored on the downloaded_data_bucket.

> [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278-2324, November 1998.

In [None]:
# S3 bucket where the original mnist data is downloaded and stored.
downloaded_data_bucket = f"sagemaker-sample-files"
downloaded_data_prefix = "datasets/image/MNIST"

In [None]:
%%time
import pickle, gzip, json

# Load the dataset
s3 = boto3.client("s3")
s3.download_file(downloaded_data_bucket, f"{downloaded_data_prefix}/mnist.pkl.gz", "mnist.pkl.gz")
with gzip.open("mnist.pkl.gz", "rb") as f:
    train_set, valid_set, test_set = pickle.load(f, encoding="latin1")

### Data inspection

Once the dataset is imported, it's typical as part of the machine learning process to inspect the data, understand the distributions, and determine what type(s) of preprocessing might be needed. You can perform those tasks right here in the notebook. As an example, let's go ahead and look at one of the digits that is part of the dataset.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (2, 10)


def show_digit(img, caption="", subplot=None):
    if subplot is None:
        _, (subplot) = plt.subplots(1, 1)
    imgr = img.reshape((28, 28))
    subplot.axis("off")
    subplot.imshow(imgr, cmap="gray")
    plt.title(caption)


show_digit(train_set[0][30], f"This is a {train_set[1][30]}")

### Data conversion

Since algorithms have particular input and output requirements, converting the dataset is also part of the process that a data scientist goes through prior to initiating training. In this particular case, the Amazon SageMaker implementation of Linear Learner takes recordIO-wrapped protobuf, where the data we have today is a pickle-ized numpy array on disk.

Most of the conversion effort is handled by the Amazon SageMaker Python SDK, imported as sagemaker below.

In [None]:
import io
import numpy as np
import sagemaker.amazon.common as smac

train_set_vectors = np.array([t.tolist() for t in train_set[0]]).astype("float32")
train_set_labels = np.where(np.array([t.tolist() for t in train_set[1]]) == 0, 1, 0).astype(
    "float32"
)

validation_set_vectors = np.array([t.tolist() for t in valid_set[0]]).astype("float32")
validation_set_labels = np.where(np.array([t.tolist() for t in valid_set[1]]) == 0, 1, 0).astype(
    "float32"
)

train_set_buf = io.BytesIO()
validation_set_buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(train_set_buf, train_set_vectors, train_set_labels)
smac.write_numpy_to_dense_tensor(validation_set_buf, validation_set_vectors, validation_set_labels)

train_set_buf.seek(0)
validation_set_buf.seek(0)

### Upload training data

Now that we've created our recordIO-wrapped protobuf, we'll need to upload it to S3, so that Amazon SageMaker training can use it.

In [None]:
import boto3
import os

key = "recordio-pb-data"
boto3.resource("s3").Bucket(bucket).Object(os.path.join(prefix, "train", key)).upload_fileobj(
    train_set_buf
)
boto3.resource("s3").Bucket(bucket).Object(os.path.join(prefix, "validation", key)).upload_fileobj(
    validation_set_buf
)

s3_train_data = f"s3://{bucket}/{prefix}/train/{key}"
s3_validation_data = f"s3://{bucket}/{prefix}/validation/{key}"

print(f"uploaded training data location: {s3_train_data}")
print(f"uploaded validation data location: {s3_validation_data}")

## Prepare File System Input

we specify the details of file system as an input to your training job. Using file system as a data source eliminates the time your training job spends downloading data with data streamed directly from file system into your training algorithm.

### WARNING

Before specifying the FileSystemInput, you must make sure Amazon FSx for Lusture is linked your Amazon S3 that have training/validation data.
For more information, see [Amazon SageMaker - Configure Data Input Channel to Use Amazon FSx for Lustre](https://docs.aws.amazon.com/sagemaker/latest/dg/model-access-training-data.html#model-access-training-data-fsx)

In [None]:
from sagemaker.inputs import FileSystemInput

# Specify file system id.
# e.g.) file_system_id = "fs-0123456789abcdef0"
file_system_id = "<your_file_system_id>"

# Specify directory path associated with the file system. You need to provide normalized and absolute path here.
# When you specify `directory_path`, make sure that you provide the Amazon FSx file system path starting with `MountName`.
# e.g.) file_system_directory_path = "/1234abcd/ns1/sagemaker/DEMO-linear-mnist/train"
file_system_directory_path = "/<mount_name>/<file_system_path>/<your_training_data_s3_location>"

# Specify your file system type: "FSxLustre".
file_system_type = "FSxLustre"
file_system_access_mode = "ro"

file_system_input = FileSystemInput(
    file_system_id=file_system_id,
    file_system_type=file_system_type,
    directory_path=file_system_directory_path,
    file_system_access_mode=file_system_access_mode,
)

## Training the linear model

Once we have the file system provisioned and file system input ready for training, the next step is to actually train the model.

In [None]:
import boto3
import sagemaker

sagemaker_session = sagemaker.Session()

# Give Amazon SageMaker Training Jobs Access to FileSystem Resources in Your Amazon VPC.
security_groups_ids = [ "<your_security_groups_ids>" ]
subnets = [ "<your_subnets>" ] # Should be the same as the subnet used for Amazon FSx

linear = sagemaker.estimator.Estimator(
    container,
    role,
    subnets=subnets,
    security_group_ids=security_groups_ids,
    instance_count=1,
    instance_type="ml.c5.xlarge",
    output_path=output_location,
    sagemaker_session=sagemaker_session
)

linear.set_hyperparameters(feature_dim=784, predictor_type="binary_classifier", mini_batch_size=200)

Towards the end of the job you should see model artifact generated and uploaded to `output_location`.

In [None]:
linear.fit({"train": file_system_input})

## Set up hosting for the model

Now that we've trained our model, we can deploy it behind an Amazon SageMaker real-time hosted endpoint. This will allow out to make predictions (or inference) from the model dyanamically.

In [None]:
training_job_name = linear.latest_training_job.name
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)

In [None]:
trained_model_location = desc['ModelArtifacts']['S3ModelArtifacts']
print(f"trained model location: {trained_model_location}")

In [None]:
from sagemaker.model import Model

model = Model(
    image_uri=container,
    model_data=trained_model_location,
    role=role
)

In [None]:
model.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

In [None]:
print(f"endpoint name: {model.endpoint_name}")

## Validate the model for use

Finally, we can now validate the model for use.
Let's try getting a prediction for a single record.

In [None]:
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor

predictor = Predictor(model.endpoint_name, serializer=CSVSerializer(), deserializer = JSONDeserializer())

In [None]:
result = predictor.predict(train_set[0][30], initial_args={"ContentType": "text/csv"})
print(result)

OK, a single prediction works. We see that for one record our endpoint returned some JSON which contains `predictions`, including the `score` and `predicted_label`. In this case, `score` will be a continuous value between \[0, 1\] representing the probability we think the digit is a `0` or not. `predicted_label` will take a value of either `0` or `1` where (somewhat counterintuitively) `1` denotes that we predict the image is a `0`, while `0` denotes that we are predicting the image is not of a `0`.

Let's do a whole batch of images and evaluate our predictive accuracy.

In [None]:
import numpy as np

predictions = []
for array in np.array_split(test_set[0], 100):
    result = predictor.predict(array)
    predictions += [r["predicted_label"] for r in result["predictions"]]

predictions = np.array(predictions)

In [None]:
import pandas as pd

pd.crosstab(
    np.where(test_set[1] == 0, 1, 0), predictions, rownames=["actuals"], colnames=["predictions"]
)

## (Optional) Clean Up

If you're ready to be done with this notebook, please run the delete_endpoint line in the cell below. This will remove the hosted endpoint you created and avoid any charges from a stray instance being left on.

In [None]:
predictor.delete_endpoint()

In [None]:
model.delete_model()