In [1]:
import sys
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')

import torch
import torchvision as tv
import pytorch_lightning as pl
import webdataset as wds

In [2]:
class ResNet(pl.LightningModule):
    
    def __init__(self, num_classes, resnet_version,
                 train_path, val_path, optimizer='adamw',
                 lr=1e-3, batch_size=64,
                 dataloader_workers=4, 
                 *args, **kwargs):
        super().__init__()
        
        self.__dict__.update(locals())
        
        resnets = {
            18:tv.models.resnet18,
            34:tv.models.resnet34,
            50:tv.models.resnet50,
            101:tv.models.resnet101,
            152:tv.models.resnet152
        }
        
        optimizers = {
            'adam': torch.optim.Adam,
            'adamw': torch.optim.AdamW,
            'sgd': torch.optim.SGD
        }
        
        self.optimizer = optimizers[optimizer]
        self.criterion = torch.nn.CrossEntropyLoss()
        
        self.model = resnets[resnet_version]()
        linear_size = list(self.model.children())[-1].in_features
        self.model.fc = torch.nn.Linear(linear_size, num_classes)
        
    def forward(self, X):
        return self.model(X)
    
    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.lr)
    
    def train_dataloader(self):
        preproc = tv.transforms.Compose([
                tv.transforms.ToTensor(),
                tv.transforms.Normalize((0.485, 0.456, 0.406), 
                                        (0.229, 0.224, 0.225)),
                tv.transforms.Resize((224, 224))
            ])
        dataset = wds.WebDataset(self.train_path).shuffle(1024) \
                        .decode("pil").to_tuple("jpeg", "cls").map_tuple(preproc, lambda x:x)
        return torch.utils.data.DataLoader(dataset, 
                                           num_workers=self.dataloader_workers, 
                                           batch_size=self.batch_size)
    
    def val_dataloader(self):
        preproc = tv.transforms.Compose([
                tv.transforms.ToTensor(),
                tv.transforms.Normalize((0.485, 0.456, 0.406), 
                                        (0.229, 0.224, 0.225)),
                tv.transforms.Resize((224, 224))
            ])
        dataset = wds.WebDataset(self.val_path).shuffle(1024) \
                        .decode("pil").to_tuple("jpeg", "cls").map_tuple(preproc, lambda x:x)
        return torch.utils.data.DataLoader(dataset, 
                                           num_workers=self.dataloader_workers, 
                                           batch_size=self.batch_size)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (y == torch.argmax(preds, 1)).type(torch.FloatTensor).mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.criterion(preds, y)
        acc = (y == torch.argmax(preds, 1)).type(torch.FloatTensor).mean()
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=False, logger=True)
    

In [3]:
class PlSageMakerLogger(pl.Callback):
    
    def __init__(self, frequency=10):
        self.frequency=frequency
        self.step = 0
        
    def on_epoch_start(self, trainer, module, *args, **kwargs):
        self.inner_step = 0
    
    def on_train_batch_end(self, trainer, module, *args, **kwargs):
        if self.inner_step%self.frequency==0:
            print(' '.join(["{0}: {1:.4f}".format(i, float(j)) for i,j in trainer.logged_metrics.items()]))
        self.inner_step += 1
        self.step += 1

In [4]:
model_params = {'num_classes': 1000,
                'resnet_version': 50,
                'train_path': 'pipe:aws s3 cp s3://jbsnyder-sagemaker-us-east/data/imagenet/train/train_{0000..2047}.tar -',
                'val_path': 'pipe:aws s3 cp s3://jbsnyder-sagemaker-us-east/data/imagenet/val/val_{0000..0127}.tar -',
                'optimizer': 'adamw',
                'lr': 1e-3, 
                'batch_size': 64,
                'dataloader_workers': 0}

trainer_params = {'gpus': torch.cuda.device_count(),
                  'num_nodes': 0,
                  'strategy': 'ddp' if torch.cuda.device_count()>1 else None,
                  'max_epochs': 4,
                  'amp_backend': 'apex',
                  'amp_level': 'O2',
                  'precision': 16,
                  'progress_bar_refresh_rate': 0,
                  'callbacks': [PlSageMakerLogger()]
                  }

In [5]:
model = ResNet(**model_params)
trainer = pl.Trainer(**trainer_params)

Using 16bit apex Automatic Mixed Precision (AMP)
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [6]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | model     | ResNet           | 25.6 M
-----------------------------------------------
25.6 M    Trainable params
0         Non-trainable params
25.6 M    Total params
102.228   Total estimated model params size (MB)


Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
[2022-01-04 00:32:06.148 pytorch-1-8-gpu-py36-ml-p3-2xlarge-84493874ea1d5c2b56c14072735a:17560 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-01-04 00:32:06.177 pytorch-1-8-gpu-py36-ml-p3-2xlarge-84493874ea1d5c2b56c14072735a:17560 INFO profiler_config_parser.py:102] Unable to find config 

In [7]:
dataloader = model.train_dataloader()

In [8]:
len(dataloader)

TypeError: object of type 'Processor' has no len()