# ResNet v2 model 학습 및 평가 실습
- dataset: cifar10
- AI tool: **pytorch**
- Reference
  * [cifar10-resnet-keras](https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras/blob/master/chapter2-deep-networks/resnet-cifar10-2.2.1.py)
  * [cifar10-tutorial](https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
  * [cifar10-resnet-pytorch](https://www.kaggle.com/code/kannapat/cifar10-with-vgg-and-resnet-in-pytorch)
  * [torchvision.models.resnet](https://docs.pytorch.org/vision/0.9/_modules/torchvision/models/resnet.html)

### print out model summary 
- install torchinfo to use sumamry of model
```python

try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

from torchinfo import summary
```

In [None]:
# Import PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split

# Import torchvison
import torchvision
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision.transforms import v2
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
from torchvision.io import decode_image
# Import dataset to load cifar10
from torchvision import datasets

from torchinfo import summary
# PIL Image
from PIL import Image
# Import for model evaluation
# for plot
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sn

# for confusion matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from datetime import datetime
from timeit import default_timer as timer

import random
import os

from tqdm.auto import tqdm

## 0. Hyperparameters

In [None]:
# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device:{device}")

cache_dir = 'D:\\HF_cache'

In [None]:
# Configurations
num_classes = 10 # cifar10 classes : fixed
batch_size = 128 # 32, 64, 128, 256 # orig paper trained all networks with batch_size=128
epochs = 120

### Model Parameters

In [None]:
# Model version
version = 2 # fixed
n = 2 # number of residual blocks per stage
# detpth = n * 9 + 2 # ResNet-v2 depth calculation
depth = n * 9 + 2

# model name, depth and version
model_type = 'ResNet%dv%d' % (depth, version)

## 1. Dataset Preparation
- torch의 torch.utils.data.dataset, torch.utils.data.dataloader 사용
  * dataset: data를 일정한 포맷으로 정리해서 넣어 두고 필요할 때 하나씩 꺼낼 수 있도록 정의하고 있음 (x, y)
  * dataloader: batch 단위로 꺼내 주고, shuffle 또는 병렬로 꺼내주는 기능
- 참조 [data_tutorial](https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html), 
[한국어](https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html)
### Loading the Data
- dataset load시에 사용할 data augmentation 방법 지정
- Data Augmentation and preprocessing
  * Simple normalization ([0,255] --> [-1,1]) + alpha
  * 학습 데이터의 mean과 std를 구하여 normalize할 수도 있습니다.
      * mean = cifar10_mean
      * std = cifar10_std
      
### 1.1 GPU 전체 데이터셋을 올림.
- 속도가 빠름, GPU memory 사용
- 데이터셋이 GPU memory에 올라갈 정도로 작은 경우만 가능

In [None]:
# dataset data augmentation for training and test dataset
data_augmentation = True # default=True

# dataset을 모두 gpu에 올려서 사용하면 true, batch 단위로 gpu에 copy하면 False
# dataset, dataloader, augmentation 방법에 차이가 있음
gpu_to_all = True # default=True

# zero_init_residual: ResNet 모델에서 Bottleneck arch 의 세번째 conv의 init를 zero로 initalization(True(default))
zero_init_residual = True

In [None]:
# GPU에 전체 데이터를 올림
def load_all_to_gpu(dataset):
    data_loader = DataLoader(dataset, batch_size=len(dataset))
    images, labels = next(iter(data_loader))
    return images.to(device), labels.to(device)
    
# Custom Dataset 클래스 정의 (transform 적용용)
class TransformedTensorDataset(Dataset):
    def __init__(self, images, labels, cifar10_class_names, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.classes = cifar10_class_names  #cifar10_dataset.classes ## class names[str]

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.labels)

In [None]:
class RandomHorizontalShiftWithZeroPad(v2.Transform):
    def __init__(self, shift_range):
        super().__init__()
        self.shift_range = shift_range # 이미지 너비 대비 비율

    def _get_params(self, flat_inputs):
        # 이동할 픽셀 수를 계산
        img_width = flat_inputs[0].shape[-1]
        max_shift_pixels = int(img_width * self.shift_range)
        
        # -max_shift_pixels부터 max_shift_pixels까지의 정수 랜덤 값
        shift_pixels = torch.randint(-max_shift_pixels, max_shift_pixels + 1, (1,)).item()
        return shift_pixels

    def _transform(self, inpt, params):
        shift_pixels = params
        img_height, img_width = inpt.shape[-2], inpt.shape[-1]
        
        # 원본 이미지보다 큰 제로 텐서 생성
        padded_width = img_width + abs(shift_pixels) * 2
        padded_height = img_height
        
        # 패딩된 텐서를 GPU에 생성
        zero_padded_tensor = torch.zeros(inpt.shape[:-2] + (padded_height, padded_width),
                                         dtype=inpt.dtype, device=inpt.device)
        
        # 원본 이미지를 랜덤하게 패딩된 텐서에 복사
        x_start = abs(shift_pixels) + shift_pixels
        zero_padded_tensor[..., :, x_start:x_start+img_width] = inpt
        
        # 중심에서 원본 크기만큼 크롭
        x_crop_start = (padded_width - img_width) // 2
        shifted_inpt = zero_padded_tensor[..., :, x_crop_start:x_crop_start+img_width]
        
        return shifted_inpt

class RandomVerticalShiftWithZeroPad(v2.Transform):
    def __init__(self, shift_range):
        super().__init__()
        # 이미지 높이 대비 이동 비율
        self.shift_range = shift_range 

    def _get_params(self, flat_inputs):
        # 이미지 높이를 기반으로 이동할 픽셀 수를 계산
        img_height = flat_inputs[0].shape[-2]
        max_shift_pixels = int(img_height * self.shift_range)
        
        # -max_shift_pixels부터 max_shift_pixels까지의 정수 랜덤 값
        shift_pixels = torch.randint(-max_shift_pixels, max_shift_pixels + 1, (1,)).item()
        return shift_pixels

    def _transform(self, inpt, params):
        shift_pixels = params
        img_height, img_width = inpt.shape[-2], inpt.shape[-1]
        
        # 원본 이미지보다 큰 제로 텐서 생성
        padded_width = img_width
        padded_height = img_height + abs(shift_pixels) * 2
        
        # 패딩된 텐서를 GPU에 생성
        zero_padded_tensor = torch.zeros(inpt.shape[:-2] + (padded_height, padded_width),
                                         dtype=inpt.dtype, device=inpt.device)
        
        # 원본 이미지를 랜덤하게 패딩된 텐서에 복사
        y_start = abs(shift_pixels) + shift_pixels
        zero_padded_tensor[..., y_start:y_start+img_height, :] = inpt
        
        # 중심에서 원본 크기만큼 크롭
        y_crop_start = (padded_height - img_height) // 2
        shifted_inpt = zero_padded_tensor[..., y_crop_start:y_crop_start+img_height, :]
        
        return shifted_inpt

In [None]:
##### batch_size, shuffle=True for train_dataloader
def create_dataloader_all_to_gpu(batch_size_, val_split_ratio = 0.2):
    # 1. CIFAR-10 다운로드 및 변환 (ToTensor로만 변환, Normalize는 선택)
    tensor_transform_a = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 정규화 : y = (x-m)/std
    ])
    
    '''
        datasets.CIFAR10(root, train, download, transform)
        root='./data', # where to download data to (cache_dir)
        train=True, # get training data(True) or test data(False)
        download=True, # download data if it doesn't exist on disk
        transform=tensor_transform, # images come as PIL format, we want to turn into Torch tensors
        target_transform=None # you can transform labels as well
    '''
    # 1. CIFAR-10 다운로드 및 변환 (ToTensor로만 변환, Normalize는 선택)
    cifar10_train_dataset = datasets.CIFAR10(root=cache_dir, train=True, download=True, transform=tensor_transform_a)
    cifar10_test_dataset = datasets.CIFAR10(root=cache_dir, train=False, download=True, transform=tensor_transform_a)

    # split train to train, val
    val_ratio = val_split_ratio
    if val_split_ratio <=0: 
        val_ratio = 0.1
        
    val_dataset_count = int(len(cifar10_train_dataset) * val_ratio) 
    train_dataset_count = len(cifar10_train_dataset) - val_dataset_count
    train_dataset, val_dataset = random_split(dataset=cifar10_train_dataset, 
                                              lengths=[train_dataset_count, val_dataset_count],
                                              generator=torch.Generator().manual_seed(1234))
    print (f'train dataset: {len(train_dataset)}')
    print (f'val dataset: {len(val_dataset)}')

    test_dataset = cifar10_test_dataset
    # 2. 전체 데이터를 한 번에 불러와서 GPU로 옮김
    train_images, train_labels = load_all_to_gpu(train_dataset)
    val_images, val_labels = load_all_to_gpu(val_dataset)
    test_images, test_labels = load_all_to_gpu(test_dataset)
    
    # Transform 정의
    if data_augmentation:
        train_transform_a = v2.Compose([
            v2.RandomHorizontalFlip(),
            RandomHorizontalShiftWithZeroPad(shift_range=0.1), # 이미지 너비의 최대 10%만큼 이동
            RandomVerticalShiftWithZeroPad(shift_range=0.1), # 이미지 높이의 최대 10%만큼 이동
        ])
    else:
        train_transform_a = v2.Compose([
            v2.RandomHorizontalFlip()
        ])

    print(cifar10_train_dataset.classes)
    print(val_dataset)
    # 3. Transform이 적용된 GPU에 있는 TensorDataset 만들기
    train_dataset_a = TransformedTensorDataset(train_images, train_labels, cifar10_train_dataset.classes, transform=train_transform_a)
    val_dataset_a = TransformedTensorDataset(val_images, val_labels, cifar10_train_dataset.classes, transform=None)
    test_dataset_a = TransformedTensorDataset(test_images, test_labels, cifar10_test_dataset.classes, transform=None)
    
    # 4. DataLoader 구성 (shuffle 등은 가능)
    train_dataloader_a = DataLoader(train_dataset_a, batch_size=batch_size_, shuffle=True) #pin_memory=True works only for cpu
    val_dataloader_a = DataLoader(val_dataset_a, batch_size=batch_size_)
    test_dataloader_a = DataLoader(test_dataset_a, batch_size=batch_size_)
    return (train_dataloader_a, val_dataloader_a, test_dataloader_a, 
            train_dataset_a, val_dataset_a, test_dataset_a)

### 1.2 Standard pytorch pattern
- dataloader에서 cpu 통해 dataset 매번 mini-batch 단위로 copy해서 gpu에 올려서 사용
- 많은 데이터에도 대응 가능
- 느림

In [None]:
# batch_size, max_num_workers
def create_dataloader_batch(batch_size_, val_split_ratio=0.1, max_num_workers=2):
        
    # Simple normalize to [-1, 1]
    mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

    if data_augmentation:
        train_transform_s = transforms.Compose([
            transforms.RandomHorizontalFlip(), # .cuda()
            transforms.RandomAffine(degrees=0, translate=(0.1,0.1)), # horizontal, vertial shift 10%
            transforms.ToTensor(), # Turn the image into a torch.Tensor
            transforms.Normalize(mean, std),
        ])
    else:
        train_transform_s = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), # Turn the image into a torch.Tensor
            transforms.Normalize(mean, std),
        ])
    
    test_transform_s = transforms.Compose([
        # Turn the image into a torch.Tensor
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
        
    # Load training dataset and preprocess them into torch tensors
    # Usage:
    # tensor_image, label = train_data[index] : # 0: image in torch.Tensor, 1: label (integer)
    cifar10_train_dataset_s = datasets.CIFAR10(
        root=cache_dir, # where to download data to?
        train=True, # get training data
        download=True, # download data if it doesn't exist on disk
        transform=train_transform_s, # images come as PIL format, we want to turn into Torch tensors
        target_transform=None # you can transform labels as well
    )
    
    # Setup testing data
    cifar10_test_dataset_s = datasets.CIFAR10(
        root=cache_dir,
        train=False, # get test data
        download=True,
        transform=test_transform_s
    )
    
    # split train to train, val
    val_ratio = val_split_ratio
    if val_split_ratio <=0: 
        val_ratio = 0.1
        
    val_dataset_count = int(len(cifar10_train_dataset_s) * val_ratio)
    train_dataset_count = len(cifar10_train_dataset_s) - val_dataset_count
    train_dataset_s, val_dataset_s = random_split(dataset=cifar10_train_dataset_s, 
                                              lengths=[train_dataset_count, val_dataset_count],
                                              generator=torch.Generator().manual_seed(1234))
    print (f'train dataset: {len(train_dataset_s)}')
    print (f'val dataset: {len(val_dataset_s)}')

    test_dataset_s = cifar10_test_dataset_s 
           
    # 2. 전체 데이터를 한 번에 불러와서 GPU로 옮김
    # load data into dataloader for training
    num_workers = os.cpu_count()
    if num_workers > max_num_workers:
        num_workers = max_num_workers
        
    train_dataloader_s = DataLoader(dataset=train_dataset_s,
                                  batch_size=batch_size_,
                                  num_workers=num_workers,
                                  shuffle=True, 
                                  pin_memory=True) #pin_memory=True works only for cpu
    
    val_dataloader_s = DataLoader(dataset=val_dataset_s,
                                 batch_size=batch_size_,
                                 num_workers=num_workers,
                                 shuffle=False, 
                                 pin_memory=True)
    
    test_dataloader_s = DataLoader(dataset=test_dataset_s,
                                 batch_size=batch_size_,
                                 num_workers=num_workers,
                                 shuffle=False, 
                                 pin_memory=True)
    
    return (train_dataloader_s, val_dataloader_s, test_dataloader_s, 
            train_dataset_s, val_dataloader_s, test_dataset_s, 
            num_workers)

In [None]:
(train_dataloader_a, val_dataloader_a, test_dataloader_a, train_dataset_a, val_dataset_a, test_dataset_a) =  create_dataloader_all_to_gpu(batch_size_=batch_size, val_split_ratio=0.2)

In [None]:
(train_dataloader_s, val_dataloader_s, test_dataloader_s, train_dataset_s, val_dataset_s, test_dataset_s, num_workers) = create_dataloader_batch(batch_size_=batch_size, val_split_ratio=0.1, max_num_workers=10)

### load된 정보 일부 확인

In [None]:
print (f'num_train_dataset = {len(train_dataloader_a)}, steps/epoch = {len(train_dataset_a)/batch_size}')
print (f'num_test_dataset = {len(test_dataloader_a)}')

In [None]:
# dataloader 사용 예시
for x, y in train_dataloader_a:
    print(x.shape, x.device, x.min().item(), x.max().item())  # Normalize 확인
    print(x.device, y.device)  # GPU
    break

In [None]:
# tensor x가 gpu에 있을 때, value access 차이: x (device 정보 포함), x.item()과 x.cpu() (tensor), x.cpu().item()(float) 
x.min(), x.min().item(), x.min().cpu().item(), type(x.min().item())

In [None]:
print (f'num_train_dataset = {len(train_dataloader_s)}, steps/epoch = {len(train_dataset_s)/batch_size}')
print (f'num_test_dataset = {len(test_dataloader_s)}')

In [None]:
print(train_dataset_a.images.shape, train_dataset_a.images.device)
print(train_dataset_a.labels.shape,  train_dataset_a.labels.device) 
print(test_dataset_a.images.shape, test_dataset_a.images.device, test_dataset_a.labels.shape, test_dataset_a.labels.device)

## 1.2 mini-batch 단위로 dataloading 

In [None]:
if gpu_to_all == True: 
    train_dataset = train_dataset_a
    test_dataset = test_dataset_a 
else:
    train_dataset = train_dataset_s
    test_dataset = test_dataset_s 
    print(num_workers, batch_size, len(train_dataset_s), len(train_dataset_s)/batch_size)

In [None]:
# see classes
class_names = test_dataset.classes
class_names

In [None]:
train_dataset

In [None]:
test_dataset

In [None]:
print('train image:', train_dataset_s[0][0].shape, train_dataset_s[0][0].device) # image
print('train label:', train_dataset_s[0][1]) # label, scalar(int) on cpu
print('test image:', test_dataset_s[0][0].shape, test_dataset_s[0][0].device)
print('test label:', test_dataset_s[0][1])

### load한 dataset에서 sampling해서 보기 
-  look into some of the image data from the dataset we've downloaded

In [None]:
# show a (normalized ) torch tensor image
# tensor_range=0: in [-1,1]
# tensor_range=1: in [0,1]
def imdisplay_tensor(img, tensor_range=0):
    if tensor_range == 0:
        img = img/2 + 0.5 # unnormalize
        npimg = img.clamp(min=0, max=1).numpy()
    else:
        npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0))) # (C, H, W) ---> (H, W, C)
    plt.show()
