# Train self supervised bert

### Set up

#### 1. Set  up  accounts and role

In [1]:
import sagemaker
import boto3
from uuid import uuid4

sagemaker_session = sagemaker.Session()
account_id =  boto3.client('sts').get_caller_identity().get('Account')
region = boto3.session.Session().region_name


#role = sagemaker.get_execution_role()
role="arn:aws:iam::{}:role/service-role/AmazonSageMaker-ExecutionRole-20190118T115449".format(account_id)
step_func_role = "arn:aws:iam::{}:role/AmazonSageMaker-StepFunctionsWorkflowExecutionRole".format(account_id)


max_runs=1

#### 2. Setup image and instance type

In [2]:
# pytorch_custom_image_name="ppi-extractor:gpu-1.0.0-201910130520"
instance_type = "ml.p3.2xlarge"
instance_type_gpu_map = {"ml.p3.8xlarge":4, "ml.p3.2xlarge": 1, "ml.p3.16xlarge":8}

In [3]:
# docker_repo = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account_id, region, pytorch_custom_image_name)

#### 3. Configure train/ test and validation datasets

In [4]:
bucket = "aegovan-data"

In [5]:
pretrained_bert="s3://{}/embeddings/bert/".format(bucket)


s3_output_path= "s3://{}/selfsupervised_results/".format(bucket)
s3_code_path= "s3://{}/selfsupervised_code".format(bucket)
s3_checkpoint = "s3://{}/selfsupervised_bert_checkpoint/{}".format(bucket, str(uuid4()))

In [6]:
fake_prefix = "s3://{}/self-supervised-fake/".format(bucket)

real_prefix = "s3://{}/self-supervised-real/".format(bucket)

fake_prefix="s3://aegovan-data/self-supervised-fake-fixed-size/202211232148/"


### Start training

In [7]:
commit_id = "a6211b46f5940b9ac48fd3bde9274734ec3605a5"

In [8]:
sm_localcheckpoint_dir="/opt/ml/checkpoints/"

In [9]:
BertNetworkFactoryhyperparameters = {
    "datasetfactory":"datasets.chemprot_selfsupervised_dataset_factory.ChemprotSelfsupervisedDatasetFactory",
    "modelfactory" :"models.bert_model_factory.BertModelFactory",
    "tokenisor_lower_case":0,
    "uselosseval":1,
    "batch": "8" * instance_type_gpu_map[instance_type],
    "gradientaccumulationsteps" : "8",
    # "protein_name_replacer_random_seed":42,
    "epochs" : "200",   
    "log-level" : "INFO",
    "learningrate":.00001,
    "earlystoppingpatience":50,
    "checkpointdir" : sm_localcheckpoint_dir,
    # Checkpoints once every n epochs
    "checkpointfreq": 2,
    "weight_decay":0.01,
    "commit_id" : commit_id
}

In [10]:
BertNetworkFactoryhyperparameters_max_f1 = BertNetworkFactoryhyperparameters.copy()
BertNetworkFactoryhyperparameters_max_f1["uselosseval"] = 0

In [11]:
metric_definitions = [{"Name": "TrainLoss",
                     "Regex": "###score: train_loss### (\d*[.]?\d*)"}
                    ,{"Name": "ValidationLoss",
                     "Regex": "###score: val_loss### (\d*[.]?\d*)"}
                
                     ,{"Name": "TrainF1BinaryScore",
                     "Regex": "###score: train_ResultScorerF1Binary_score### (\d*[.]?\d*)"}
                   ,{"Name": "ValidationF1BinaryScore",
                     "Regex": "###score: val_ResultScorerF1Binary_score### (\d*[.]?\d*)"}
                    ]

In [12]:
!git log -1 | head -1
!git log -1 | head -5 | tail -1

commit 9208e130d8f1878e5e477219487e886ac6b9189e
    Update notebook


In [13]:
# set True if you need spot instance
use_spot = False
train_max_run_secs =   5 *24 * 60 * 60
spot_wait_sec =  5 * 60
max_wait_time_secs = train_max_run_secs +  spot_wait_sec

if not use_spot:
    max_wait_time_secs = None
    
