# Chemprot: Bert NER on SageMaker using PyTorch

This uses the chemprot chemical protien names corpus in https://biocreative.bioinformatics.udel.edu/news/corpora/chemprot-corpus-biocreative-vi/






In [1]:
import sys, os
import logging

sys.path.append("src")

logging.basicConfig(level="INFO", handlers=[logging.StreamHandler(sys.stdout)],
                        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

In [2]:
local_temp="temp"

In [3]:
!rm -rf $local_temp
!mkdir -p $local_temp

### Bucket and role set up

In [4]:
import boto3

#from sagemaker import get_execution_role

import sagemaker.session 
sm_session = sagemaker.session.Session()
account_id =  boto3.client('sts').get_caller_identity().get('Account')

# role=get_execution_role()
role ="arn:aws:iam::{}:role/service-role/AmazonSageMaker-ExecutionRole-20181222T162635".format(account_id)


2022-07-23 14:56:34,348 - botocore.credentials - INFO - Found credentials in shared credentials file: ~/.aws/credentials
2022-07-23 14:56:34,455 - botocore.credentials - INFO - Found credentials in shared credentials file: ~/.aws/credentials


In [35]:
data_bucket = "aegovan-data"

data_bucket_prefix = "chemprotner"

s3_uri_data = "s3://{}/{}".format(data_bucket, data_bucket_prefix)
s3_uri_train = "{}/{}".format(s3_uri_data, "train/chemprot_training_abstracts.tsv")
s3_uri_train_classes = "{}/{}".format(s3_uri_data, "train/chemprot_training_entities.tsv")

s3_uri_val = "{}/{}".format(s3_uri_data, "val/chemprot_development_abstracts.tsv")
s3_uri_val_classes = "{}/{}".format(s3_uri_data, "val/chemprot_development_entities.tsv")



s3_uri_test = "{}/{}".format(s3_uri_data, "test/chemprot_test_abstracts_gs.tsv")
s3_uri_test_classes = "{}/{}".format(s3_uri_data, "test/chemprot_test_entities_gs.tsv")


s3_output_path = "s3://{}/{}/output".format(data_bucket, data_bucket_prefix)
s3_code_path = "s3://{}/{}/code".format(data_bucket, data_bucket_prefix)
s3_checkpoint = "s3://{}/{}/checkpoint".format(data_bucket, data_bucket_prefix)

## Train

This shows you how to train BERT on SageMaker using SPOT instances

In [6]:
inputs_full =  {
    "train" : s3_uri_train,
    "class" : s3_uri_train_classes,
    "val" : s3_uri_val,
    "valclass" : s3_uri_val_classes
}

inputs = inputs_full

In [7]:
sm_localcheckpoint_dir="/opt/ml/checkpoints/"

In [8]:
instance_type = "ml.p3.2xlarge"
instance_type_gpu_map = {"ml.p3.8xlarge":4, "ml.p3.2xlarge": 1, "ml.p3.16xlarge":8}

In [9]:
hp = {
"epochs" : 50,
"earlystoppingpatience" : 10,
# Increasing batch size might end up with CUDA OOM error, increase grad accumulation instead
"batch" : 8 * instance_type_gpu_map[instance_type],
"trainfile" :s3_uri_train.split("/")[-1],
"classfile":s3_uri_train_classes.split("/")[-1],
"valfile" :s3_uri_val.split("/")[-1],
"valclassfile":s3_uri_val_classes.split("/")[-1],
"datasetfactory":"datasets.chemprot_dataset_factory.ChemprotDatasetFactory",
# The number of steps to accumulate gradients for
"gradaccumulation" : 4,
"log-level":"INFO",
# This param depends on your model max pos embedding size or when large you might end up with CUDA OOM error    
"maxseqlen" : 512,
# Make sure the lr is quite small, as this is a pretrained model..
"lr":0.00001,
# Use finetuning (set to 1), if you only want to change the weights in the final classification layer.. 
"finetune": 0,
"checkpointdir" : sm_localcheckpoint_dir,
# Checkpoints once every n epochs
"checkpointfreq": 2,
"log-level" : "INFO"
}



In [10]:
hp

{'epochs': 50,
 'earlystoppingpatience': 10,
 'batch': 8,
 'trainfile': 'chemprot_training_abstracts.tsv',
 'classfile': 'chemprot_training_entities.tsv',
 'valfile': 'chemprot_development_abstracts.tsv',
 'valclassfile': 'chemprot_development_entities.tsv',
 'datasetfactory': 'datasets.chemprot_dataset_factory.ChemprotDatasetFactory',
 'gradaccumulation': 4,
 'log-level': 'INFO',
 'maxseqlen': 512,
 'lr': 1e-05,
 'finetune': 0,
 'checkpointdir': '/opt/ml/checkpoints/',
 'checkpointfreq': 2}

In [11]:
inputs

{'train': 's3://aegovan-data/chemprotner/train/chemprot_training_abstracts.tsv',
 'class': 's3://aegovan-data/chemprotner/train/chemprot_training_entities.tsv',
 'val': 's3://aegovan-data/chemprotner/val/chemprot_development_abstracts.tsv',
 'valclass': 's3://aegovan-data/chemprotner/val/chemprot_development_entities.tsv'}

In [12]:
metric_definitions = [{"Name": "TrainLoss",
                     "Regex": "###score: train_loss### (\d*[.]?\d*)"}
                    ,{"Name": "ValidationLoss",
                     "Regex": "###score: val_loss### (\d*[.]?\d*)"}
                    ,{"Name": "TrainScore",
                     "Regex": "###score: train_score### (\d*[.]?\d*)"}
                   ,{"Name": "ValidationScore",
                     "Regex": "###score: val_score### (\d*[.]?\d*)"}
                    ]

In [13]:
# set True if you need spot instance
use_spot = False
train_max_run_secs =   2*24 * 60 * 60
spot_wait_sec =  5 * 60
max_wait_time_secs = train_max_run_secs +  spot_wait_sec

if not use_spot:
    max_wait_time_secs = None
    
# During local mode, no spot.., use smaller dataset
if instance_type == 'local':
    use_spot = False
    max_wait_time_secs = 0
    wait = True
   

In [14]:
job_type = "chemprot-ner-bert"
base_name = "{}".format(job_type)

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='main.py',
                    source_dir = 'src',
                    role=role,
                    framework_version ="1.4.0",
                    py_version='py3',
                    instance_count=1,
                    instance_type=instance_type,
                    hyperparameters = hp,
                    output_path=s3_output_path,
                    metric_definitions=metric_definitions,
                    volume_size=30,
                    code_location=s3_code_path,
                    debugger_hook_config=False,
                    base_job_name =base_name,  
                    use_spot_instances = use_spot,
                    max_run =  train_max_run_secs,
                    max_wait = max_wait_time_secs,   
                    checkpoint_s3_uri=s3_checkpoint,
                    checkpoint_local_path=sm_localcheckpoint_dir
                    )

