# ViT_pretrained

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional   # 변수에 None이 들어올 수도 있으면 Optional을 선언하여 사용함
from torch import Tensor
from torchvision.transforms import Compose, Resize, ToTensor
# from einops import rearrange, reduce, repeat
import einops
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from PIL import Image
import matplotlib.pyplot as plt
# from torchsummaryX import summary
import numpy as np
from torch.utils import model_zoo
from scipy.ndimage import zoom
import torch.optim as optim
import glob
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from einops import rearrange, reduce, repeat
from torchvision.datasets import ImageFolder
from matplotlib.pyplot import imshow

### configs - ViT model configurations

In [2]:
"""configs.py - ViT model configurations, based on:
https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py
"""

def get_base_config():
    """Base ViT config ViT"""
    return dict(
      dim=768,
      ff_dim=3072,
      num_heads=12,
      num_layers=12,
      attention_dropout_rate=0.0,
      dropout_rate=0.1,
      representation_size=768,
      classifier='token'
    )

def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = get_base_config()
    config.update(dict(patches=(16, 16)))
    return config

def get_b32_config():
    """Returns the ViT-B/32 configuration."""
    config = get_b16_config()
    config.update(dict(patches=(32, 32)))
    return config

def get_l16_config():
    """Returns the ViT-L/16 configuration."""
    config = get_base_config()
    config.update(dict(
        patches=(16, 16),
        dim=1024,
        ff_dim=4096,
        num_heads=16,
        num_layers=24,
        attention_dropout_rate=0.0,
        dropout_rate=0.1,
        representation_size=1024
    ))
    return config

def get_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = get_l16_config()
    config.update(dict(patches=(32, 32)))
    return config

def drop_head_variant(config):
    config.update(dict(representation_size=None))
    return config


PRETRAINED_MODELS = {
    'B_16': {
      'config': get_b16_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth"
    },
    'B_32': {
      'config': get_b32_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth"
    },
    'L_16': {
      'config': get_l16_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': None
    },
    'L_32': {
      'config': get_l32_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth"
    },
    'B_16_imagenet1k': {
      'config': drop_head_variant(get_b16_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth"
    },
    'B_32_imagenet1k': {
      'config': drop_head_variant(get_b32_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth"
    },
    'L_16_imagenet1k': {
      'config': drop_head_variant(get_l16_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth"
    },
    'L_32_imagenet1k': {
      'config': drop_head_variant(get_l32_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth"
    },
}

### model

In [3]:
# shape 주석은 img size 224 기준
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size= 16, emb_size : int = 768, img_size = 384):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),Rearrange('b e (h) (w) -> b (h w) e'),)
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))   # torch.randn(1,1,emb_size) : 1행 1열, emb_size(768, 채널) 짜리 평균이 0이고 표준편차가 1인 가우시안 정규분포의 난수로 채워진 텐서 반환 
        self.positions = nn.Parameter(torch.randn((img_size // patch_size[0])**2 +1, emb_size))  # (img_size // patch_size)**2는 패치 갯수, +1은 cls_token 으로 늘어난 크기에 맞춰 더한 것
        
    def forward(self, x):
#         print('patch emb x:', x)
        b = x.shape[0]
        # print('b:',b)
#         print('patch emb in:',x)
        x = self.projection(x)   # [8,768,14,14] -> [8,196,768] 
        # print('after projection:',x.shape)
        
        # cls_token을 반복하여 batch size와 크기 맞춰줌(배치 크기 만큼 확장)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)   # [1,1,768] -> [8,1,768]
        
        # cls token을 input에 추가(concat) 
        x = torch.cat([cls_tokens,x], dim=1)    # [8,197,768]
        
         # position embedding 더해줌
        x += self.positions   # [197,768](positions) + [8,197,768](x)
        # print('positions:',x)

        return x

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size : int = 768, num_heads: int = 8, dropout: float=0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self,x ,mask: None): 
        # queries, keys, values -- [8,8,197,96]
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)  # batch, heads, sequence_len, embedding_size(여기서 embedding_size는 embedding_size / n_heads 한 값 인듯) 모양으로 변경
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values  = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        
        
        # einsum 연산을 통해 행렬,벡터의 내적(dot product),외적(outer products), 전치(transpose), 행렬곱 등을 일관성있게 표현 가능
        # 쿼리와 키 곱(einops 이용해 자동으로 transpose 후 내적이 진행됨)
        energy = torch.einsum('bhqd,bhkd -> bhqk', queries, keys)   # 결과 벡터 모양은 batch, heads, query_len, key_len -- [8,8,197,197]
        
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill()
            
        scaling = self.emb_size**(1/2)
        att = F.softmax(energy, dim=-1) / scaling  # [8,8,197,197]
        att = self.att_drop(att)
        
        # scaling해준 후 얻어진 attention score와 value를 내적
        out = torch.einsum('bhal, bhlv -> bhav', att, values)  # [8,8,197,96]
        
        # emb_size로 rearrange하면 MHA의 output나옴
        out = rearrange(out, 'b h n d -> b n (h d)')  #  [8,197,768]
        
        # 최종 output은 linear layer거쳐서 나오게 됨
        out = self.projection(out)   
      
        return out

In [5]:
# class ResidualAdd(nn.Module):   # 이 클래스 사용하면 forward() takes 1 positional argument but 2 were given -> 이 오류 계속 남
#     def __init__(self,fn):
#         super().__init__()
#         self.fn = fn
        
#     def forward(self, **kwargs):  # **kwargs는 keyword argument의 줄임말로 키워드를 제공.딕셔너리 형태로 {'키워드':'특정 값'} 이렇게 함수내부로 전달됨
#         res = x
#         x = self.fn(x, **kwargs)
#         x +=res
        
#         return x

In [6]:
# MHA 이후 진행되는 MLP 부분
class FeedForwardBlock(nn.Sequential):  
    def __init__(self, emb_size: int, expansion: int=4, drop_p: float= 0.):
        super().__init__(nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion*emb_size, emb_size),)    # emb_size확장하고 두번째 linear layer에서 다시 원래의 emb_size로 축소