# During local mode, no spot.., use smaller dataset
if instance_type == 'local':
    use_spot = False
    max_wait_time_secs = 0
    wait = True
    # Use smaller dataset to run locally
    # TODO:
    #  inputs = inputs_sample

In [14]:
def create_fake_hp(fake_prefix, prefix="fake"):
    from sagemaker.s3 import S3Downloader
    train_files = sorted(list(  filter( lambda x: x.endswith("json") and "train" in x,
                                            S3Downloader.list(fake_prefix))
                                    ), reverse=True)
    
    
                   
    
    fake_experiments = {}
    
    
    for f in train_files:
        base_name = f.split("/")[-1].replace("train_","").replace(".json","")

        fake_experiments[f"selfsup-{prefix}-{base_name}-bert-f1"] = {
            "hp" :BertNetworkFactoryhyperparameters_max_f1,
            "inputs" :  {
                    "train" : f"{f}",
                    "val" : f"{fake_prefix}val.json",
                    "PRETRAINED_MODEL" : pretrained_bert
                }
            }

            
    return fake_experiments


fake_experiments_hp = create_fake_hp(fake_prefix, "fake")
real_experiments_hp = create_fake_hp(real_prefix, "real")

In [15]:


experiments = {**fake_experiments_hp, **real_experiments_hp}

In [16]:
import datetime
date_fmt = datetime.datetime.today().strftime("%Y%m%d%H%M")

In [17]:
from sagemaker.pytorch import PyTorch
from stepfunctions.steps import *
from stepfunctions.workflow import Workflow
import random

train_steps = []
variations = 5
for n,e in filter(lambda x: True  ,fake_experiments_hp.items()):
    print(f"Running {n}")
    for i in range(variations):


        job_name = n.replace("_","-")+ f"-{i:02d}-" + date_fmt



        estimator = PyTorch(
          entry_point='main_train_pipeline.py',
                            source_dir = '../src',
                            dependencies =['../src/datasets', '../src/models','../src/utils', '../src/scorers'],
                             # git_config= git_config,
        #                     image_name= docker_repo,
                            role=role,
                            framework_version ="1.4.0",
                            py_version='py3',
                            instance_count=1,
                            instance_type=instance_type,
                            hyperparameters = e["hp"],
                            output_path=s3_output_path,
                            metric_definitions=metric_definitions,
                            volume_size=30,
                            code_location=s3_code_path,
                            debugger_hook_config=False,
                            base_job_name =n.replace("_", "-"),  
                            use_spot_instances = use_spot,
                            max_run =  train_max_run_secs,
                            max_wait = max_wait_time_secs,   
                            # checkpoint_s3_uri=s3_checkpoint,
                            # checkpoint_local_path=sm_localcheckpoint_dir
        )


    
        # Job
        step_train = sagemaker.TrainingStep( f"Train-{job_name}", 
                                                       estimator, 
                                                       job_name, 
                                                       data=e["inputs"])

        train_steps.append(step_train)




parallel_steps = []
max_parallel = 5
for i in range(0, len(train_steps), max_parallel):
    p = states.Parallel(f"train-p-{i}")
    for s in train_steps[i: i+max_parallel]:
        w = random.randint(10, 300)
        p.add_branch(Chain([states.Wait(f"wait-{i}-{w}", seconds=w) , s]))
    parallel_steps.append(p)

    
basic_path = Chain(parallel_steps)

basic_workflow = Workflow(
    name=f"selfsup-adver-train-{date_fmt}", definition=basic_path, role=step_func_role
)



Running selfsupervised-fake-2000_500_50-bert-f1
Running selfsupervised-fake-2000_500_400-bert-f1
Running selfsupervised-fake-2000_500_350-bert-f1
Running selfsupervised-fake-2000_500_300-bert-f1
Running selfsupervised-fake-2000_500_250-bert-f1
Running selfsupervised-fake-2000_500_200-bert-f1
Running selfsupervised-fake-2000_500_150-bert-f1
Running selfsupervised-fake-2000_500_100-bert-f1
Running selfsupervised-fake-2000_500_0-bert-f1


In [18]:
# Render the workflow
basic_workflow.render_graph()




In [19]:
basic_workflow.create()

'arn:aws:states:us-east-2:324346001917:stateMachine:selfsup-adver-train-202211232215'

In [20]:
basic_workflow.execute()