### Must install Lightning

In [9]:
!pip3 install lightning-bolts==0.6.0 --quiet
!pip3 install pytorch-lightning==1.8.5 --quiet

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-gpu 2.9.1 requires absl-py>=1.0.0, which is not installed.
tensorflow-gpu 2.9.1 requires astunparse>=1.6.0, which is not installed.
tensorflow-gpu 2.9.1 requires flatbuffers<2,>=1.12, which is not installed.
tensorflow-gpu 2.9.1 requires gast<=0.4.0,>=0.2.1, which is not installed.
tensorflow-gpu 2.9.1 requires google-pasta>=0.1.1, which is not installed.
tensorflow-gpu 2.9.1 requires grpcio<2.0,>=1.24.3, which is not installed.
tensorflow-gpu 2.9.1 requires h5py>=2.9.0, which is not installed.
tensorflow-gpu 2.9.1 requires keras<2.10.0,>=2.9.0rc0, which is not installed.
tensorflow-gpu 2.9.1 requires keras-preprocessing>=1.1.1, which is not installed.
tensorflow-gpu 2.9.1 requires libclang>=13.0.0, which is not installed.
tensorflow-gpu 2.9.1 requires opt-einsum>=2.3.2, which is not installed.
tensorfl

In [12]:
from torchvision import transforms
import pytorch_lightning as pl
#from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torchvision.models.resnet import resnet18
from pytorch_lightning import Trainer, LightningModule
import torch.nn as nn
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint


In [30]:
EPOCHS = 5
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
PRINT_FREQ = 50
TRAIN_BATCH=128
VAL_BATCH=128

In [15]:
GPU = 0
cifar_mean_RGB = [0.4914, 0.4822, 0.4465]
cifar_std_RGB = [0.2023, 0.1994, 0.2010]

### fill in the transform statements below

In [16]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, train_batch_size, val_batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        
        self.transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean_RGB, cifar_std_RGB),
])
        self.transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean_RGB, cifar_std_RGB),
])
        
        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download 
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
#            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
#            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
            self.cifar_train = CIFAR10(self.data_dir, train=True, transform=self.transform_train)
            self.cifar_val = CIFAR10(self.data_dir, train=False, transform=self.transform_val)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform_val)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.train_batch_size, num_workers = 2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.val_batch_size, num_workers = 2)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size, num_workers = 2)

In [17]:
dm = CIFAR10DataModule(TRAIN_BATCH, VAL_BATCH)
dm.prepare_data()
dm.setup()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [18]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

In [19]:
# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

In [20]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

### Complete the training, validation, and optimizer methods below