#or plt.imshow(image_tensor.permute(1,2,0).clamp(min=0, max=1))

def imdisplay(im, tensor_range=0):
    if isinstance(im, torch.Tensor):
        imdisplay_tensor(im, tensor_range)
    elif isinstance(im, numpy.ndarray):
        plt.imshow(np.transpose(npimg, (1, 2, 0))) # (C, H, W) ---> (H, W, C)
        plt.show()
    elif isinstance(im, PIL.Image):
        plt.imshow(im)
        plt.show()
        

In [None]:
# display 

def display_cifar10_sample(dataset, rand_idx):
    plt.figure(figsize=(7,7))
    
    for i, idx in enumerate(rand_idx):
    
        img, label = dataset[idx] # torch.Tensor, int(label index)

        img = img.cpu() # gpu to cpu 
        img = img/2 + 0.5 # unnormalize from [-1, 1] to [0, 1]
    
        img_class = class_names[label]
    
        plt.subplot(4,4,i+1)
        plt.imshow(img.permute(1,2,0).clamp(min=0, max=1))
        plt.title(f"Class : {img_class}",fontsize=10)
        plt.axis(False)

In [None]:
rand_idx = random.sample(range(len(test_dataset)),k=16)
display_cifar10_sample(test_dataset, rand_idx)

