### Set up

#### 1. Set  up  accounts and role

In [1]:
import sagemaker
import boto3
from datetime import datetime

sagemaker_session = sagemaker.Session()
account_id =  boto3.client('sts').get_caller_identity().get('Account')
region = boto3.session.Session().region_name


#role = sagemaker.get_execution_role()
role="arn:aws:iam::{}:role/service-role/AmazonSageMaker-ExecutionRole-20190118T115449".format(account_id)
max_runs=1

#### 2. Setup image and instance type

In [2]:
# pytorch_custom_image_name="ppi-extractor:gpu-1.0.0-201910130520"
instance_type = "ml.p3.2xlarge" 

In [3]:
# docker_repo = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account_id, region, pytorch_custom_image_name)

#### 3. Configure train/ test and validation datasets

In [4]:
bucket = "aegovan-data"

In [5]:
# train = "s3://{}/counterfactuals/imdb/202304021657/24_0_0/train.json".format(bucket)
# val = "s3://{}/counterfactuals/imdb/202304021657/24_0_0/val.json".format(bucket)

# train = "s3://aegovan-data/counterfactuals/imdb/202304081856/9_0_0/train.json"
# val = "s3://aegovan-data/counterfactuals/imdb/202304081856/9_0_0/val.json"

# train = "s3://aegovan-data/counterfactuals/imdb/202304091143/original/train.json"
# val = "s3://aegovan-data/counterfactuals/imdb/202304091143/original/val.json"

# train = "s3://aegovan-data/counterfactuals/imdb/202304091342/0_0_0/train.json"
# val = "s3://aegovan-data/counterfactuals/imdb/202304091342/0_0_0/val.json"

# train = "s3://aegovan-data/counterfactuals/imdb/202304091853/9_0_0/train.json"
# val = "s3://aegovan-data/counterfactuals/imdb/202304091853/9_0_0/val.json"


# train = "s3://aegovan-data/counterfactuals/imdb/202304091926/0_0_0/train.json"
# val = "s3://aegovan-data/counterfactuals/imdb/202304091926/0_0_0/val.json"

train = "s3://aegovan-data/counterfactuals/imdb/202304092106/23_0_0/train.json"
val = "s3://aegovan-data/counterfactuals/imdb/202304092106/23_0_0/val.json"



pretrained_bert="s3://{}/pretrained_models/bert-base-uncased/".format(bucket)



s3_output_path= "s3://{}/sagemakerresults/".format(bucket)
s3_code_path= "s3://{}/counterfactuals_imdb_bert_code".format(bucket)
s3_checkpoint = "s3://{}/counterfactuals_imdb_checkpoint/{}".format(bucket, datetime.now().strftime("%m%d%Y%H%M%S"))

### Start training

In [6]:
inputs = {
    "train" : train,
    "val" : val,
    "PRETRAINED_MODEL" : pretrained_bert
}

In [7]:
sm_localcheckpoint_dir="/opt/ml/checkpoints/"


In [8]:
BertNetworkFactoryhyperparameters = {
    "datasetfactory":"datasets.counterfact_imbd_dataset_factory.CounterfactImdbDatasetFactory",
    "modelfactory" :"models.bert_model_factory.BertModelFactory",
    "tokenisor_lower_case":1,
    "batch": "8",
    "gradientaccumulationsteps" : "8",
    "epochs" : "100",   
    "log-level" : "INFO",
    "learningrate":.00001,
    "earlystoppingpatience":9,
    "checkpointdir" : sm_localcheckpoint_dir,
    # Checkpoints once every n epochs
    "checkpointfreq": 2



}

In [9]:
metric_definitions = [{"Name": "TrainLoss",
                     "Regex": "###score: train_loss### (\d*[.]?\d*)"}
                    ,{"Name": "ValidationLoss",
                     "Regex": "###score: val_loss### (\d*[.]?\d*)"}
                
                    ,{"Name": "TrainAccuracyScore",
                     "Regex": "###score: train_ResultScorerAccuracy_score### (\d*[.]?\d*)"}
                   ,{"Name": "ValidationAccuracyScore",
                     "Regex": "###score: val_ResultScorerAccuracy_score### (\d*[.]?\d*)"}
                  
                    ]

In [10]:
!git log -1 | head -1
!git log -1 | head -5 | tail -1

commit bdd410f4107aea31ab3ecb6c916190b4306be5ec
    Update notebook - repeat adv 0.10


In [11]:
# set True if you need spot instance
use_spot = False
train_max_run_secs =   5 *24 * 60 * 60
spot_wait_sec =  5 * 60
max_wait_time_secs = train_max_run_secs +  spot_wait_sec

if not use_spot:
    max_wait_time_secs = None
    
# During local mode, no spot.., use smaller dataset
if instance_type == 'local':
    use_spot = False
    max_wait_time_secs = 0
    wait = True
    # Use smaller dataset to run locally
    inputs = inputs_sample

In [12]:
experiments = {
 
      "counterfact-imdb" : {
        "hp" :BertNetworkFactoryhyperparameters,
        "inputs" : inputs
    }
}

In [13]:

base_name = "counterfact-imdb"

hyperparameters = experiments[base_name]["hp"]
inputs = experiments[base_name]["inputs"] 

In [14]:
hyperparameters

{'datasetfactory': 'datasets.counterfact_imbd_dataset_factory.CounterfactImdbDatasetFactory',
 'modelfactory': 'models.bert_model_factory.BertModelFactory',
 'tokenisor_lower_case': 1,
 'batch': '8',
 'gradientaccumulationsteps': '8',
 'epochs': '100',
 'log-level': 'INFO',
 'learningrate': 1e-05,
 'earlystoppingpatience': 9,
 'checkpointdir': '/opt/ml/checkpoints/',
 'checkpointfreq': 2}

In [15]:
inputs

{'train': 's3://aegovan-data/counterfactuals/imdb/202304092106/23_0_0/train.json',
 'val': 's3://aegovan-data/counterfactuals/imdb/202304092106/23_0_0/val.json',
 'PRETRAINED_MODEL': 's3://aegovan-data/pretrained_models/bert-base-uncased/'}

In [16]:
job_base_name = f"{base_name}-{inputs['train'].split('/')[-2]}".replace("_", "-")

In [17]:


from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point='main_train_pipeline.py',
                    source_dir = '../src',
                    dependencies =['../src/datasets', '../src/models','../src/utils', '../src/scorers'],
#                      git_config= git_config,
#                     image_name= docker_repo,
                    role=role,
                    framework_version ="1.4.0",
                    py_version='py3',
                    instance_count=1,
                    instance_type=instance_type,
                    hyperparameters = hyperparameters,
                    output_path=s3_output_path,
                    metric_definitions=metric_definitions,
                    volume_size=30,
                    code_location=s3_code_path,
                    debugger_hook_config=False,
                    base_job_name = job_base_name,  
                    use_spot_instances = use_spot,
                    max_run =  train_max_run_secs,
                    max_wait = max_wait_time_secs,   
                    checkpoint_s3_uri=s3_checkpoint,
                    checkpoint_local_path=sm_localcheckpoint_dir)

estimator.fit(inputs, wait=False)