In [None]:
import glob
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random

import seaborn as sn
import boto3
import re
from io import BytesIO
import base64

import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
from torchvision import transforms


In [None]:

resnet50 = models.resnet50(pretrained=True)

_ = resnet50.eval()
# _ = resnet50.cuda()

modules=list(resnet50.children())[:-1]
resnet50=nn.Sequential(*modules)
for p in resnet50.parameters():
    p.requires_grad = False

transform = transforms.Compose([transforms.ToTensor()])

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

resnet50 = resnet50.to(device)

device


In [None]:
#return all s3 keys
def get_all_s3_keys(bucket, filt=None):
    """Get a list of all keys in an S3 bucket."""    
    keys = []

    kwargs = {'Bucket': bucket}
    while True:
        resp = s3.list_objects_v2(**kwargs)
        for obj in resp['Contents']:
            key = obj['Key']
            if filt is not None:
                if filt not in key:
                    continue
            keys.append('s3://' + bucket + '/' + key)

        try:
            kwargs['ContinuationToken'] = resp['NextContinuationToken']
        except KeyError:
            break

    return keys

# Test model with a single image

In [None]:
s3 = boto3.client('s3')

bucket = 'sagemaker-us-east-2-333209439517'

In [None]:
s3_uris = get_all_s3_keys(bucket, filt='jpg')

s3_uris[0], len(s3_uris)

s3_uri = s3_uris[0]

s3_uri


In [None]:
s3_uri = 's3://sagemaker-us-east-2-333209439517/geological_similarity/andesite/012L6.jpg'

In [None]:
payload = s3.get_object(Bucket=bucket,Key=s3_uri.replace(f's3://{bucket}/', ''))['Body'].read()

im_file = BytesIO(payload)  # convert image to file-like object
img = Image.open(im_file)   # img is now PIL Image object

img

In [None]:
im = np.asarray(img)# convert image to numpy array
img = transform(im) # convert to tensor
#img = img.reshape(1,3,28,28)
img = torch.unsqueeze(img, 0)
img = img.to(device)

with torch.no_grad():
    feature = resnet50(img)

feature = feature.cpu().detach().numpy().reshape(-1)

feature.shape

# Save model to s3

In [None]:
img.shape

In [None]:
import tarfile

In [None]:
input_shape = [1, 3, 28, 28]
trace = torch.jit.trace(resnet50.float().eval(), torch.zeros(input_shape).float())

In [None]:
trace.save("model.pth")

In [None]:
with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")

In [None]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base

In [None]:
role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

In [None]:
bucket

In [None]:
compilation_job_name = name_from_base("TorchVision-ResNet50")
prefix = compilation_job_name + "/model"

In [None]:
compilation_job_name, prefix

In [None]:
model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

In [None]:
model_path = 's3://sagemaker-us-east-1-333209439517/TorchVision-ResNet50-2021-09-13-00-30-10-117/model/model.tar.gz'

# Deploy model

In [None]:
import json
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session

In [None]:
role

In [None]:
predictor.delete_endpoint()

In [None]:
model = PyTorchModel(
    entry_point="inference.py",
    source_dir="code",
    role=role,
    model_data=model_path,
    framework_version="1.5.0",
    py_version="py3",
)

In [None]:
# SageMakerFullAccess - policy is a managed policy that includes all the necessary permissions required to perform most actions on SageMaker

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance

local_mode = True

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

In [None]:
predictor.__dict__

In [None]:
predictor.endpoint_name

In [None]:
predictor.serializer

In [None]:
predictor.deserializer

In [None]:
predictor.content_type

In [None]:
encoded_image = base64.b64encode(payload).decode('utf-8')

In [None]:
im_bytes = base64.b64decode(encoded_image)   # im_bytes is a binary image
im_file = BytesIO(im_bytes)  # convert image to file-like object
image = Image.open(im_file)   # img is now PIL Image object
im = np.asarray(image)# convert image to numpy array
print(im.shape)
image

In [None]:
req = {'inputs':encoded_image}

In [None]:
req

In [None]:
data = req
initial_args=None
target_model=None
target_variant=None
inference_id=None

request_args = predictor._create_request_args(data, initial_args, target_model, target_variant, inference_id)
response = predictor.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)

In [None]:
response

In [None]:
print(response['Body'].read().decode('utf-8'))

In [None]:
out = predictor._handle_response(response)

In [None]:
data = json.loads(req)['inputs']

In [None]:
data

In [None]:
res = predictor.predict(req)

In [None]:

# define a function to extract image features
from time import sleep 

sm_client = boto3.client('sagemaker-runtime')
ENDPOINT_NAME = predictor.endpoint 

def get_predictions(payload): 
    return sm_client.invoke_endpoint(EndpointName=ENDPOINT_NAME, 
                                     ContentType='application/x-image', 
                                     Body=payload) 

def extract_features(s3_uri): 
    key = s3_uri.replace(f's3://{bucket}/', '') 
    payload = s3.get_object(Bucket=bucket,Key=key)['Body'].read() 

    sleep(0.1) 
    response = get_predictions(payload) 
        

    response_body = json.loads((response['Body'].read())) 
    
    feature_lst = response_body['predictions'][0] 
    
    return s3_uri, feature_lst









#Connect to Elasticsearch
service = 'es'
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
                   region, service, session_token=credentials.token)

headers = {"Content-Type": "application/json"}

es = Elasticsearch(
    hosts=[{'host': host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=60
)

def create_index(index):
    """
    This function will create an index using knn settings
    """
    if not es.indices.exists(index=index):
        index_settings = {
            "settings": {
                "index.knn": True,
                "index.mapping.total_fields.limit": "2000"
            },
            "mappings": {
                "properties": {
                    "embeddings": {
                        "type": "knn_vector",
                        "dimension": 2048
                    }
                }
            }
        }

        es.indices.create(index=index, body=json.dumps(index_settings))
        print("Created the elasticsearch index successufly ")
    else:
        print("elasticsearch index already exists")


#Create the index using knn settings
create_index(es_index)


# You can check if the index is created within your es cluster
es.indices.get_alias("*")

def ingest_data_into_es(event):
    
    loaded_keys = []
    
    bucket = event['bucket']
    key = event['key']

    loaded_keys += [key]

    obj = s3_client.get_object(Bucket=bucket, Key=key)

    records = json.loads(obj['Body'].read().decode('utf-8'))


    count = 0
    lost_records = 0

    for record in records:
        # Get the primary key for use as the Elasticsearch ID
        record_id = record['id']

        try:
            if 'embeddings' in record:
                record['embeddings'] = ast.literal_eval(record['embeddings'])

            es.index(index=es_index, id=record_id, doc_type='_doc', body=record)
    
            count += 1
        except Exception as error:
            logger.error(f"An error {error} for record {record}")
            lost_records += 1

        
    logger.info(
        f'{lost_records} out of {len(records)} are lost records')

    logger.info(
        f'{count} out of {len(records)} records has been processed')

    return {
        'statusCode': 200,
        'body': json.dumps(str(count) + ' records processed.')
    }


#Check that data is indeed in ES
res = es.search(index=es_index, body={
                    "query": {
                            "match_all": {}
                        }},
           size=10)



es_query ={
            "query": {
                "knn": {
                    "embeddings": {
                        "vector": query_embeddings,
                        "k": 5
                    }
                }
            }
    }

res = es.search(index=es_index, body=es_query, size=page_size)