In [None]:
## display the augmented images in the train dataset 

def display_cifar10_transformed_sample(x_dataloader):
    plt.figure(figsize=(7,7))
    n = 16 
    i = 0
    n_imgs = n // 4
    for batch_idx, (inputs, targets) in enumerate(x_dataloader):
        imgs = inputs 
        for k, img in enumerate(inputs):
            img = img.cpu() # gpu to cpu 
            img = img/2 + 0.5 # unnormalize from [-1, 1] to [0, 1]
            label = targets[k].cpu() 
            img_class = class_names[label]
    
            plt.subplot(n_imgs, n_imgs,i+1)
            plt.imshow(img.permute(1,2,0).clamp(min=0, max=1))
            plt.title(f"Class : {img_class}",fontsize=10)
            plt.axis(False)
            i+= 1
            if i >= n:
                break
        if i>= n:
            break

In [None]:
display_cifar10_transformed_sample(train_dataloader_a)

## 2.Modeling

### ResNet v2 model

In [None]:
# ResNet Layer
class ResNetLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, activation=True, batch_norm=True, conv_first=True):
        super(ResNetLayer, self).__init__()
        layers = []

        padding = 0 if kernel_size == 1 else 1 # kernel_size = 3 : padding = 1
        if conv_first:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=False)) # bias=False
            if batch_norm:
                layers.append(nn.BatchNorm2d(out_channels))
            if activation:
                layers.append(nn.ReLU(inplace=True)) ### inplace=False
        else:
            if batch_norm:
                layers.append(nn.BatchNorm2d(in_channels))
            if activation:
                layers.append(nn.ReLU(inplace=True)) ### inplace=False
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=False))

        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