estimator.fit(inputs, wait=True)

## Deploy BERT model

#### Inference container
Ideally the server containing should already have all the required dependencies installed to reduce start up time and ensure that the runtime enviornment is consistent. This can be implemented using a custom docker image.

But for this demo, to simplify, we will let the Pytorch container script model install the dependencies during start up. As a result, you will see some of the initial ping requests fail, until all dependencies are installed.


In [25]:
import sagemaker
training_job = "chemprot-ner-bert-2022-07-23-21-56-34-969"
estimator = sagemaker.estimator.Estimator.attach(training_job)


2022-07-23 22:40:11 Starting - Preparing the instances for training
2022-07-23 22:40:11 Downloading - Downloading input data
2022-07-23 22:40:11 Training - Training image download completed. Training in progress.
2022-07-23 22:40:11 Uploading - Uploading generated training model
2022-07-23 22:40:11 Completed - Training job completed


In [26]:
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role
role = role

model_uri = estimator.model_data

model = PyTorchModel(model_data=model_uri,
                     role=role,
                     framework_version='1.4.0',
                     py_version = "py3",
                     entry_point='serve.py',
                     source_dir='src'
                    
                    )

predictor = model.deploy(initial_instance_count=1, instance_type='ml.p3.2xlarge')

2022-07-23 17:13:40,137 - sagemaker - INFO - Creating model with name: pytorch-inference-2022-07-24-00-13-40-137
2022-07-23 17:13:41,044 - sagemaker - INFO - Creating endpoint-config with name pytorch-inference-2022-07-24-00-13-41-043
2022-07-23 17:13:41,193 - sagemaker - INFO - Creating endpoint with name pytorch-inference-2022-07-24-00-13-41-043
--------!

### Invoke API

In [27]:
class Predictor:
    
    def serialize(self, x):
        return x
    
    def deserialize(self,x, content_type):
        payload_bytes = json.loads( x.read().decode("utf-8") )
        return payload_bytes

In [28]:
#predictor= sagemaker.predictor.Predictor("end")
predictor.serializer = Predictor()
predictor.deserializer = Predictor()


In [29]:
test_local_dir=local_temp
sagemaker.s3.S3Downloader.download(s3_uri_test, test_local_dir)
test_local_file = os.path.join(local_temp, s3_uri_test.split("/")[-1])

In [36]:
sagemaker.s3.S3Downloader.download(s3_uri_test_classes, test_local_dir)

In [30]:
import json, csv

from datasets.chemprot_dataset import ChemprotDataset
from datasets.chemprot_ner_label_mapper import ChemprotNerLabelMapper

def chunk(l, size=5):
    for i in range(0, len(l),size):
        yield l[i:i+size]
        
def predict(test_local_file, output_file):
    
    # Load file
    with open(test_local_file, "r") as f:
        docs = []
        ids =[]
        reader = csv.reader(f, delimiter='\t', quotechar=None)
        for l in reader:
            (id, text) = l[0], l[1] + l[2]
            docs.append(text)
            ids.append(id)
            
        
    
    
    label_mapper = ChemprotNerLabelMapper()

    id_chucks = list(chunk(ids))
    result = []
    for (i, data) in enumerate(chunk(docs)):

        data_bytes=("\n".join(data)).encode("utf-8")
        response  = predictor.predict(data_bytes,  
                                        initial_args={ "Accept":"text/json", "ContentType" : "text/csv" }
                                           )

        assert len(response) == len(data), "Data size {} doesnt match result size {}".format(len(r), len(d))



        for ri, r in enumerate(response):
            doc_id = id_chucks[i][ri]

            result.append({"docid":doc_id, "text": data[ri], "entities_detected": r })
        
    
    
    with open(output_file, "w") as f:
        json.dump( result, f)
        


results_json_file=os.path.join(local_temp, "result.json")
predict(test_local_file, results_json_file)

## Delete endpoint

In [31]:
predictor.delete_endpoint()

2022-07-23 17:18:37,879 - sagemaker - INFO - Deleting endpoint configuration with name: pytorch-inference-2022-07-24-00-13-41-043
2022-07-23 17:18:38,049 - sagemaker - INFO - Deleting endpoint with name: pytorch-inference-2022-07-24-00-13-41-043
