In [None]:
# !pip install -e ./src

In [None]:
import os
import json
import argparse
import sys
import warnings
from pathlib import Path
from ast import literal_eval
warnings.filterwarnings('ignore')

import torch
import torchvision as tv
import pytorch_lightning as pl
import webdataset as wds
from sm_resnet.models import ResNet
from sm_resnet.callbacks import PlSageMakerLogger, ProfilerCallback, SMDebugCallback
from sm_resnet import callbacks
from smdebug.core.reduction_config import ReductionConfig
from smdebug.core.save_config import SaveConfig
from smdebug.core.collection import CollectionKeys

world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

In [None]:
s3_bucket = "s3://jbsnyder-sagemaker-us-east/"

model_params = {'num_classes': 1000,
                    'resnet_version': 50,
                    'train_path': os.path.join(s3_bucket, "data", "imagenet", "train"),
                    'val_path': os.path.join(s3_bucket, "data", "imagenet", "val"),
                    'optimizer': 'adamw',
                    'lr': 0.004, 
                    'batch_size': 64,
                    'dataloader_workers': 0,
                    'max_epochs': 2,
                    'warmup_epochs': 1,
                    'mixup_alpha': 0.1
                   }

trainer_params = {'gpus': [local_rank],
                  'max_epochs': 2,
                  'precision': 16,
                  'progress_bar_refresh_rate': 0,
                  'replace_sampler_ddp': False,
                  'callbacks': [PlSageMakerLogger()]
                  }

debugger_params = {'out_dir': os.path.join(os.getcwd(), 'debugger_output'),
                       'export_tensorboard': True,
                       'tensorboard_dir': os.path.join(os.getcwd(), 'tensorboard'),
                       'reduction_config': ReductionConfig(reductions=['mean', 'std']),
                       'save_config': SaveConfig(save_interval=25),
                       'include_collections': [CollectionKeys.LOSSES, CollectionKeys.GRADIENTS],
                       'save_all': False,
                       }

trainer_params['callbacks'].append(SMDebugCallback(**debugger_params))

In [None]:
model = ResNet(**model_params)
trainer = pl.Trainer(**trainer_params)

In [None]:
trainer.fit(model)