# Bottleneck Residual Unit
class BottleneckResidualUnit(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride=1, activation=True, batch_norm=True):
        super(BottleneckResidualUnit, self).__init__()

        # shortcut connection 
        self.shortcut = None
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

        # Bottleneck layers
        self.conv1 = ResNetLayer(in_channels, mid_channels, kernel_size=1, stride=stride,
                                 activation=activation, batch_norm=batch_norm, conv_first=False)
        self.conv2 = ResNetLayer(mid_channels, mid_channels, kernel_size=3, stride=1,
                                 activation=True, batch_norm=True, conv_first=False)
        self.conv3 = ResNetLayer(mid_channels, out_channels, kernel_size=1, stride=1,
                                 activation=True, batch_norm=True, conv_first=False)

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        if self.shortcut is not None:
            y += self.shortcut(x)
        else:
            y += x
        return y

# ResNet v2 model 
class ResNetV2(nn.Module):
    def __init__(self, depth, num_classes=10, loss_fn=nn.CrossEntropyLoss(), zero_init_residual=True, debug_mode=False):
        super(ResNetV2, self).__init__()
        self.debug = debug_mode
        
        self.loss_fn = nn.CrossEntropyLoss()
        
        self.num_filters_in = 16
        num_res_blocks = (depth - 2) // 9

        # first convolution
        self.conv1 = ResNetLayer(3, self.num_filters_in, conv_first=True)

        # stage 0
        self.stage0, num_filters_in = self._make_stage(0, num_res_blocks, self.num_filters_in)

        # stage 1
        self.stage1, num_filters_in = self._make_stage(1, num_res_blocks, num_filters_in)
    
        # stage 2
        self.stage2, num_filters_in = self._make_stage(2, num_res_blocks, num_filters_in)

        # final layers
        self.bn = nn.BatchNorm2d(num_filters_in)
        self.relu = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_filters_in, num_classes)

        # parameter initialization 
        self.zero_init_residual = zero_init_residual
        self.init_weights()

        # check model 
        if self.debug: 
            self.test_print()
            
    def _make_stage(self, stage, num_res_blocks, num_filters_in=0):
        layers = []
        stride = 1
        if num_filters_in == 0:
            num_filters_in = self.num_filters_in

        for res_block in range(num_res_blocks):
            # 첫번째 블록에서만 다운샘플링 stride=2
            if stage > 0 and res_block == 0:
                stride = 2 # downsampling
            else:
                stride = 1

            if stage == 0:
                num_filters_out = num_filters_in * 4
                # 첫번째 stage, 첫번째 블록은 활성화, 배치정규화 안 함
                if res_block == 0:
                    activation = False
                    batch_norm = False
                else:
                    activation = True
                    batch_norm = True
            else:
                num_filters_out = num_filters_in * 2
                activation = True
                batch_norm = True

            if res_block == 0:
                num_filters_mid = num_filters_in
                layers.append(BottleneckResidualUnit(
                    num_filters_in, num_filters_mid, num_filters_out,
                    stride=stride,
                    activation=activation, batch_norm=batch_norm
                ))
            else: 
                # mid_channels 계산: stage0에서는 out_channels//4, 나머지 stage에서는 out_channels//2
                if stage == 0:
                    num_filters_mid = num_filters_out // 4
                else:
                    num_filters_mid = num_filters_out // 2
                layers.append(BottleneckResidualUnit(
                    num_filters_out, num_filters_mid, num_filters_out,
                    stride=1, 
                    activation=activation, batch_norm=batch_norm
                ))

        return nn.Sequential(*layers), num_filters_out

    def init_weights(self):
        # parameter initialization 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
            else:
                if isinstance(m, nn.Sequential) != True and \
                   isinstance(m, BottleneckResidualUnit) != True and \
                   isinstance(m, ResNetLayer) != True and \
                   isinstance(m, nn.ReLU) != True and \
                   isinstance(m, nn.AdaptiveAvgPool2d):
                    if self.debug:
                        print ("uninitialized module: ", m)
                
        # Zero-initialize the last BN in each residual branch
        # so that the residual branch starts with zeros,
        # and each residual block hehavies like identity

        if self.zero_init_residual:
            for m in self.modules():
                if isinstance(m, BottleneckResidualUnit):
                    # conv3.block[2]은 Conv2d layer을 0 으로 초기화 
                    if isinstance(m.conv3, ResNetLayer) and len(m.conv3.block) > 1:
                        if self.debug:
                            print((len(m.conv3.block)))
                        nn.init.constant_(m.conv3.block[2].weight, 0) 
                        
    def forward(self, x):
        x = self.conv1(x)
        x = self.stage0(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def test_print(self): # , num_res_blocks, num_filters_in):
        x = torch.randn(1, 3, 32, 32).to(device)
        print('input:', x.shape)
        x = self.conv1(x)
        print('after conv1:', x.shape)
        x = self.stage0(x)
        print('after stage0:', x.shape)
        x = self.stage1(x)
        print('after stage1:', x.shape)
        x = self.stage2(x)
        print('after stage2:', x.shape)
        x = self.bn(x)
        x = self.relu(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        

In [None]:
# initialize a model
model = ResNetV2(depth=depth, num_classes=num_classes,zero_init_residual=zero_init_residual).to(device)
print(model)

In [None]:
# initialize a model
model = ResNetV2(depth=depth, num_classes=num_classes).to(device)
print(model)

In [None]:
summary(model, input_size=[1, 3, 32, 32])

In [None]:
#model.test_print()

## 3.Training
### Let's train
- optimizer, loss function(loss_fn), scheduler


In [None]:
## lr scheduler
def lr_schedule(epoch):
    lr = 1e-3  # initial learning rate
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 130:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    
    return lr

In [None]:
# run training, with or without data adata_augmentation

# loss and optimizer
# loss_fn = nn.CrossEntropyLoss()
print ("Loss function:", model.loss_fn, "== nn.CrossEntropyLoss()")
optimizer = torch.optim.Adam(model.parameters(), lr=lr_schedule(0))

# learning rate scheduler
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_schedule(epoch) / lr_schedule(0))

In [None]:
# training function
def train_epoch(
                model,
                train_dataloader,
                optimizer,
                scheduler,
                epoch=-1, 
                verbose=0):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = model.loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if verbose and batch_idx % 100 == 0:
            if epoch >=0:
                print(f'Epoch: {epoch+1} | Batch: {batch_idx+1}/{len(train_dataloader)} | '
                      f'Loss: {loss.item():.4f} | Accuracy: {correct/total:.2f}')
            else:
                print(f'Batch: {batch_idx+1}/{len(train_dataloader)} | '
                      f'Loss: {loss.item():.4f} | Accuracy: {correct/total:.2f}')

    scheduler.step()

    train_loss = running_loss / len(train_dataloader)
    train_acc = correct /total

    return train_loss, train_acc

# evaluate dataset to caluclate loss, accuracy
# return loss, accuracy
def evaluate_val(model, val_dataloader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_dataloader):
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            loss = model.loss_fn(outputs, targets)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
        val_loss = running_loss / len(val_dataloader)
        val_acc = correct /total

    return val_loss, val_acc

In [None]:
# calcuate accuracy
def evaluate(model, loader):
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, t in loader:
            x = x.to(device=device, non_blocking=True)
            t = t.to(device=device, non_blocking=True)
            
            y = model(x)
            _, predictions = y.max(1)
            num_correct += (predictions == t).sum()
            num_samples += predictions.size(0)
        
        print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}') 
    
    model.train()
    
    return float(num_correct)/float(num_samples)