In [21]:
class LitResnet18(LightningModule):
    def __init__(self, learning_rate, momentum, weight_decay):
        super().__init__()
        self.nn = resnet18(pretrained = False, progress  = True)
        self.nn.fc = nn.Linear(self.nn.fc.in_features, 10)
        self.lr = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss().cuda(GPU)
    
    def forward(self, x):
        return self.nn.forward(x)
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self.nn(x)
        loss = self.criterion(logits, y)
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task = "multiclass", num_classes = 10)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=False)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=False)
        if batch_idx % PRINT_FREQ == 0:
          print("train step! " + str(batch_idx) + " train loss: " + str(loss.item()) + " train acc " + str(acc.item()))        
        return loss     
        
        
    def validation_step(self, batch, batch_idx):
        x,y = batch
        logits = self.nn(x)
        loss = self.criterion(logits, y) 
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task = "multiclass", num_classes = 10)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        if batch_idx % PRINT_FREQ == 0:
          print("val step! " + str(batch_idx) + " val loss: " + str(loss.item()) + " val acc " + str(acc.item()))
        return loss  
        
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(model.parameters(), self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        return optimizer

In [22]:
# model = resnet18(pretrained = False, progress  = True)
model = LitResnet18(LR, MOMENTUM, WEIGHT_DECAY)




In [32]:
# Initialize a trainer
trainer = pl.Trainer(max_epochs=EPOCHS,
                     gpus=1,
                     logger=None,
                     callbacks=[early_stop_callback, checkpoint_callback],
                     )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [33]:
trainer.fit(model, dm)

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type             | Params
-----------------------------------------------
0 | nn        | ResNet           | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

val step! 0 val loss: 2.372194290161133 val acc 0.1328125


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

train step! 0 train loss: 2.4847490787506104 train acc 0.140625
train step! 50 train loss: 2.7515175342559814 train acc 0.21875
train step! 100 train loss: 1.8652034997940063 train acc 0.3125
train step! 150 train loss: 1.8951433897018433 train acc 0.328125
train step! 200 train loss: 2.0577919483184814 train acc 0.296875
train step! 250 train loss: 1.7193142175674438 train acc 0.390625
train step! 300 train loss: 1.6305127143859863 train acc 0.46875
train step! 350 train loss: 1.6575597524642944 train acc 0.40625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.5634247064590454 val acc 0.46875
val step! 50 val loss: 1.6393826007843018 val acc 0.4609375
train step! 0 train loss: 1.61439049243927 train acc 0.3828125
train step! 50 train loss: 1.5394006967544556 train acc 0.4296875
train step! 100 train loss: 1.6740747690200806 train acc 0.375
train step! 150 train loss: 1.492903709411621 train acc 0.4453125
train step! 200 train loss: 1.5820711851119995 train acc 0.4375
train step! 250 train loss: 1.4094233512878418 train acc 0.4765625
train step! 300 train loss: 1.2854712009429932 train acc 0.53125
train step! 350 train loss: 1.3089112043380737 train acc 0.5625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.2204228639602661 val acc 0.5859375
val step! 50 val loss: 1.318739414215088 val acc 0.5546875
train step! 0 train loss: 1.3242756128311157 train acc 0.546875
train step! 50 train loss: 1.3336715698242188 train acc 0.5
train step! 100 train loss: 1.3738987445831299 train acc 0.46875
train step! 150 train loss: 1.2089847326278687 train acc 0.5546875
train step! 200 train loss: 1.1186171770095825 train acc 0.546875
train step! 250 train loss: 1.1720494031906128 train acc 0.578125
train step! 300 train loss: 1.1128743886947632 train acc 0.5390625
train step! 350 train loss: 1.1373581886291504 train acc 0.625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.0680311918258667 val acc 0.609375
val step! 50 val loss: 1.0926347970962524 val acc 0.625
train step! 0 train loss: 1.08979332447052 train acc 0.5859375
train step! 50 train loss: 0.9772717356681824 train acc 0.65625
train step! 100 train loss: 1.1656326055526733 train acc 0.59375
train step! 150 train loss: 1.0862491130828857 train acc 0.578125
train step! 200 train loss: 1.189758539199829 train acc 0.5703125
train step! 250 train loss: 1.0378353595733643 train acc 0.640625
train step! 300 train loss: 1.2769482135772705 train acc 0.53125
train step! 350 train loss: 1.0350534915924072 train acc 0.65625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.223374843597412 val acc 0.5703125
val step! 50 val loss: 1.3671236038208008 val acc 0.5390625
train step! 0 train loss: 1.3961710929870605 train acc 0.5
train step! 50 train loss: 1.001551628112793 train acc 0.65625
train step! 100 train loss: 0.9258986711502075 train acc 0.6328125
train step! 150 train loss: 1.1131962537765503 train acc 0.640625
train step! 200 train loss: 0.9204949736595154 train acc 0.6640625
train step! 250 train loss: 1.1459896564483643 train acc 0.59375
train step! 300 train loss: 1.0398625135421753 train acc 0.65625
train step! 350 train loss: 1.1226768493652344 train acc 0.625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.9221568703651428 val acc 0.7421875
val step! 50 val loss: 0.9443749785423279 val acc 0.671875


`Trainer.fit` stopped: `max_epochs=5` reached.
