In [1]:
import sys
sys.path.append('..')

import os
from pathlib import Path

In [2]:
import torch
import pytorch_lightning as pl

from baard.classifiers.cifar10_resnet18 import CIFAR10_ResNet18

In [3]:
PATH_ROOT = Path(os.getcwd()).absolute().parent
print('PATH_ROOT:', PATH_ROOT)
PATH_CHECKPOINT = os.path.join(PATH_ROOT, 'pretrained_clf', 'cifar10_resnet18.ckpt')
if os.path.isfile(PATH_CHECKPOINT):
    print('PATH_CHECKPOINT:', PATH_CHECKPOINT)
else:
    raise FileExistsError('Cannot find PyTorch lightning checkpoint. Check file name!')


PATH_ROOT: /home/lukec/workspace/baard_v4
PATH_CHECKPOINT: /home/lukec/workspace/baard_v4/pretrained_clf/cifar10_resnet18.ckpt


In [4]:
model = CIFAR10_ResNet18.load_from_checkpoint(PATH_CHECKPOINT)

In [6]:
from torchinfo import summary

batch_size = model.train_dataloader().batch_size
summary(model, input_size=(batch_size, 3, 32, 32))

Files already downloaded and verified


Layer (type:depth-idx)                        Output Shape              Param #
CIFAR10_ResNet18                              [256, 10]                 --
├─ResNet: 1-1                                 [256, 10]                 --
│    └─Conv2d: 2-1                            [256, 64, 32, 32]         1,728
│    └─BatchNorm2d: 2-2                       [256, 64, 32, 32]         128
│    └─ReLU: 2-3                              [256, 64, 32, 32]         --
│    └─Identity: 2-4                          [256, 64, 32, 32]         --
│    └─Sequential: 2-5                        [256, 64, 32, 32]         --
│    │    └─BasicBlock: 3-1                   [256, 64, 32, 32]         73,984
│    │    └─BasicBlock: 3-2                   [256, 64, 32, 32]         73,984
│    └─Sequential: 2-6                        [256, 128, 16, 16]        --
│    │    └─BasicBlock: 3-3                   [256, 128, 16, 16]        230,144
│    │    └─BasicBlock: 3-4                   [256, 128, 16, 16]        295,42

In [7]:
print(model)

CIFAR10_ResNet18(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Identity()
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
     

In [None]:
trainer = pl.Trainer(accelerator='auto', logger=False)
trainer.test(model, model.val_dataloader())