### 1. Load the Enviromental Variables

In [13]:
SHARED_FS = 'tpushare'
MOUNT_POINT = '/mnt/common'
TPU_POD_NAME = 'nyc-tpu-v3-32' 
NFS_IP = '10.224.68.26'
BUILD = '5e452b42-a97c-40da-9a1a-5f2a5fc6ba34'
PYTORCH_PROJ_NAME='pytorch-tpu-nfs'

### 2. Load data from GCS bucket 

In [14]:
GCS_DATASET='gs://tpu-demo-eu/dataset/*'
!if [ -d "$MOUNT_POINT/data" ]; then echo "using existing $MOUNT_POINT/data directory"; else sudo mkdir -p $MOUNT_POINT/data && sudo gsutil -m cp -r $GCS_DATASET $MOUNT_POINT/data/; fi

### 3. Download RoBERTa code from repo

In [15]:
CODE_REPO='https://github.com/taylanbil/fairseq.git'
BRANCH='roberta-tpu'
!if [ -d "$MOUNT_POINT/code" ]; then echo "using existing $MOUNT_POINT/code directory"; else sudo mkdir -p $MOUNT_POINT/code && sudo git clone $CODE_REPO $MOUNT_POINT/code/; fi
!cd $MOUNT_POINT/code && git fetch && git checkout $BRANCH

### 4. Set the execution Variables

In [16]:
logfile = "$(date +%Y%m%d)-roberta-podrun-$1.txt"
nshards = 1
num_cores = 8
data_path = "$MOUNT_POINT/data" 
DATABIN = '/mnt/common/data/shard0'
checkpoints_out = MOUNT_POINT+'/checkpoints/checkpoints-roberta'+BUILD

### 5. Execute torch_xla.distributed.xla_dist

In [None]:
!python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME \
	--docker-run-flag=--shm-size=120GB \
	--docker-run-flag=--rm=true \
	--docker-run-flag=--volume=$SHARED_FS:$MOUNT_POINT \
	--env=XLA_USE_BF16=1  \
	--docker-image=gcr.io/$PYTORCH_PROJ_NAME/xla:$BUILD -- python $MOUNT_POINT/code/train.py \
	$DATABIN \
	--save-dir $checkpoints_out \
	--arch roberta_large \
	--optimizer adam \
	--adam-betas "(0.9, 0.98)" \
	--adam-eps 1e-06 \
	--clip-norm 1.0 \
	--lr-scheduler polynomial_decay \
	--lr 0.0004 \
	--warmup-updates 15000 \
	--max-update 1500000 \
	--log-format json \
	--log-interval 10 \
	--skip-invalid-size-inputs-valid-test \
	--task multilingual_masked_lm \
	--criterion masked_lm \
	--dropout 0.1 \
	--attention-dropout 0.1 \
	--weight-decay 0.01 \
	--sample-break-mode complete \
	--tokens-per-sample 512 \
	--total-num-update 1500000 \
	--multilang-sampling-alpha 0.7 \
	--no-epoch-checkpoints \
	--save-interval-updates 3000 \
	--validate-interval 5000 \
	--num-workers 1 \
	--update-freq `expr 4096 / 8` \
	--valid-subset=valid \
	--train-subset=train \
	--input_shapes 2x512 \
	--num_cores=8 \
	--metrics_debug \
	--suppress_loss_report \
	--log_steps=1