In [225]:
import transformers
import sagemaker
import torch
import json

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path

from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor
from datetime import datetime
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.utils import name_from_base

import pandas as pd
from utils import save_to_s3
import tarfile

In [226]:
sess = sagemaker.Session()
sagemaker_session_bucket = 'sagemaker-godeltech'
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

#put SageMaker role here if you're running this notebook locally
role = ''


sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker bucket: sagemaker-godeltech
sagemaker session region: eu-west-1


In [227]:
MODEL_NAME = 'unitary/toxic-bert'

In [183]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir = '../tmp/AutoTokenizer')  
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME,
                                                           cache_dir = '../tmp/AutoModel',
                                                           return_dict=False)

In [228]:
# Prepare sample input for jit model tracing
seq = "Godel technologies: Sage Maker: this is just an example for PyTorch"
max_length = 512

tokenized_sequence_pair = tokenizer.encode_plus(
    seq, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)

example = tokenized_sequence_pair["input_ids"], tokenized_sequence_pair["attention_mask"]

traced_model = torch.jit.trace(model.eval(), example)
traced_model.save("../tmp/model.pth")

In [229]:
with tarfile.open("../tmp/model.tar.gz", "w:gz") as f:
    f.add("../tmp/model.pth")

In [188]:
prefix = "neuron-experiments"
flavour = "normal"
date_string = datetime.now().strftime("%Y%m%d")

In [261]:
traced_model_url = sess.upload_data(
    path="../tmp/model.tar.gz",
    key_prefix=f"{prefix}/toxicbert/traced_model",
)
traced_model_url

's3://sagemaker-godeltech/neuron-experiments/toxicbert/traced_model/model.tar.gz'

In [191]:
normal_sm_model = PyTorchModel(
    model_data=traced_model_url,
    predictor_cls=Predictor,
    # framework_version="1.8.1",
    role=role,
    sagemaker_session=sess,
    entry_point="inference_normal.py",
    source_dir="aux",
    py_version="py3",
    name=f"{flavour}-toxic-{date_string}",
    env={"SAGEMAKER_CONTAINER_LOG_LEVEL": "10"},
)

In [192]:
%%time

hardware = "g4dn"

normal_predictor = normal_sm_model.deploy(
    instance_type="ml.g4dn.xlarge",
    initial_instance_count=1,
    endpoint_name=f"toxicbert-{flavour}-{hardware}-{date_string}",
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

print('READY!')

----------!READY!
CPU times: user 31 s, sys: 5.02 s, total: 36 s
Wall time: 5min 39s


In [214]:
normal_predictor.predict("who are you? this is a test message")[1:-1]

"toxicity': 0.0012372832279652357, 'severe_toxic': 0.00010001649934565648, 'obscene': 0.0001964759867405519, 'threat': 0.00010137527715414762, 'insult': 0.00019288001931272447, 'identity_hate': 0.0001452629658160731"

In [264]:
compiled_sm_model = PyTorchModel(
    model_data=traced_model_url,
    predictor_cls=Predictor,
    framework_version="1.12.1",
    role=role,
    sagemaker_session=sess,
    entry_point="inference_inf1.py",
    source_dir="aux",
    py_version="py3",
    env={"MMS_DEFAULT_RESPONSE_TIMEOUT": "500"},
)

In [266]:
hardware = "inf1"
compilation_job_name = name_from_base("godel")


compiled_inf1_model = compiled_sm_model.compile(
    target_instance_family=f"ml_{hardware}",
    input_shape={"input_ids": [1, 512], "attention_mask": [1, 512]},
    job_name=compilation_job_name,
    role=role,
    framework="pytorch",
    framework_version="1.12.1",
    output_path=f"s3://{sagemaker_session_bucket}/{prefix}/compiled_model",
    compiler_options=json.dumps("--dtype int64"),
    #     compiler_options={'dtype': 'int64'},    # For compiling to "normal" instance types, cpu or gpu-based
    compile_max_run=900,
)

?????????????????????????????....................................*

UnexpectedStatusException: Error for Compilation job godel-2022-09-27-13-24-28-285: Failed. Reason: ClientError: InputConfiguration: Unable to load PyTorch model:', '[enforce fail at inline_container.cc:222] . file not found: model/version')  For further troubleshooting common failures please visit: https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting-compilation.html

In [None]:
flavour = 'inferentia'


compiled_inf1_predictor = compiled_inf1_model.deploy(
    instance_type="ml.inf1.xlarge",
    initial_instance_count=1,
    endpoint_name=f"toxicbert-{flavour}-{hardware}-{date_string}",
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

In [268]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.load('../tmp/model.pth', map_location=device)



RecursiveScriptModule(
  original_name=BertForSequenceClassification
  (bert): RecursiveScriptModule(
    original_name=BertModel
    (embeddings): RecursiveScriptModule(
      original_name=BertEmbeddings
      (word_embeddings): RecursiveScriptModule(original_name=Embedding)
      (position_embeddings): RecursiveScriptModule(original_name=Embedding)
      (token_type_embeddings): RecursiveScriptModule(original_name=Embedding)
      (LayerNorm): RecursiveScriptModule(original_name=LayerNorm)
      (dropout): RecursiveScriptModule(original_name=Dropout)
    )
    (encoder): RecursiveScriptModule(
      original_name=BertEncoder
      (layer): RecursiveScriptModule(
        original_name=ModuleList
        (0): RecursiveScriptModule(
          original_name=BertLayer
          (attention): RecursiveScriptModule(
            original_name=BertAttention
            (self): RecursiveScriptModule(
              original_name=BertSelfAttention
              (query): RecursiveScriptModule(ori