### Training Loop


In [None]:
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          val_dataloader: torch.utils.data.DataLoader,
          # test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler:torch.optim.lr_scheduler,
          #grad_clip:float=None,
          epochs: int = 10):

    history = {"train_loss": [],
                "train_accuracy": [],
                "val_loss": [],
                "val_accuracy": [],
                "lr": []
              }

    start_time = datetime.now()
    
    for epoch in tqdm(range(epochs)):
        start_time1 = datetime.now()
        train_loss, train_acc = train_epoch(model, train_dataloader, optimizer=optimizer, scheduler=scheduler, epoch=epoch)
    
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_acc)
        history['lr'].append(scheduler.get_last_lr()[0])
        if test_dataloader is not None:
            val_loss, val_acc = evaluate_val(model, val_dataloader)
            history['val_loss'].append(val_loss)
            history['val_accuracy'].append(val_acc)
        end_time1 = datetime.now()
        
        if test_dataloader is None:
            print(f'Epoch: {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f} |'
                  f'lr: {scheduler.get_last_lr()[0]:.4e}, {end_time1-start_time1} sec/epoch')
        else:
            print(f'Epoch: {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f} |'
                  f'Val. Loss: {val_loss:.4f} | Val. Acc: {val_acc:.2f} | '
                  f'lr: {scheduler.get_last_lr()[0]:.4e}, {end_time1-start_time1} sec/epoch')
    end_time = datetime.now()
    print(f'Training completed in: {end_time - start_time}')

    return history

