In [None]:
import collections
import math
import torch
import os, tarfile, json
import time, datetime
from io import StringIO
import numpy as np
import sagemaker
from sagemaker.pytorch import estimator, PyTorchModel, PyTorchPredictor, PyTorch
from sagemaker.utils import name_from_base
import boto3
from types import SimpleNamespace

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = 'privisaa-bucket-virginia'#sagemaker_session.default_bucket() # can replace with your own S3 bucket 'privisaa-bucket-virginia' # 
prefix = 'object-tracker'
runtime_client = boto3.client('runtime.sagemaker')

## Build our tracking container

In [None]:
%%sh

# The name of our algorithm
algorithm_name=tracker-train

chmod +x train

account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
region=${region:-us-east-1}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest"

# If the repository doesn't exist in ECR, create it.

aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null
fi

# Get the login command from ECR and execute it directly
$(aws ecr get-login --region ${region} --no-include-email)

# Build the docker image locally with the image name and then push it to ECR
# with the full name.

docker build  -t ${algorithm_name} .
docker tag ${algorithm_name} ${fullname}

# some kind of security auth issue with pushing this to ecr, not authorized to perform ecr:InitiateLayerUpload
docker push ${fullname}

## Grab our tracking data

In [None]:

s3train = f's3://{bucket}/nfl-data/game_clips'

train = sagemaker.session.s3_input(s3train, distribution='FullyReplicated', 
                        content_type=None, s3_data_type='S3Prefix') 
# 'application/tfrecord'

data_channels = {'train': train}

In [None]:
# set our hyperparameters
hyperparameters = {
                  'num_gpus':8,
                  'epochs':20,
                  'lr':0.0003
}

# instantiate model
torch_model = PyTorch( role=role,
                      train_instance_count=1,
                      train_instance_type= 'ml.p3dn.24xlarge', # try local
                      entry_point='/home/ec2-user/SageMaker/code/object-tracking-project/train_object_tracker.py',
                      image_name='209419068016.dkr.ecr.us-east-1.amazonaws.com/tracker-train',
                      framework_version='1.5.1',
                      hyperparameters=hyperparameters
                     )


In [None]:
torch_model.fit(inputs=data_channels)