In [7]:
# class TransformerEncoderBlock(nn.Sequential):     # 이 클래스 사용하면 forward() takes 1 positional argument but 2 were given -> 이 오류 계속 남
#     def __init__(self, emb_size: int=768, drop_p: float=0., forward_expansion: int=4, forward_drop_p: float=0., **kwargs):
#         super().__init__(ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, **kwargs), nn.Dropout(drop_p))),
#                          ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size), FeedForwardBlock(emb_size, expansion = forward_expansion, drop_p = forward_drop_p), nn.Dropout(drop_p))))

In [8]:
# class TransformerEncoder(nn.Sequential):     # 이 클래스 사용하면 forward() takes 1 positional argument but 2 were given -> 이 오류 계속 남
#     def __init__(self, depth: int=12, **kwargs):
#         super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])  # 앞에 *가 붙은 이유는 인자를 리스트 형식으로 보내는게 아니라 각각 나눠서 보내줘야하기 때문. 인자를 [1,2,3] 넣을 경우 함수에서는 [1,2,3]으로 받지만 *[1,2,3]일 경우 1,2,3으로 각각 나눠진 후 들어감

In [9]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size: int=768, drop_p: float=0., forward_expansion: int=4, forward_drop_p: float=0., **kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.attention = MultiHeadAttention(emb_size, **kwargs)
        self.dropout = nn.Dropout(drop_p)
        self.norm2 = nn.LayerNorm(emb_size)
        self.mlp = FeedForwardBlock(emb_size, expansion = forward_expansion, drop_p = forward_drop_p)
        
    def forward(self,x,mask=None):
        h = self.norm1(x)
        h = self.attention(h, mask=None)
        h = self.dropout(h)
        x = x+h
        h = self.norm2(x)
        h = self.mlp(h)
        h = self.dropout(h)
        x = x+h

        return x

In [10]:
class TransformerEncoder(nn.Module):
    def __init__(self, depth: int=12, **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerEncoderBlock(**kwargs) for _ in range(depth)])

    def forward(self,x, mask=None):
        for block in self.blocks:
            x = block(x,mask)

        return x