In [None]:
# gpu_to_all = False
if gpu_to_all == True:
    train_dataloader =  train_dataloader_a
    val_dataloader = val_dataloader_a
    test_dataloader = test_dataloader_a
    train_dataset = train_dataset_a
    val_dataset = val_dataset_a
    test_dataset = test_dataset_a
else:
    train_dataloader =  train_dataloader_s
    val_dataloader = val_dataloader_s
    test_dataloader = test_dataloader_s
    train_dataset = train_dataset_s
    val_dataset = val_dataset_s
    test_dataset = test_dataset_s

In [None]:
# model.init_weights()

In [None]:
### Try initial run
# print information 
print (f'batch_size={batch_size}')
print (f'num_train_dataset = {len(train_dataloader)}, steps/epoch = {len(train_dataset)/batch_size}')
print (f'num_test_dataset = {len(test_dataloader)}, steps/epoch = {len(test_dataset)/batch_size}')
print ('all data on gpu' if gpu_to_all == True else 'moving by mini-batch from host to gpu')

In [None]:
# test run
# model_history1 = train(model, train_dataloader,  val_dataloader, optimizer, scheduler,  epochs=2)

In [None]:
# to compare the processing time with and without test_dataloader
# 시간이 많이 걸리면 test_dataloader를 None 또는 적은 횟수로 평가한다.
# model_history2 = train(model, train_dataloader,  test_dataloader=None, optimizer=optimizer, scheduler=scheduler,  epochs=2)

