### SageMaker self supervised prediction

In [1]:
import sagemaker
import boto3

sagemaker_session = sagemaker.Session()
account_id =  boto3.client('sts').get_caller_identity().get('Account')
region = boto3.session.Session().region_name


#role = sagemaker.get_execution_role()
role="arn:aws:iam::{}:role/service-role/AmazonSageMaker-ExecutionRole-20190118T115449".format(account_id)
step_func_role = "arn:aws:iam::{}:role/AmazonSageMaker-StepFunctionsWorkflowExecutionRole".format(account_id)

In [2]:
version_tag="202211100421"
pytorch_custom_image_name=f"large-scale-ptm-ppi:gpu-{version_tag}"

In [3]:
docker_repo = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account_id, region, pytorch_custom_image_name)

In [4]:
bucket = "aegovan-data"

In [5]:

abstract_fake = "s3://{}/self-supervised-fake/8217_767_1593/".format(bucket)

abstract_largescale = "s3://{}/selfsupervisedlargescale/pubmedabstracts/".format(bucket)

eval_file = abstract_largescale
filepattern = "{}/*.tsv"#   "{}/*.json" # #  # #  # 

instance_type =  "ml.g4dn.2xlarge" # "ml.p3.2xlarge"  #  #ml.g4dn.2xlarge
instance_count =  5

In [6]:
import datetime
date_fmt = datetime.datetime.today().strftime("%Y%m%d%H%M")

In [7]:
training_jobs =[]
for m in ["0", "50", "100", "150", "200", "250", "300", "350", "400"]:
    for i in ["00", "01", "02", "03", "04"]:
        training_jobs.append(f"selfsup-fake-2000-500-{m}-bert-f1-{i}-202211241448")

### Run  prediction

In [8]:
s3_input_data = eval_file
s3_data_type="S3Prefix"
usefilter=0
filter_threshold_std=1.0

s3_input_vocab = "s3://{}/embeddings/bert/".format(bucket)

In [9]:
s3_input_data, s3_data_type

('s3://aegovan-data/selfsupervisedlargescale/pubmedabstracts/', 'S3Prefix')

In [10]:
from sagemaker.network import NetworkConfig
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.processing import ScriptProcessor
from stepfunctions.steps import *
from stepfunctions.workflow import Workflow
import random

processing_steps = []

PREPROCESSING_SCRIPT_LOCATION = "../src/inference/selfsupervised_batch_predict.py"

input_code = sagemaker_session.upload_data(
    PREPROCESSING_SCRIPT_LOCATION,
    bucket=sagemaker_session.default_bucket(),
    key_prefix="code/chemprot-adverserial/code",
)

for j in training_jobs:
    s3_model_path = f"s3://aegovan-data/selfsupervised_results/{j}/output/model.tar.gz"
    s3_output_predictions = "s3://aegovan-data/selfsupervised_chemprot/predictions_{}/{}".format(j,date_fmt)
    s3_input_models = s3_model_path


    script_processor = ScriptProcessor(image_uri=docker_repo,
                                           command=["python"],
                                           env={'mode': 'python', 'PYTHONPATH':'/opt/ml/code'},
                                           role=role,
                                           instance_type=instance_type,
                                           instance_count=instance_count,
                                           max_runtime_in_seconds= 5 * 24 * 60 * 60,
                                           volume_size_in_gb = 200,
                                           network_config = NetworkConfig(enable_network_isolation=False),
                                           )


    sm_local_input_models = "/opt/ml/processing/input/data/models"
    sm_local_input_data = "/opt/ml/processing/input/data/jsondata"
    sm_local_input_vocab = "/opt/ml/processing/input/data/vocab"


    sm_local_output = "/opt/ml/processing/output"

    input_file_name = s3_input_data.split("/")[-1]

    step_processing = sagemaker.ProcessingStep(f"selfsup-infer-{j[-31:]}", 
            script_processor, 
            f"selfsup-infer-{j[-31:]}-{date_fmt}", 
            container_entrypoint=['python', '/opt/ml/processing/input/code/selfsupervised_batch_predict.py'],
                
            container_arguments=[
                sm_local_input_data,
                sm_local_input_models,
                sm_local_output,
                "--ensemble", "0",
                "--tokenisor_data_dir", sm_local_input_vocab,           
                "--filter", str(usefilter),
                "--batch", "32",
                "--filterstdthreshold", str(filter_threshold_std),
                "--filepattern",filepattern
            ],

            inputs=[
                    ProcessingInput(
                        source=s3_input_data,
                        input_name="input-1",
                        s3_data_type = s3_data_type,
                        destination=sm_local_input_data,
                        s3_data_distribution_type="ShardedByS3Key"),

                ProcessingInput(
                        source=s3_input_models,
                        input_name="model",
                        destination=sm_local_input_models,
                        s3_data_distribution_type="FullyReplicated"),

                ProcessingInput(
                        source=s3_input_vocab,
                        input_name="vocab",
                        destination=sm_local_input_vocab,
                        s3_data_distribution_type="FullyReplicated"),
                
                 ProcessingInput(
                                source=input_code,
                                destination="/opt/ml/processing/input/code",
                                input_name="code",
                            ),


                ],


            outputs=[ProcessingOutput(
                    source=sm_local_output, 
                    destination=s3_output_predictions,
                    output_name='predictions')]
        )


    processing_steps.append(step_processing)
    



parallel_steps = []
max_parallel =2
for i in range(0, len(processing_steps), max_parallel):
    p = states.Parallel(f"predict-p-{i}")
    for si, s in enumerate(processing_steps[i: i+max_parallel]):
        w = (si+1)*30
        p.add_branch(Chain([states.Wait(f"wait-{i}-{w}", seconds=w) , s]))
    parallel_steps.append(p)

    
basic_path = Chain(parallel_steps)

basic_workflow = Workflow(
    name=f"selfsup-adver-infer-{date_fmt}", definition=basic_path, role=step_func_role
)

# Render the workflow
basic_workflow.render_graph()



In [11]:
basic_workflow.create()

'arn:aws:states:us-east-2:324346001917:stateMachine:selfsup-adver-infer-202211261427'

In [12]:
basic_workflow.execute()