In [11]:
class ViT(nn.Module):
    def __init__(self, name = None, pretrained: bool=False, patch_size: int=16, emb_size: int=768, ff_dim = 3072, num_heads: int=12, 
                 attention_dropout_rate: float=0., dropout_rate = 0.1, in_channels: int=3,image_size: int=384, depth: int=12, n_classes: int=2, **kwargs):
        super().__init__()
        
        if name is None:
            check_msg = 'must specify name of pretrained model'
            assert not pretrained, check_msg    # assert - 가정 설정문, 뒤의 조건이 true가 아니면 AssertError발생시킴
            assert not resize_positional_embedding, check_msg  
            if n_classes is None:
                n_classes = 2
            if image_size is None:
                image_size = 384
        
        else:  # load pretrained model
            assert name in PRETRAINED_MODELS.keys(), \
                'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys())
            config = PRETRAINED_MODELS[name]['config']
            patch_size = config['patches']
            dim = config['dim']
            ff_dim = config['ff_dim']
            num_heads = config['num_heads']
            depth = config['num_layers']
            attention_dropout_rate = config['attention_dropout_rate']
            dropout_rate = config['dropout_rate']
#             representation_size = config['representation_size']
#             classifier = config['classifier']
            
            if image_size is None:
                image_size = PRETRAINED_MODELS[name]['image_size']
            
            if n_classes is None:
                n_classes = PRETRAINED_MODELS[name]['num_classes']
                
#         self.image_size = image_size
        
        # patch embedding
        print(patch_size)
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, image_size) 
                   
        # print('self.transformer input:',TransformerEncoder(depth, emb_size=emb_size, **kwargs))
        # transformer
        self.transformer = TransformerEncoder(depth, emb_size=emb_size, **kwargs)
        
        
        # classifier head
#         self.classification = ClassificationHead(emb_size, n_classes)
        self.norm = nn.LayerNorm(emb_size,eps=1e-6)
        self.fc = nn.Linear(emb_size ,n_classes)
        
        # initialize weights
        self.init_weights()
        
        # load pretrained model
        if pretrained:
            pretrained_num_channels = 3
            pretrained_num_classes = PRETRAINED_MODELS[name]['num_classes']
            pretrained_image_size = PRETRAINED_MODELS[name]['image_size']
#             load_pretrained_weights(name, None)
            
            url = PRETRAINED_MODELS[name]['url']
            if url:
                state_dict = model_zoo.load_url(url)   # model_zoo.load_url - 주어진 url에서 torch 직렬화된 개체를 로드

            else:
                raise ValueError(f'pretrained model for {name} has not yet been released')   # raise - 예외 에러 발생 시키기, ValueError은 키워드

            expected_missing_keys = []
            if pretrained_num_channels != in_channels and 'patch_embedding.projection.weight' in state_dict:
                expected_missing_keys +=['patch_embedding.projection.weight', 'projection.bias']

            if pretrained_num_classes != n_classes and 'fc.weight' in state_dict:
                expected_missing_keys +=['fc.weight', 'fc.bias']

            for key in expected_missing_keys:
#                 print(state_dict)
                state_dict.pop(key)    # pop - 리스트의 마지막 요소 꺼내고 삭제

            # resize positional embedding
            if pretrained_image_size != image_size in state_dict:
                posemb = state_dict['patch_embedding.positions']   # static_dict은 torch.save처럼 모델 저장, 각 layer마다 텐서로 매핑되는 매개변수(예를 들어 가중치, 편향 등)를 python dictionary 타입으로 저장한 객체(한마디로 모델 구조에 맞게 각 레이어마다 매개변수를 텐서형태로 매핑해서 dictionary 형태로 저장하는 것)
                posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]  # posemb에 어떤형태로 값이 들어간지 모르겠음

                # get old and new grid sizes
                grid_size_old = int(np.sqrt(len(posemb_grid)))
                grid_size_new = int(np.sqrt(ntok_new))
                posemb_grid = posemb_grid.reshape(grid_size_old, grid_size_old, -1)   # 불러온 그리드로 reshape해주고?

                # rescale grid
                zoom_factor = (grid_size_new / grid_size_old, grid_size_new / grid_size_old,1)
                posemb_grid = zoom(posemb_grid, zoom_factor, order =1)   # zoom은 보간해주는거? 줌 패턴으로 그리드를 늘림
                posemb_grid = posemb_grid.reshape(1, grid_size_new * grid_size_new, -1)  # 그리드 사이즈에 맞게 reshape
                posemb_grid = torch.from_numpy(posemb_grid)

                # deal with class token and return 
                posemb = torch.cat([posemb_tok, posemb_grid], dim=1)

                state_dict['patch_embedding.positions'] = posemb
                
            self.load_state_dict(state_dict, strict=False)

    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):   # m이 nn.linear인지 확인..?
                nn.init.xavier_uniform_(m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:   # hasattr- 변수가 있는지 확인함. ex) hasattr(cls,'b') -> cls에 b라는 멤버가 있는지 확인
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.patch_embedding.positions, std=0.02)  # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
#         nn.init.constant_(self.class_token, 0)

    def forward(self,x):
        # print('vit input:',x.shape)
        x = self.patch_embedding(x)   
        # print('after patch emb x:',x.shape)
        x = self.transformer(x)   ################오류