In [None]:
# Train
model_history = train(model, train_dataloader,  val_dataloader, optimizer=optimizer, scheduler=scheduler, 
                      epochs=epochs)

## Model Evaluation

In [None]:
# type
from typing import List , Dict , Tuple

# function to plot loss & accuracy curve
def plot_history(history: Dict[str, List[float]]):
    # Get the loss values of the results dictionary (training and val)
    loss = history['train_loss']

    # Get the accuracy values of the results dictionary (training and val)
    accuracy = history['train_accuracy']

    # Figure out how many epochs there were
    epochs = range(len(history['train_loss']))

    if 'val_loss' in history.keys() and len(epochs) == len(history['val_loss']):
        val_loss = history['val_loss']
        val_accuracy = history['val_accuracy']
        
    # Setup a plot
    plt.figure(figsize=(15, 7))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label='train_loss')
    if 'val_loss' in history.keys() and len(epochs) == len(history['val_loss']):
        plt.plot(epochs, val_loss, label='val_loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.grid()
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label='train_accuracy')
    if 'val_loss' in history.keys() and len(epochs) == len(history['val_loss']):
        plt.plot(epochs, val_accuracy, label='val_accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.grid()
    plt.legend()

In [None]:
plot_history(model_history)

In [None]:
epochs = range(len(model_history['lr']))
plt.plot(epochs, model_history['lr'])

In [None]:
# test accuracy
test_acc = evaluate(model, test_dataloader)
print(f'Final evaluation accuracy for test dataset: {test_acc:.2f}')

In [None]:
#model_history

In [None]:
model_type

In [None]:
# Save model
# 1. Save and load uing model.state_dict()
model_filename = f"cifar10_{model_type}_state_dict.pth"
torch.save(model.state_dict(), model_filename)

In [None]:
# Load model for evaluation
reloaded_model = ResNetV2(depth, num_classes=num_classes).to(device)
reloaded_model.load_state_dict(torch.load(model_filename, weights_only=True, map_location=device))
reloaded_model.eval()

# test accuracy
test_acc = evaluate(reloaded_model, test_dataloader)
print(f'Final evaluation accuracy for test dataset: {test_acc:.2f}')

In [None]:
# 2.Save and load a model

In [None]:
full_model_filename = f"cifar10_{model_type}.pth"
torch.save(model, full_model_filename)

In [None]:
full_model = torch.load(full_model_filename, weights_only=False, map_location=device)

In [None]:
print(full_model)

In [None]:
summary(full_model, input_size=(1, 3, 32, 32))

In [None]:
#outputs = evaluate(reloaded_model, train_dataloader)
outputs = evaluate(model, test_dataloader)

In [None]:
test_loss, test_acc = evaluate_val(reloaded_model, test_dataloader)
test_loss, test_acc

In [None]:
test_accuracy = evaluate(model, test_dataloader)
test_accuracy

### Confusion Matrix

In [None]:
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, targets in test_dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sn.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Ground Truth')
plt.title('Confunsion Matrix')
plt.show()

### Inference Test

In [None]:
import torchvision.transforms.v2.functional as F

In [None]:
# save an image from cifar10 test dataset and load it
import PIL

image_id = 0
image_file = f'./images/cifar10_test{image_id}.png'
image_tensor = test_dataset[image_id][0] # 0: tensor image, 1: label
#imdisplay(image_tensor.cpu(), 0)
label = test_dataset[image_id][1] # 0: tensor image, 1: label
image_tensor = (image_tensor / 2 + 0.5).clamp(min=0, max=1).cpu()
#imdisplay(image_tensor, 1)
plt.imshow(image_tensor.permute(1,2,0))
#image_tensor = image_tensor.permute(1,2,0)
image_pil = transforms.ToPILImage()(image_tensor)
image_pil.save(image_file)
#image_pil

In [None]:
# load the image
img = PIL.Image.open(image_file) # img: PIL image [0, 255], HxW, C

img_tensor = transforms.ToTensor()(img) # torch tensor image [0, 1], (C, H, W)
imdisplay(img_tensor, 1)

img_tensor = transforms.Normalize(0.5, 0.5)(img_tensor)
inputs = img_tensor.unsqueeze(0).to(device)

outputs = model(inputs)
_, preds = torch.max(outputs, 1)
print (f'gnd = {class_names[label]}, prediction = {class_names[preds[0]]}')

### Qualitative Evaluation



In [None]:
def evaluate_model(x_dataloader, model):
    num_rows = 4
    num_cols = 6
    
    # Retrieve a number of images from the dataset.
    with torch.no_grad():
        (inputs, targets) = next(iter(x_dataloader))
        inputs, targets = inputs.to(device), targets.to(device)
            
        # Get predictions from model.  
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total = targets.size(0)
        correct = predicted.eq(targets).sum().item()
    print(f'batch accuracy = {correct/total * 100:.1f}(N={total})')
    
    #n = len(targets) 
    num_cols = 6
    num_rows = min(total//num_cols, 4)
    
    data_batch = inputs.cpu()

    plt.figure(figsize=(20, 8))
    num_matches = 0
        
    for idx in range(num_rows*num_cols):
        ax = plt.subplot(num_rows, num_cols, idx + 1)
        plt.axis("off")
        img = transforms.Resize((32,32), interpolation=transforms.InterpolationMode.NEAREST)(data_batch[idx])
        img = img/2 + 0.5 # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

        pred_idx = predicted[idx]
        truth_idx = targets[idx]
            
        title = str(class_names[truth_idx]) + " : " + str(class_names[pred_idx])
        title_obj = plt.title(title, fontdict={'fontsize':13})
            
        if pred_idx == truth_idx:
            num_matches += 1
            plt.setp(title_obj, color='g')
        else:
            plt.setp(title_obj, color='r')
                
        acc = num_matches/(idx+1)
        
    print("Prediction accuracy (for data in display): ", int(100*acc)/100)
    
    return


In [None]:
evaluate_model(test_dataloader, model)