In [None]:
import boto3
import torch
import numpy as np 
import cv2
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import sagemaker
import tempfile
from sagemaker.pytorch import PyTorch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

# This role has been created to use stepfunctions
workflow_execution_role = "arn:aws:iam::287222052256:role/StepFunctionsWorkflowExecutionRole"

In [None]:
train_source = 's3://kirit-processed/train'
test_source = 's3://kirit-processed/test'

In [None]:
estimator = PyTorch(entry_point='trainer.py',
                    source_dir='code',
                    role=role,
                    framework_version='1.6.0',
                    image_name = '763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04',
                    train_instance_count=3,
                    train_instance_type='ml.m4.xlarge', 
                    hyperparameters={
                        'epochs': 6,
                        'backend': 'gloo'
                    })

### Additional Step functions can be used for model monitoring

In [None]:
import stepfunctions
from stepfunctions.template.pipeline import TrainingPipeline

In [None]:
pipeline = TrainingPipeline(
    estimator = estimator,
    role = workflow_execution_role,
    inputs = {'train' : train_source, 'test': test_source},
    s3_bucket = sagemaker_session.default_bucket()
)

In [None]:
print(pipeline.workflow.definition.to_json(pretty=True))

In [None]:
pipeline.render_graph()

In [None]:
pipeline.create()

### If using stepfunctions, run `pipeline.execute()` and not `estimator.fit()`

In [None]:
pipeline.execute()

### Begin the training job

In [None]:
estimator.fit({'train' : train_source, 'test': test_source})

### Deploy model to an endpoint

In [None]:
estimator.deploy(initial_instance_count = 1, instance_type = 'ml.p2.xlarge')