#         x = self.classification(x)
        x = self.norm(x)[:,0]  # class token만
        x = self.fc(x)
#             if hasattr(self, 'class_token'):
#                 x = torch.cat((self.class_token.expand(b,-1,-1),x), dim=1)
#             if hasattr(self,'positional_embedding'):
#                 x = self.positional_embedding(x)
#             x = self.transformer(x)
#             if hasattr(self,'pre_logits'):
#                 x = self.pre_logits(x)
#                 x = torch.tanh(x)
#             if hasattr(self,'fc'):
#                 x = self.norm(x)[:,0]
#                 x = self.fc(x)

        return x


### data

In [12]:
train_transform = transforms.Compose([transforms.Resize((384,384)),transforms.ToTensor()])
val_transform = transforms.Compose([transforms.Resize((384,384)),transforms.ToTensor()])
test_transform = transforms.Compose([transforms.Resize((384,384)),transforms.ToTensor()])

In [13]:
train_dataset = ImageFolder('F:\\mk\\chest_xray\\train\\',  train_transform)
val_dataset = ImageFolder('F:\\mk\\chest_xray\\val\\',  val_transform)
test_dataset = ImageFolder('F:\\mk\\chest_xray\\test\\', test_transform)

In [14]:
classes = train_dataset.classes
print(classes)

['NORMAL', 'PNEUMONIA']


In [15]:
for num, value in enumerate(train_dataset): 
    data, label = value 
    print('num ', num)
    print('data ', data)
    print('label ', label)

    break

