## 1. Import Library

In [24]:
from copy import deepcopy

import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader

from torchvision import models, transforms
from torchvision.datasets import CIFAR10

from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from torchmetrics import Accuracy, MetricCollection

## 2. Define Dataloader

In [22]:
class CIFAR(LightningDataModule):
    def __init__(self, img_size=32, batch_size=32):
        super().__init__()
        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.batch_size = batch_size
        self.train_transforms = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.Pad(4, padding_mode='reflect'),
            transforms.RandomCrop(self.img_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.test_transforms = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.CenterCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

    def prepare_data(self) -> None:
        CIFAR10(root='data', train=True, download=True)
        CIFAR10(root='data', train=False, download=True)
    
    def setup(self, stage=None):
        self.train_ds = CIFAR10(root='data', train=True, download=False, transform=self.train_transforms)
        self.valid_ds = CIFAR10(root='data', train=False, download=False, transform=self.test_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_ds, num_workers=4, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_ds, num_workers=4, batch_size=self.batch_size, shuffle=False)

## 3. Define Model & EMA

In [28]:
class EMA(nn.Module):
    """ Model Exponential Moving Average V2 from timm"""
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)
    

class BasicModule(LightningModule):
    def __init__(self, lr=0.01, use_ema=False):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model_ema = EMA(self.model, decay=0.9) if use_ema else None
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        
        metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
        self.train_metric = metric.clone(prefix='train_')
        self.valid_metric = metric.clone(prefix='valid_')
    
    def training_step(self, batch, batch_idx, optimizer_idx=None):
        return self.shared_step(*batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(*batch, self.valid_metric)

    def shared_step(self, x, y, metric):
        y_hat = self.model(x) if self.training or self.model_ema is None else self.model_ema.module(x)
        loss = self.criterion(y_hat, y)
        self.log_dict(metric(y_hat, y), prog_bar=True)
        return loss

    def configure_optimizers(self):
        return SGD(self.model.parameters(), lr=self.lr)

    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:
            self.model_ema.update(self.model)

## 4. Train without EMA

In [20]:
data = CIFAR(batch_size=512)
model = BasicModule(lr=0.01, use_ema=False)
trainer = Trainer(max_epochs=2, gpus='5,', precision=16)
trainer.fit(model, data)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8]

  | Name         | Type             | Params
--------------------------------------------------
0 | model        | ResNet           | 11.7 M
1 | criterion    | CrossEntropyLoss | 0     
2 | train_metric | MetricCollection | 0     
3 | valid_metric | MetricCollection | 0     
--------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
23.379    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## 5. Train with EMA

In [29]:
data = CIFAR(batch_size=512)
model = BasicModule(lr=0.01, use_ema=True)
trainer = Trainer(max_epochs=2, gpus='5,', precision=16)
trainer.fit(model, data)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7,8]

  | Name         | Type             | Params
--------------------------------------------------
0 | model        | ResNet           | 11.7 M
1 | model_ema    | EMA              | 11.7 M
2 | criterion    | CrossEntropyLoss | 0     
3 | train_metric | MetricCollection | 0     
4 | valid_metric | MetricCollection | 0     
--------------------------------------------------
23.4 M    Trainable params
0         Non-trainable params
23.4 M    Total params
46.758    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]