# SageMaker Implementation of dis-background-removal

This notebook implements the [DIS background removal tool](https://github.com/xuebinqin/DIS) in Amazon SageMaker.

The first part of this notebook will walkthrough on how to package a pre-train dis model to be used for inference. The second part of this notebook implements model training/fine-tuning.


In [28]:
!pip install gdown

Keyring is skipped due to an exception: 'keyring.backends'
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [150]:
import os, gdown
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch, PyTorchModel
import time
from time import gmtime, strftime
from datetime import datetime
import urllib
import numpy as np

In [68]:
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()  ### Replace with your own bucket if needed
role = sagemaker.get_execution_role(sagemaker_session)
prefix = "dis-background-removal"  ### Replace with the S3 prefix desired
region = boto3.Session().region_name
boto_session = boto3.session.Session()
sm_client = boto_session.client("sagemaker")
sm_runtime = boto_session.client("sagemaker-runtime")
sns_client = boto3.client('sns')
print(f"S3 bucket: {bucket}")
print(f"Role: {role}")
print(f"Region: {region}")

S3 bucket: sagemaker-us-west-2-686650353599
Role: arn:aws:iam::686650353599:role/service-role/AmazonSageMaker-ExecutionRole-20221207T104487
Region: us-west-2


In [10]:
# Download official weights
if not os.path.exists("saved_models"):
    os.mkdir("saved_models")
    MODEL_PATH_URL = "https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn"
    gdown.download(MODEL_PATH_URL, "saved_models/model.pth", use_cookies=False)

Downloading...
From (uriginal): https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn
From (redirected): https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn&confirm=t&uuid=af0bf7bc-c527-42d9-a542-e9d5dea27240
To: /root/dis-background-removal-sagemaker/saved_models/model.pth
100%|██████████| 177M/177M [00:03<00:00, 52.2MB/s] 


## Package our model and code into model.tar.gz and upload it to S3

In [53]:
model_artifact= os.path.join("s3://", bucket, prefix, "model.tar.gz")
model_artifact

's3://sagemaker-us-west-2-686650353599/dis-background-removal/model.tar.gz'

## Package code from scripts folder

Here we will be packaging our code to be deployed into SageMaker for inference

1. Download the DIS library from the git repository (https://github.com/xuebinqin/DIS)
2. Move the DIS model into our scripts folder where we develop our script

Anatomy of the scripts folder
- `inference_script.py` - Our SageMaker inference script to be used. This overrides the default pytorch inferencing functionality with our own implementation.
- `models` folder contains the DIS model which we will use

In [119]:
!mkdir model_and_code
!cp ./saved_models/model.pth ./model_and_code
!mkdir model_and_code/code

mkdir: cannot create directory ‘model_and_code’: File exists


In [115]:
!git clone https://github.com/xuebinqin/DIS

Cloning into 'DIS'...
remote: Enumerating objects: 328, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 328 (delta 64), reused 56 (delta 55), pack-reused 253[K
Receiving objects: 100% (328/328), 49.59 MiB | 19.09 MiB/s, done.
Resolving deltas: 100% (153/153), done.
Checking out files: 100% (44/44), done.


### Copy the model to the package folder

In [120]:
!mv DIS/IS-Net/* ./model_and_code/code
!rm ./model_and_code/code/Inference.py
!rm ./model_and_code/code/requirements.txt

### Copy the inference code to the package folder

In [194]:
!cp -r ./scripts/* model_and_code/code

### Package code and model into model.tar.gz

In [195]:
!tar cvzf model.tar.gz -C model_and_code/ . 

./
./model.pth
./code/
./code/models/
./code/models/isnet.py
./code/models/__pycache__/
./code/models/__pycache__/u2netfast.cpython-37.pyc
./code/models/__pycache__/__init__.cpython-37.pyc
./code/models/__init__.py
./code/basics.py
./code/pytorch18.yml
./code/requirements.txt
./code/data_loader_cache.py
./code/.ipynb_checkpoints/
./code/.ipynb_checkpoints/requirements-checkpoint.txt
./code/inference.py
./code/train_valid_inference_main.py
./code/hce_metric_main.py
./code/__pycache__/
./code/__pycache__/basics.cpython-37.pyc
./code/__pycache__/data_loader_cache.cpython-37.pyc


## Upload model and code to S3

In [196]:
!aws s3 cp model.tar.gz $model_artifact

upload: ./model.tar.gz to s3://sagemaker-us-west-2-686650353599/dis-background-removal/model.tar.gz


## Create async endpoint

Reference: https://github.com/aws-samples/amazon-sagemaker-asynchronous-inference-computer-vision/blob/main/mask-rcnn-async-inference.ipynb

### Create SNS Error and Success topics

**Prerequisite**
Make sure IAM role has permission to create sns topics.

In [73]:
response = sns_client.create_topic(Name=f"{prefix}-Async-ErrorTopic")
error_topic= response['TopicArn']
print(error_topic)

arn:aws:sns:us-west-2:686650353599:dis-background-removal-Async-ErrorTopic


In [74]:
response = sns_client.create_topic(Name=f"{prefix}-Async-SuccessTopic")
success_topic = response['TopicArn']
print(success_topic)

arn:aws:sns:us-west-2:686650353599:dis-background-removal-Async-SuccessTopic


In [77]:
response = sns_client.list_topics()
topics = response["Topics"]
print(topics)

[{'TopicArn': 'arn:aws:sns:us-west-2:686650353599:DUSStack-dusstackjobcompletiontopic3jncj5mznfvxeahw8vnzazA02368F4-ESfqwypWqKrP'}, {'TopicArn': 'arn:aws:sns:us-west-2:686650353599:aws-controltower-SecurityNotifications'}, {'TopicArn': 'arn:aws:sns:us-west-2:686650353599:dis-background-removal-Async-ErrorTopic'}, {'TopicArn': 'arn:aws:sns:us-west-2:686650353599:dis-background-removal-Async-SuccessTopic'}]


### Create SageMaker Model with PyTorch inference conatiner

In [197]:
# Retrieve inference container
from sagemaker.image_uris import retrieve

deploy_instance_type = 'ml.m5.4xlarge'
pytorch_inference_image_uri = retrieve('pytorch',
                                       region,
                                       version='1.12',
                                       py_version='py38',
                                       instance_type = deploy_instance_type,
                                       accelerator_type=None,
                                       image_scope='inference')
print(pytorch_inference_image_uri)

763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12-cpu-py38


In [198]:
container = pytorch_inference_image_uri
model_name = 'sagemaker-maskrcnn-{0}'.format(str(int(time.time())))
print(container)
print(model_name)

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        'Image': container,
        'ModelDataUrl': model_artifact,
        'Environment': {
            'TS_MAX_REQUEST_SIZE': '100000000', #default max request size is 6 Mb for torchserve, need to update it to support the 70 mb input payload
            'TS_MAX_RESPONSE_SIZE': '100000000',
            'TS_DEFAULT_RESPONSE_TIMEOUT': '1000'
        }
    },    
)

763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12-cpu-py38
sagemaker-maskrcnn-1685937575


### Create Asynchronous inference endpoint

In [199]:
resource_name = "DISSegmentation-{}-{}"

endpoint_config_name = resource_name.format(
    "EndpointConfig", datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
)
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.m5.4xlarge",
            "InitialInstanceCount": 1,
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": f"s3://{bucket}/{prefix}/output",
            # Optionally specify Amazon SNS topics
            "NotificationConfig": {
              "SuccessTopic": success_topic,
              "ErrorTopic": error_topic,
            }
        },
        "ClientConfig": {"MaxConcurrentInvocationsPerInstance": 4},
    },
)
print(f"Created EndpointConfig: {create_endpoint_config_response['EndpointConfigArn']}")

Created EndpointConfig: arn:aws:sagemaker:us-west-2:686650353599:endpoint-config/DISSegmentation-EndpointConfig-2023-06-05-03-59-35


In [200]:
endpoint_name = f"sm-{prefix}-{strftime('%Y-%m-%d-%H-%M-%S', gmtime())}"
create_endpoint_response = sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
print(f"Creating Endpoint: {create_endpoint_response['EndpointArn']}")

Creating Endpoint: arn:aws:sagemaker:us-west-2:686650353599:endpoint/sm-dis-background-removal-2023-06-05-03-59-36


In [None]:
waiter = boto3.client('sagemaker').get_waiter('endpoint_in_service')
print("Waiting for endpoint to create...")
waiter.wait(EndpointName=endpoint_name)
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(f"Endpoint Status: {resp['EndpointStatus']}")

Waiting for endpoint to create...


## Run Inference

In [None]:
def upload_file(input_location, prefix):
    prefix = f"{prefix}/input"
    return sagemaker_session.upload_data(
        input_location, 
        bucket=sagemaker_session.default_bucket(),
        key_prefix=prefix, 
        extra_args={"ContentType": "image/jpg"})

In [None]:
input_1_location = "data/sample.jpg"
input_1_s3_location = upload_file(input_1_location, prefix)

In [None]:
print(input_1_s3_location)

In [None]:
print(endpoint_name)

In [None]:
response = sm_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=input_1_s3_location)
output_location = response['OutputLocation']
print(f"OutputLocation: {output_location}")

In [None]:
from botocore.exceptions import ClientError

def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sagemaker_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
        except ClientError as e:
            if e.response['Error']['Code'] == 'NoSuchKey':
                print("waiting for output...")
                time.sleep(2)
                continue
            raise

In [None]:
output = get_output(output_location)
print(f"Output size in bytes: {((sys.getsizeof(output)))}")

In [None]:
## Download
!aws s3 cp $output_location ./output

In [None]:
print(os.path.basename(output_location))

In [None]:
import json
from PIL import Image

with open(os.path.join("output", os.path.basename(output_location)), "rb") as f:
    data = json.load(f)
          
    # print(data)

payload = np.asarray(data)
pil_mask = Image.fromarray(payload.astype('uint8')).convert('L')

In [None]:
pil_mask

In [193]:
sm_client.delete_endpoint(EndpointName=endpoint_name)

{'ResponseMetadata': {'RequestId': '15d405ea-d3c8-442d-aeab-f49841fab3be',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '15d405ea-d3c8-442d-aeab-f49841fab3be',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Mon, 05 Jun 2023 03:57:49 GMT'},
  'RetryAttempts': 0}}

## Model inference script

In [None]:
model = PyTorchModel(
    model_data=MODEL_ARTIFACTS_FILE_NAME,
    entry_point='inference_script.py',
    framework_version='1.12',
    py_version='py38',
    role=role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts"  
)

In [None]:
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m5.4xlarge')

-----!

In [45]:
# Using a test image
# img_path = './data/test/img/img100.tif'
# mask_path = './data/test/mask/img100.tif'

# Using a train image
img_path = './sample2.jpg'
with open(img_path, "rb") as f:
    payload = f.read()

sm_runtime = boto3.Session().client("sagemaker-runtime")

response = sm_runtime.invoke_endpoint(
    EndpointName=predictor.endpoint_name, ContentType="application/x-image", Body=payload
)
print(response)

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (500) from primary with message "{
  "code": 500,
  "type": "InternalServerException",
  "message": "Worker died."
}
". See https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=/aws/sagemaker/Endpoints/pytorch-inference-2023-06-04-23-22-40-885 in account 686650353599 for more information.

In [None]:
with open(img_path, "rb") as image_file, open(mask_path, "rb") as mask_file:
    image = Image.open(image_file).convert("RGB")
    
    fig, ax = plt.subplots(1, 3, figsize=(10, 5))
    ax[0].imshow(image)
    ax[0].set_title('Original image')
    ax[1].imshow(np.array(result) > 0.5, cmap='Blues')
    ax[1].set_title('Masks prediction')
    fig.show()

In [46]:
predictor.delete_endpoint()