# Preparation

Download the [ants and bees imageset](https://download.pytorch.org/tutorial/hymenoptera_data.zip) and deploy it to the S3 bucket you see below.

In [None]:
!pip install --upgrade sagemaker

In [None]:
%%time
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch

role = get_execution_role()
user_id = role.split(':')[4]

sess = sagemaker.Session()
bucket = f'sagemaker-{sess.boto_region_name}-{user_id}'
print(f'Bucket for images is {bucket}')
prefix = 'salmon_trout'
s3_prefix = 'salmon_trout/data'
s3_validate_prefix = f'{s3_prefix}/val'

In [None]:
pytorch_estimator = PyTorch('../train.py',
                            instance_type='ml.p3.2xlarge',
                            instance_count=1,
                            framework_version='1.6.0',
                            py_version='py3',
                            role=role,
                            hyperparameters = {
                                'epochs': 25,
                                'learning-rate': 0.001,
                                'gamma': 0.1,
                                'momentum': 0.9,
                                'step-size': 7,
                                'action': 'final_layer'
                            })
pytorch_estimator.fit({'training': f's3://{bucket}/{s3_prefix}'})

In [None]:
classifier = pytorch_estimator.deploy(initial_instance_count = 1, instance_type = 'ml.m5.xlarge', serializer=sagemaker.serializers.IdentitySerializer())

In [None]:
import boto3
import json
import numpy as np
def classify(bucket, image_path):
    s3 = boto3.resource('s3')
    object = s3.Object(bucket, image_path)
    payload = object.get()['Body'].read()
    result = classifier.predict(payload, initial_args={'ContentType': 'application/x-image'})[0]
    index = np.argmax(result)
    object_categories = ['salmon', 'trout']
    #print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))
    return object_categories[index], result[index]

In [None]:
result = classify(bucket, f'{s3_validate_prefix}/salmon/aug_3.jpg')
result

In [None]:
class BucketPaginator:
    def __init__(self, bucket, prefix):
        self.bucket = bucket
        self.prefix = prefix
        self.client = boto3.client('s3')
        self.first = True
        self.continuation_token = None
    
    def list_objects(self):
        if not self.first and not self.continuation_token:
            return []
        self.first = False
        if self.continuation_token:
            response = self.client.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix, ContinuationToken=self.continuation_token)
        else:
            response = self.client.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix)
        if 'NextContinuationToken' in response:
            self.continuation_token = r['NextContinuationToken']
        else:
            self.continuation_token = None
        keys = [item['Key'] for item in response['Contents']]
        return keys

In [None]:
def find_errors(bucket, prefix, expected):
    paginator = BucketPaginator(bucket, f'{prefix}/{expected}')
    keys = paginator.list_objects()
    errors = []
    for key in keys:
        actual = classify(bucket, f'{key}')[0]
        if actual != expected:
            errors.append((key, actual))
    return errors

In [None]:
salmon_errors = find_errors(bucket, s3_validate_prefix, 'salmon')

In [None]:
trout_errors = find_errors(bucket, s3_validate_prefix, 'trout')

In [None]:
len(salmon_errors)

In [None]:
len(trout_errors)

In [None]:
trout_errors

In [None]:
import matplotlib.pyplot as plt
from matplotlib import style
from PIL import Image
import io
s3 = boto3.resource('s3')
style.use('dark_background')

In [None]:
%matplotlib inline
def show_image(key, caption=''):
    object = s3.Object(bucket, key)
    image = Image.open(io.BytesIO(object.get()['Body'].read()))
    plt.axis('off')
    plt.imshow(np.asarray(image))

In [None]:
def show_images(keys, caption=''):
    print(caption)
    figure = plt.figure(figsize=(24,24))
    for i, key in enumerate(keys):
        figure.add_subplot(4, 4, i + 1)
        show_image(key)
    plt.show()

In [None]:
show_images([item[0] for item in trout_errors], 'Feilklassifisert ørret')

In [None]:
show_images([item[0] for item in salmon_errors], 'Feilklassifisert laks')

In [None]:
if True:
    classifier.delete_endpoint()
    classifier.delete_model()