num  0
data  tensor([[[0.0902, 0.0824, 0.0784,  ..., 0.3725, 0.3725, 0.3725],
         [0.0902, 0.0863, 0.0784,  ..., 0.3725, 0.3725, 0.3765],
         [0.0863, 0.0824, 0.0863,  ..., 0.3725, 0.3725, 0.3725],
         ...,
         [0.1333, 0.1373, 0.1373,  ..., 0.3216, 0.3255, 0.3255],
         [0.1608, 0.1608, 0.1608,  ..., 0.3882, 0.3843, 0.3922],
         [0.1922, 0.1922, 0.1882,  ..., 0.4588, 0.4549, 0.4588]],

        [[0.0902, 0.0824, 0.0784,  ..., 0.3725, 0.3725, 0.3725],
         [0.0902, 0.0863, 0.0784,  ..., 0.3725, 0.3725, 0.3765],
         [0.0863, 0.0824, 0.0863,  ..., 0.3725, 0.3725, 0.3725],
         ...,
         [0.1333, 0.1373, 0.1373,  ..., 0.3216, 0.3255, 0.3255],
         [0.1608, 0.1608, 0.1608,  ..., 0.3882, 0.3843, 0.3922],
         [0.1922, 0.1922, 0.1882,  ..., 0.4588, 0.4549, 0.4588]],

        [[0.0902, 0.0824, 0.0784,  ..., 0.3725, 0.3725, 0.3725],
         [0.0902, 0.0863, 0.0784,  ..., 0.3725, 0.3725, 0.3765],
         [0.0863, 0.0824, 0.0863,  ..., 0.372

In [16]:
train_data_loader = DataLoader(train_dataset,batch_size = 6, shuffle=True, num_workers=0)
val_data_loader = DataLoader(val_dataset,batch_size = 6, shuffle=True, num_workers=0)
test_data_loader = DataLoader(test_dataset,batch_size= 6, shuffle=False, num_workers=0)

### train

In [76]:
class Early_Stopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pth'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, path):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model,path):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            
        # 모델 저장
        torch.save(model.state_dict(), path+"vit_model.pth")

        self.val_loss_min = val_loss

In [77]:
model_vit = ViT(name ='B_16_imagenet1k', pretrained=True, emb_size=768, ff_dim=3072, num_heads=12, 
                attention_dropout_rate=0.0, dropout_rate = 0.1, in_channels=3,image_size=384, depth=12, n_classes=2 ).cuda()
# USE_CUDA = torch.cuda.is_available()  
# DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
# print('device:',DEVICE) 
# batch_size = 64   
epochs = 100
learning_rate = 0.0001

# criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_vit.parameters(), lr=learning_rate)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)

summary(model_vit,(3,384,384))

(16, 16)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 24, 24]         590,592
         Rearrange-2             [-1, 576, 768]               0
    PatchEmbedding-3             [-1, 577, 768]               0
         LayerNorm-4             [-1, 577, 768]           1,536
            Linear-5             [-1, 577, 768]         590,592
            Linear-6             [-1, 577, 768]         590,592
            Linear-7             [-1, 577, 768]         590,592
           Dropout-8          [-1, 8, 577, 577]               0
            Linear-9             [-1, 577, 768]         590,592
MultiHeadAttention-10             [-1, 577, 768]               0
          Dropout-11             [-1, 577, 768]               0
        LayerNorm-12             [-1, 577, 768]           1,536
           Linear-13            [-1, 577, 3072]       2,362,368
             GELU-14         

In [99]:
earlystop_patient = 40
earlystopping = Early_Stopping(patience=earlystop_patient,verbose=True)

def train_val(model_vit, train_data_loader, val_data_loader, optimizer, epochs):
    # train
    val_loss = []
    for epoch in range(1,epochs+1):
        model_vit.train()
        for batch_idx,(img,target) in enumerate(train_data_loader):
            img, target = img.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model_vit(img)
            loss = F.cross_entropy(output,target)
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0 :
                print("[Epoch %d/%d] [Batch %d/%d] [loss: %.6f]" % (epoch, epochs, batch_idx, len(train_data_loader), loss.item()))
#                 print('train epoch : {}[{}/{} ({: .0f}%)]\tloss:{:.6f}'.format(epoch, len(img), len(train_data_loader.dataset),
#                                                                            100* batch_idx / len(train_data_loader.dataset),loss.item()))

        print("-----[Epoch %d/%d] [avg_train_loss: %f]-----" % (epoch, epochs, np.average(loss.item())))
     
        # validation          
        model_vit.eval()
        save_path = 'F:\\mk\\save_model\\'
        
        with torch.no_grad():
            for i,(v_imgs,v_targets) in enumerate(val_data_loader):
                img, target = v_imgs.cuda(), v_targets.cuda()
                output = model_vit(img)

                # 배치 오차를 합산
            
                validation_loss = F.cross_entropy(output, target, reduction='sum').item()
                validation_loss += validation_loss
                print("validation loss: {}".format(validation_loss))
                val_loss.append(validation_loss)

                # 가장 높은 값을 가진 인덱스가 바로 예측값
#                 pred = output.max(1, keepdim=True)[1]
#                 correct_val = pred.eq(target.view_as(pred)).sum().item()
#                 correct_val + = correct_val
        


        valid_loss=np.average(val_loss)
        print("-----[val_avg_loss %f]-----" % valid_loss)

        # loss 떨어지면 model 저장
        earlystopping(validation_loss,model_vit,path=save_path)

        if earlystopping.early_stop:
            print("Early stopping")
            break


In [100]:
## 5
train_val(model_vit, train_data_loader, val_data_loader, optimizer,epochs)

[Epoch 1/100] [Batch 0/870] [loss: 0.095607]
[Epoch 1/100] [Batch 100/870] [loss: 0.018003]
[Epoch 1/100] [Batch 200/870] [loss: 0.085457]
[Epoch 1/100] [Batch 300/870] [loss: 0.153195]
[Epoch 1/100] [Batch 400/870] [loss: 0.013461]
[Epoch 1/100] [Batch 500/870] [loss: 0.004137]
[Epoch 1/100] [Batch 600/870] [loss: 0.068757]
[Epoch 1/100] [Batch 700/870] [loss: 0.091549]
[Epoch 1/100] [Batch 800/870] [loss: 0.020403]
-----[Epoch 1/100] [avg_train_loss: 0.003539]-----
validation loss: 13.204612731933594
validation loss: 11.928382873535156
validation loss: 10.538419723510742
-----[val_avg_loss 11.890472]-----
Validation loss decreased (inf --> 10.538420).  Saving model ...
[Epoch 2/100] [Batch 0/870] [loss: 0.309153]
[Epoch 2/100] [Batch 100/870] [loss: 0.042092]
[Epoch 2/100] [Batch 200/870] [loss: 0.077056]
[Epoch 2/100] [Batch 300/870] [loss: 0.209545]
[Epoch 2/100] [Batch 400/870] [loss: 0.158199]
[Epoch 2/100] [Batch 500/870] [loss: 0.028819]
[Epoch 2/100] [Batch 600/870] [loss: 0.0

[Epoch 13/100] [Batch 600/870] [loss: 0.036212]
[Epoch 13/100] [Batch 700/870] [loss: 0.002649]
[Epoch 13/100] [Batch 800/870] [loss: 0.001798]
-----[Epoch 13/100] [avg_train_loss: 0.193531]-----
validation loss: 5.811218738555908
validation loss: 15.76420783996582
validation loss: 0.8315023183822632
-----[val_avg_loss 9.907715]-----
Validation loss decreased (1.272917 --> 0.831502).  Saving model ...
[Epoch 14/100] [Batch 0/870] [loss: 0.062413]
[Epoch 14/100] [Batch 100/870] [loss: 0.012028]
[Epoch 14/100] [Batch 200/870] [loss: 0.629920]
[Epoch 14/100] [Batch 300/870] [loss: 0.451897]
[Epoch 14/100] [Batch 400/870] [loss: 0.017680]
[Epoch 14/100] [Batch 500/870] [loss: 0.060402]
[Epoch 14/100] [Batch 600/870] [loss: 0.142645]
[Epoch 14/100] [Batch 700/870] [loss: 0.084459]
[Epoch 14/100] [Batch 800/870] [loss: 0.005104]
-----[Epoch 14/100] [avg_train_loss: 0.010890]-----
validation loss: 29.83350944519043
validation loss: 9.165514945983887
validation loss: 9.665331840515137
-----[va

validation loss: 45.08815002441406
validation loss: 15.458662986755371
-----[val_avg_loss 12.028366]-----
EarlyStopping counter: 1 out of 40
[Epoch 26/100] [Batch 0/870] [loss: 0.035109]
[Epoch 26/100] [Batch 100/870] [loss: 0.003978]
[Epoch 26/100] [Batch 200/870] [loss: 0.055177]
[Epoch 26/100] [Batch 300/870] [loss: 0.008137]
[Epoch 26/100] [Batch 400/870] [loss: 0.006588]
[Epoch 26/100] [Batch 500/870] [loss: 0.033039]
[Epoch 26/100] [Batch 600/870] [loss: 0.116938]
[Epoch 26/100] [Batch 700/870] [loss: 0.032737]
[Epoch 26/100] [Batch 800/870] [loss: 0.012317]
-----[Epoch 26/100] [avg_train_loss: 0.032919]-----
validation loss: 11.866551399230957
validation loss: 4.125335216522217
validation loss: 0.015078019350767136
-----[val_avg_loss 11.770954]-----
Validation loss decreased (0.321080 --> 0.015078).  Saving model ...
[Epoch 27/100] [Batch 0/870] [loss: 0.052498]
[Epoch 27/100] [Batch 100/870] [loss: 0.044015]
[Epoch 27/100] [Batch 200/870] [loss: 0.022291]
[Epoch 27/100] [Batch 

[Epoch 38/100] [Batch 300/870] [loss: 0.010182]
[Epoch 38/100] [Batch 400/870] [loss: 0.035171]
[Epoch 38/100] [Batch 500/870] [loss: 0.196857]
[Epoch 38/100] [Batch 600/870] [loss: 0.074751]
[Epoch 38/100] [Batch 700/870] [loss: 0.731541]
[Epoch 38/100] [Batch 800/870] [loss: 0.320486]
-----[Epoch 38/100] [avg_train_loss: 0.008586]-----
validation loss: 27.58959197998047
validation loss: 3.88928484916687
validation loss: 11.005369186401367
-----[val_avg_loss 11.965691]-----
EarlyStopping counter: 12 out of 40
[Epoch 39/100] [Batch 0/870] [loss: 0.033317]
[Epoch 39/100] [Batch 100/870] [loss: 0.001131]
[Epoch 39/100] [Batch 200/870] [loss: 0.002550]
[Epoch 39/100] [Batch 300/870] [loss: 0.152395]
[Epoch 39/100] [Batch 400/870] [loss: 0.241439]
[Epoch 39/100] [Batch 500/870] [loss: 0.003025]
[Epoch 39/100] [Batch 600/870] [loss: 0.009264]
[Epoch 39/100] [Batch 700/870] [loss: 0.001533]
[Epoch 39/100] [Batch 800/870] [loss: 0.003911]
-----[Epoch 39/100] [avg_train_loss: 0.003324]-----
va

-----[Epoch 50/100] [avg_train_loss: 0.015511]-----
validation loss: 1.4168199300765991
validation loss: 37.599876403808594
validation loss: 9.99209976196289
-----[val_avg_loss 12.561042]-----
EarlyStopping counter: 4 out of 40
[Epoch 51/100] [Batch 0/870] [loss: 0.005636]
[Epoch 51/100] [Batch 100/870] [loss: 0.010855]
[Epoch 51/100] [Batch 200/870] [loss: 0.022190]
[Epoch 51/100] [Batch 300/870] [loss: 0.007695]
[Epoch 51/100] [Batch 400/870] [loss: 0.001621]
[Epoch 51/100] [Batch 500/870] [loss: 0.038868]
[Epoch 51/100] [Batch 600/870] [loss: 0.005220]
[Epoch 51/100] [Batch 700/870] [loss: 0.032123]
[Epoch 51/100] [Batch 800/870] [loss: 0.000409]
-----[Epoch 51/100] [avg_train_loss: 0.119576]-----
validation loss: 0.7974138259887695
validation loss: 14.572258949279785
validation loss: 2.673143148422241
-----[val_avg_loss 12.432674]-----
EarlyStopping counter: 5 out of 40
[Epoch 52/100] [Batch 0/870] [loss: 0.419226]
[Epoch 52/100] [Batch 100/870] [loss: 0.000661]
[Epoch 52/100] [Bat

[Epoch 63/100] [Batch 200/870] [loss: 0.001999]
[Epoch 63/100] [Batch 300/870] [loss: 0.197921]
[Epoch 63/100] [Batch 400/870] [loss: 0.000082]
[Epoch 63/100] [Batch 500/870] [loss: 0.000945]
[Epoch 63/100] [Batch 600/870] [loss: 0.091088]
[Epoch 63/100] [Batch 700/870] [loss: 0.152571]
[Epoch 63/100] [Batch 800/870] [loss: 0.001362]
-----[Epoch 63/100] [avg_train_loss: 0.001566]-----
validation loss: 33.45314407348633
validation loss: 3.460793972015381
validation loss: 0.11449095606803894
-----[val_avg_loss 12.913098]-----
EarlyStopping counter: 17 out of 40
[Epoch 64/100] [Batch 0/870] [loss: 0.000640]
[Epoch 64/100] [Batch 100/870] [loss: 0.002322]
[Epoch 64/100] [Batch 200/870] [loss: 0.134725]
[Epoch 64/100] [Batch 300/870] [loss: 0.033613]
[Epoch 64/100] [Batch 400/870] [loss: 0.000338]
[Epoch 64/100] [Batch 500/870] [loss: 0.234537]
[Epoch 64/100] [Batch 600/870] [loss: 0.042622]
[Epoch 64/100] [Batch 700/870] [loss: 0.046744]
[Epoch 64/100] [Batch 800/870] [loss: 0.002868]
----

[Epoch 75/100] [Batch 800/870] [loss: 0.001419]
-----[Epoch 75/100] [avg_train_loss: 0.000107]-----
validation loss: 31.24398422241211
validation loss: 4.570513725280762
validation loss: 4.727726459503174
-----[val_avg_loss 13.460219]-----
EarlyStopping counter: 29 out of 40
[Epoch 76/100] [Batch 0/870] [loss: 0.005334]
[Epoch 76/100] [Batch 100/870] [loss: 0.000006]
[Epoch 76/100] [Batch 200/870] [loss: 0.000088]
[Epoch 76/100] [Batch 300/870] [loss: 0.721820]
[Epoch 76/100] [Batch 400/870] [loss: 0.059646]
[Epoch 76/100] [Batch 500/870] [loss: 0.003799]
[Epoch 76/100] [Batch 600/870] [loss: 0.000370]
[Epoch 76/100] [Batch 700/870] [loss: 0.000549]
[Epoch 76/100] [Batch 800/870] [loss: 0.378900]
-----[Epoch 76/100] [avg_train_loss: 0.002628]-----
validation loss: 34.82727813720703
validation loss: 18.985389709472656
validation loss: 9.5768404006958
-----[val_avg_loss 13.561135]-----
EarlyStopping counter: 30 out of 40
[Epoch 77/100] [Batch 0/870] [loss: 0.000573]
[Epoch 77/100] [Batch

[Epoch 88/100] [Batch 0/870] [loss: 0.005202]
[Epoch 88/100] [Batch 100/870] [loss: 0.010602]
[Epoch 88/100] [Batch 200/870] [loss: 0.192433]
[Epoch 88/100] [Batch 300/870] [loss: 0.000370]
[Epoch 88/100] [Batch 400/870] [loss: 0.094861]
[Epoch 88/100] [Batch 500/870] [loss: 0.000211]
[Epoch 88/100] [Batch 600/870] [loss: 0.000005]
[Epoch 88/100] [Batch 700/870] [loss: 0.008243]
[Epoch 88/100] [Batch 800/870] [loss: 0.003923]
-----[Epoch 88/100] [avg_train_loss: 0.004831]-----
validation loss: 2.0508503913879395
validation loss: 5.369868755340576
validation loss: 24.138805389404297
-----[val_avg_loss 13.680542]-----
EarlyStopping counter: 2 out of 40
[Epoch 89/100] [Batch 0/870] [loss: 0.003474]
[Epoch 89/100] [Batch 100/870] [loss: 0.002643]
[Epoch 89/100] [Batch 200/870] [loss: 0.006217]
[Epoch 89/100] [Batch 300/870] [loss: 0.013792]
[Epoch 89/100] [Batch 400/870] [loss: 0.002111]
[Epoch 89/100] [Batch 500/870] [loss: 0.000390]
[Epoch 89/100] [Batch 600/870] [loss: 0.238088]
[Epoch 

[Epoch 100/100] [Batch 600/870] [loss: 0.001841]
[Epoch 100/100] [Batch 700/870] [loss: 0.000851]
[Epoch 100/100] [Batch 800/870] [loss: 0.000163]
-----[Epoch 100/100] [avg_train_loss: 0.116199]-----
validation loss: 18.77013397216797
validation loss: 2.846834421157837
validation loss: 0.0693703144788742
-----[val_avg_loss 14.195218]-----
EarlyStopping counter: 14 out of 40


### test

In [105]:
def evaluate(model_vit, test_data_loader):
    model_vit.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx,(img,target) in enumerate(test_data_loader):
            img, target = img.cuda(), target.cuda()
            output = model_vit(img)

            # 배치 오차를 합산
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            print('test loss: ', test_loss)


            # 가장 높은 값을 가진 인덱스가 바로 예측값
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_data_loader.dataset)
    test_accuracy = 100. * correct / len(test_data_loader.dataset)

    print('Epoch: [{}] Test Loss: {:.4f}, Accuracy: {:.2f}%'.format(epochs, test_loss, test_accuracy))


In [106]:
evaluate(model_vit, test_data_loader)    

test loss:  14.175180435180664
test loss:  15.68332326412201
test loss:  22.245993971824646
test loss:  23.132773518562317
test loss:  29.793816685676575
test loss:  33.11590230464935
test loss:  34.82793712615967
test loss:  36.41619431972504
test loss:  37.4918372631073
test loss:  39.69332408905029
test loss:  40.05107420682907
test loss:  43.11215132474899
test loss:  55.95081442594528
test loss:  60.72455710172653
test loss:  85.6560909152031
test loss:  96.29649180173874
test loss:  112.26405066251755
test loss:  122.80559748411179
test loss:  144.50718516111374
test loss:  170.31282824277878
test loss:  184.78570955991745
test loss:  214.65224474668503
test loss:  236.30214709043503
test loss:  269.5859224200249
test loss:  283.73496836423874
test loss:  283.900126978755
test loss:  286.7518022507429
test loss:  287.3478495925665
test loss:  294.0965773910284
test loss:  294.1011264640838
test loss:  302.52941015549004
test loss:  304.30270949192345
test loss:  304.3813646901399

In [53]:
torch.cuda.empty_cache()