# 1.Library Import

In [1]:
import yaml
from box import Box

import torch
import torch.nn as nn
import torch.optim as optim

import simmim
from swin_v2 import SwinTransformerV2

from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

from timm.data import Mixup
import transformers

  from .autonotebook import tqdm as notebook_tqdm


# 2.Configuration

In [2]:
simmim_config = yaml.load(open('config/pretrain.yaml'), Loader=yaml.FullLoader)
simmim_config

{'MODEL': {'TYPE': 'swinv2',
  'NAME': 'simmim_pretrain',
  'DROP_PATH_RATE': 0.0,
  'SWIN': {'EMBED_DIM': 96,
   'DEPTHS': [2, 2, 6, 2],
   'NUM_HEADS': [3, 6, 12, 24],
   'WINDOW_SIZE': 6,
   'PATCH_SIZE': 4}},
 'DATA': {'IMG_SIZE': 192,
  'MASK_PATCH_SIZE': 32,
  'MASK_RATIO': 0.6,
  'BATCH_SIZE': 1024,
  'NUM_WORKERS': 24,
  'DATA_PATH': '../../data/sports'},
 'TRAIN': {'EPOCHS': 100,
  'WARMUP_EPOCHS': 10,
  'BASE_LR': 0.0014,
  'WEIGHT_DECAY': 0.05,
  'CLIP_GRAD': 5}}

In [3]:
encoder_config = {'img_size':simmim_config['DATA']['IMG_SIZE'], 
                'patch_size':simmim_config['MODEL']['SWIN']['PATCH_SIZE'], 
                'in_chans':3, 
                'num_classes':100,
                'embed_dim':192, 
                'depths':[2,2,18,4], 
                'num_heads':[6,12,24,48],           
                'window_size':12, 
                'mlp_ratio':4., 
                'qkv_bias':True, 
                'qk_scale':None,
                'drop_rate':0., 
                'attn_drop_rate':0., 
                'drop_path_rate':.2,
                'norm_layer':nn.LayerNorm, 
                'patch_norm':True, 
                'pretrained_window_sizes':[0,0,0,0],
                'ape':True}

encoder_stride = 32
in_chans = encoder_config['in_chans']
patch_size = encoder_config['patch_size']

# 3.Load SimMIM

In [4]:
encoder = simmim.SwinTransformerV2ForSimMIM(**encoder_config)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
model = simmim.SimMIM( encoder=encoder, 
                       encoder_stride=encoder_stride, 
                       in_chans=in_chans, 
                       patch_size=patch_size)

## Mask Generator Test

In [6]:
mask_generator = simmim.MaskGenerator(input_size=224,
                                      mask_patch_size=28,
                                      model_patch_size=28,
                                      mask_ratio=.6)
mask = mask_generator()
mask

array([[1, 1, 1, 0, 1, 0, 0, 1],
       [1, 1, 0, 1, 0, 1, 1, 0],
       [1, 1, 1, 1, 1, 0, 1, 0],
       [1, 0, 1, 1, 0, 0, 1, 0],
       [0, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 0],
       [0, 1, 1, 1, 1, 1, 0, 1]])

In [7]:
print(f"생성된 mask의 비율은 {mask.sum() / (mask.shape[0]*mask.shape[1])*100}%")

생성된 mask의 비율은 60.9375%


## SimMIM DataLoader

In [8]:
simmim_config = Box(simmim_config)
simmim_config.DATA.BATCH_SIZE = 256
simmim_config.DATA.BATCH_SIZE

256

In [9]:
dataloader = simmim.build_loader_simmim(simmim_config)

samples = next(iter(dataloader))
len(samples)

3

In [10]:
samples[0].shape, samples[1].shape, samples[2].shape 

(torch.Size([256, 3, 192, 192]), torch.Size([256, 48, 48]), torch.Size([256]))

## Hyper Parameters and etc.

In [11]:
base_lr = float(simmim_config.TRAIN.BASE_LR)
weight_decay = .1
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
warmup_epochs = 20
train_epochs = 100

scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, 
                                                        num_warmup_steps=warmup_epochs*len(dataloader), 
                                                        num_training_steps=train_epochs*len(dataloader),
                                                        num_cycles=0.5)

2024-01-19 14:05:20.294512: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-19 14:05:20.294587: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-19 14:05:20.295445: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-19 14:05:20.300970: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
device = 'cuda:2'
model.to(device)
torch.backends.cudnn.benchmark = True

In [13]:
model_save = True
simmim_path = '../../models/swin2/simmim_large.pth'

# 4.Train SimMIM

In [14]:
training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')

# GradScaler 초기화
scaler = GradScaler()

for epoch in range(train_epochs):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        image, mask = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            loss = model(image, mask)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        if simmim_config.TRAIN.CLIP_GRAD:
            clip_grad_norm_(model.parameters(), max_norm=100)
        else:
            clip_grad_norm_(model.parameters())

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    losses.append(epoch_loss)

    # 모델 저장
    if epoch_loss < best_loss:
        
        best_loss = epoch_loss
        vit_save = model_save
        if vit_save:
            torch.save(model.state_dict(), simmim_path)
        
    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss:.4f}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False    
        
    print(text)


Epoch 1: 100%|██████████| 56/56 [01:05<00:00,  1.16s/it]


	Loss: 1.1199, LR: 7.000000000000001e-05, Duration: 67.11 sec - model saved!


Epoch 2: 100%|██████████| 56/56 [01:01<00:00,  1.09s/it]


	Loss: 0.9498, LR: 0.00014000000000000001, Duration: 64.58 sec - model saved!


Epoch 3: 100%|██████████| 56/56 [01:01<00:00,  1.09s/it]


	Loss: 0.7927, LR: 0.00020999999999999998, Duration: 64.60 sec - model saved!


Epoch 4: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]


	Loss: 0.7450, LR: 0.00028000000000000003, Duration: 65.73 sec - model saved!


Epoch 5: 100%|██████████| 56/56 [01:00<00:00,  1.09s/it]


	Loss: 0.7209, LR: 0.00035, Duration: 64.21 sec - model saved!


Epoch 6: 100%|██████████| 56/56 [01:01<00:00,  1.09s/it]


	Loss: 0.6993, LR: 0.00041999999999999996, Duration: 64.63 sec - model saved!


Epoch 7: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.6814, LR: 0.00049, Duration: 65.06 sec - model saved!


Epoch 8: 100%|██████████| 56/56 [01:00<00:00,  1.09s/it]


	Loss: 0.6649, LR: 0.0005600000000000001, Duration: 64.20 sec - model saved!


Epoch 9: 100%|██████████| 56/56 [01:01<00:00,  1.11s/it]


	Loss: 0.6505, LR: 0.00063, Duration: 65.32 sec - model saved!


Epoch 10: 100%|██████████| 56/56 [01:01<00:00,  1.11s/it]


	Loss: 0.6422, LR: 0.0007, Duration: 65.27 sec - model saved!


Epoch 11: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.6378, LR: 0.0007700000000000001, Duration: 65.13 sec - model saved!


Epoch 12: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]


	Loss: 0.6111, LR: 0.0008399999999999999, Duration: 65.38 sec - model saved!


Epoch 13: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]

	Loss: 0.6135, LR: 0.00091, Duration: 62.78 sec



Epoch 14: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]

	Loss: 0.6209, LR: 0.00098, Duration: 62.91 sec



Epoch 15: 100%|██████████| 56/56 [01:02<00:00,  1.12s/it]

	Loss: 0.6176, LR: 0.00105, Duration: 63.75 sec



Epoch 16: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]

	Loss: 0.6133, LR: 0.0011200000000000001, Duration: 63.59 sec



Epoch 17: 100%|██████████| 56/56 [01:02<00:00,  1.12s/it]

	Loss: 0.6136, LR: 0.0011899999999999999, Duration: 63.89 sec



Epoch 18: 100%|██████████| 56/56 [01:02<00:00,  1.12s/it]


	Loss: 0.5983, LR: 0.00126, Duration: 65.96 sec - model saved!


Epoch 19: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.5936, LR: 0.00133, Duration: 64.66 sec - model saved!


Epoch 20: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]

	Loss: 0.6053, LR: 0.0014, Duration: 63.45 sec



Epoch 21: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.5890, LR: 0.0013994603253685062, Duration: 64.91 sec - model saved!


Epoch 22: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.5863, LR: 0.0013978421336131896, Duration: 64.71 sec - model saved!


Epoch 23: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]

	Loss: 0.5983, LR: 0.0013951479198684486, Duration: 62.91 sec



Epoch 24: 100%|██████████| 56/56 [01:01<00:00,  1.10s/it]


	Loss: 0.5707, LR: 0.0013913818384165965, Duration: 64.81 sec - model saved!


Epoch 25: 100%|██████████| 56/56 [01:02<00:00,  1.12s/it]

	Loss: 0.5763, LR: 0.0013865496962822614, Duration: 63.69 sec



Epoch 26: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]

	Loss: 0.5774, LR: 0.0013806589442783735, Duration: 63.59 sec



Epoch 27: 100%|██████████| 56/56 [01:02<00:00,  1.11s/it]


	Loss: 0.5624, LR: 0.001373718665517553, Duration: 65.70 sec - model saved!


Epoch 28:  14%|█▍        | 8/56 [00:10<00:54,  1.13s/it]

## Del SimMIM Model

In [None]:
model.cpu()
del model
torch.cuda.empty_cache()

# 5.Load Swin V2 for stage-2 training

In [None]:
model = SwinTransformerV2(pretrained_window_sizes=[12,12,12,12], 
                          ape=True, 
                          drop_path_rate=0.3,
                          embed_dim=192,
                          depths=[2,2,18,2],
                          num_heads=[6,12,24,48],
                          window_size=12)
model.state_dict().keys()

odict_keys(['absolute_pos_embed', 'embeddings.patch_embeddings.weight', 'embeddings.patch_embeddings.bias', 'embeddings.norm.weight', 'embeddings.norm.bias', 'stages.0.blocks.0.attn_mask', 'stages.0.blocks.0.attn.t_scale', 'stages.0.blocks.0.attn.relative_coords_table', 'stages.0.blocks.0.attn.relative_position_index', 'stages.0.blocks.0.attn.crpb_mlp.0.weight', 'stages.0.blocks.0.attn.crpb_mlp.0.bias', 'stages.0.blocks.0.attn.crpb_mlp.3.weight', 'stages.0.blocks.0.attn.qkv.weight', 'stages.0.blocks.0.attn.qkv.bias', 'stages.0.blocks.0.attn.proj.weight', 'stages.0.blocks.0.attn.proj.bias', 'stages.0.blocks.0.norm1.weight', 'stages.0.blocks.0.norm1.bias', 'stages.0.blocks.0.mlp.fc1.weight', 'stages.0.blocks.0.mlp.fc1.bias', 'stages.0.blocks.0.mlp.fc2.weight', 'stages.0.blocks.0.mlp.fc2.bias', 'stages.0.blocks.0.norm2.weight', 'stages.0.blocks.0.norm2.bias', 'stages.0.blocks.1.attn_mask', 'stages.0.blocks.1.attn.t_scale', 'stages.0.blocks.1.attn.relative_coords_table', 'stages.0.blocks.1

In [None]:
from torchsummary import summary

summary(model.to('cuda'), (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
         LayerNorm-2             [-1, 3136, 96]             192
        embeddings-3             [-1, 3136, 96]               0
           Dropout-4             [-1, 3136, 96]               0
            Linear-5              [-1, 49, 288]          27,936
            Linear-6          [-1, 13, 13, 384]           1,152
              ReLU-7          [-1, 13, 13, 384]               0
           Dropout-8          [-1, 13, 13, 384]               0
            Linear-9            [-1, 13, 13, 3]           1,152
          Softmax-10            [-1, 3, 49, 49]               0
          Dropout-11            [-1, 3, 49, 49]               0
           Linear-12               [-1, 49, 96]           9,312
          Dropout-13               [-1, 49, 96]               0
  WindowAttention-14               [-1,

## Default Parameter(weight) Check
- 추후 SimMIM 가중치가 제대로 불러와졌는지 확인용

In [None]:
model.state_dict()['embeddings.patch_embeddings.weight'][0]

tensor([[[-0.0129, -0.1077, -0.0841,  0.0873],
         [ 0.0951,  0.0538, -0.1324, -0.1053],
         [ 0.1133,  0.0925, -0.0437, -0.1434],
         [-0.0528, -0.1394, -0.0026,  0.0884]],

        [[-0.0144,  0.0523, -0.0949,  0.0638],
         [-0.0402,  0.0985,  0.0427,  0.0493],
         [-0.0610,  0.0717,  0.0322,  0.1303],
         [-0.0906, -0.0806,  0.0084,  0.1267]],

        [[-0.0225,  0.0842,  0.0438,  0.0556],
         [ 0.0726, -0.0971,  0.0465,  0.0960],
         [ 0.0120,  0.0960,  0.0975,  0.1130],
         [ 0.1074, -0.0488,  0.0056, -0.1428]]], device='cuda:0')

In [None]:
model.state_dict()['stages.3.blocks.1.attn.crpb_mlp.3.weight'][0]

tensor([ 4.5615e-03,  1.8758e-02,  1.1253e-02,  3.1161e-02, -2.9341e-02,
         1.6727e-02,  1.5575e-02, -6.6814e-03, -1.8727e-02,  3.3447e-02,
        -1.0352e-02, -3.8013e-03,  6.1451e-03,  1.2605e-02,  1.4256e-02,
         3.3592e-02, -1.0751e-02, -3.6015e-03,  2.7951e-02, -1.5929e-03,
        -3.5530e-02,  3.4795e-02,  1.0762e-02, -7.6697e-03,  3.6423e-03,
        -1.6669e-02,  7.6708e-03, -4.1405e-02,  6.0026e-04, -9.1808e-03,
         1.3536e-03,  3.4790e-02, -5.9329e-02, -5.1861e-03, -1.4435e-02,
         4.2382e-03,  2.3635e-03, -1.4914e-02,  1.2138e-02,  1.1349e-02,
         1.9588e-02,  5.6263e-04,  1.7838e-02,  2.1391e-02,  3.3820e-02,
        -1.7201e-02, -2.9766e-03, -3.0597e-03,  1.5437e-02, -2.3733e-02,
         3.9303e-02,  6.1092e-03,  1.2487e-02, -1.5646e-02, -8.8374e-03,
        -8.1489e-03, -3.6487e-03,  1.4805e-02,  3.5202e-02, -4.5491e-03,
        -1.2150e-03,  1.3884e-02, -4.7024e-03, -5.4003e-03, -8.8659e-03,
         6.6497e-03, -1.7442e-02, -1.0223e-02,  1.1

## Load Swin v2 config

In [17]:
swin_config = yaml.load(open('config/train.yaml'), Loader=yaml.FullLoader)
swin_config

{'MODEL': {'TYPE': 'swinv2',
  'NAME': 'simmim_train',
  'PRETRAINED': '../../models/swin2/simmim.pth',
  'DROP_PATH_RATE': 0.2,
  'SWIN': {'EMBED_DIM': 96,
   'DEPTHS': [2, 2, 6, 2],
   'NUM_HEADS': [3, 6, 12, 24],
   'WINDOW_SIZE': 7,
   'PATCH_SIZE': 4}},
 'DATA': {'IMG_SIZE': 224,
  'MASK_PATCH_SIZE': 32,
  'MASK_RATIO': 0.6,
  'BATCH_SIZE': 960,
  'NUM_WORKERS': 24,
  'DATA_PATH': '../../data/sports'},
 'TRAIN': {'EPOCHS': 20,
  'WARMUP_EPOCHS': 10,
  'BASE_LR': '1e-4',
  'WEIGHT_DECAY': 0.05,
  'CLIP_GRAD': 5}}

## Load weight from SimMIM Model
- Different Image/Window Size
- Image와 Window 사이즈의 비율은 맞춰야함
  ex) 192÷6 = 224÷7 = 32

In [18]:
def load_pretrained(config, model):
    print(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
    state_dict = torch.load(config.MODEL.PRETRAINED, map_location='cpu')

    # remain encoder only
    not_encoder_keys = [k for k in state_dict.keys() if 'encoder' not in k]
    for k in not_encoder_keys:
        del state_dict[k]
        
    # remove prefix encoder.
    state_dict = {k.replace('encoder.', ''):v for k, v in state_dict.items()}

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete relative_coords_table since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del state_dict[k]

    # bicubic interpolate relative_position_bias_table if not match
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
    for k in relative_position_bias_table_keys:
        relative_position_bias_table_pretrained = state_dict[k]
        relative_position_bias_table_current = model.state_dict()[k]
        L1, nH1 = relative_position_bias_table_pretrained.size()
        L2, nH2 = relative_position_bias_table_current.size()
        if nH1 != nH2:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                # bicubic interpolate relative_position_bias_table if not match
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
                    mode='bicubic')
                state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)

    # bicubic interpolate absolute_pos_embed if not match
    absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
    for k in absolute_pos_embed_keys:
        # dpe
        absolute_pos_embed_pretrained = state_dict[k]
        absolute_pos_embed_current = model.state_dict()[k.replace('encoder.','')]
        _, L1, C1 = absolute_pos_embed_pretrained.size()
        _, L2, C2 = absolute_pos_embed_current.size()
        if C1 != C1:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
                state_dict[k] = absolute_pos_embed_pretrained_resized

    # check classifier, if not match, then re-init classifier to zero
    head_bias_pretrained = state_dict['classifier.bias']
    Nc1 = head_bias_pretrained.shape[0]
    Nc2 = model.classifier.bias.shape[0]
    if (Nc1 != Nc2):
        torch.nn.init.constant_(model.classifier.bias, 0.)
        torch.nn.init.constant_(model.classifier.weight, 0.)
        del state_dict['classifier.weight']
        del state_dict['classifier.bias']
        print(f"Error in loading classifier head, re-init classifier head to 0")

    msg = model.load_state_dict(state_dict, strict=False)
    print(msg)

    print(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")

    torch.cuda.empty_cache()

In [19]:
swin_config = Box(swin_config)
load_pretrained(swin_config, model)

_IncompatibleKeys(missing_keys=['stages.0.blocks.0.attn_mask', 'stages.0.blocks.0.attn.relative_coords_table', 'stages.0.blocks.0.attn.relative_position_index', 'stages.0.blocks.1.attn_mask', 'stages.0.blocks.1.attn.relative_coords_table', 'stages.0.blocks.1.attn.relative_position_index', 'stages.1.blocks.0.attn_mask', 'stages.1.blocks.0.attn.relative_coords_table', 'stages.1.blocks.0.attn.relative_position_index', 'stages.1.blocks.1.attn_mask', 'stages.1.blocks.1.attn.relative_coords_table', 'stages.1.blocks.1.attn.relative_position_index', 'stages.2.blocks.0.attn_mask', 'stages.2.blocks.0.attn.relative_coords_table', 'stages.2.blocks.0.attn.relative_position_index', 'stages.2.blocks.1.attn_mask', 'stages.2.blocks.1.attn.relative_coords_table', 'stages.2.blocks.1.attn.relative_position_index', 'stages.2.blocks.2.attn_mask', 'stages.2.blocks.2.attn.relative_coords_table', 'stages.2.blocks.2.attn.relative_position_index', 'stages.2.blocks.3.attn_mask', 'stages.2.blocks.3.attn.relative_c

## Check Loading Weight Results
- 정상적으로 불러와졌는지 확인

In [20]:
model.state_dict()['embeddings.patch_embeddings.weight'][0]

tensor([[[ 0.0460,  0.0211, -0.0460,  0.0576],
         [ 0.1058,  0.0408,  0.1015, -0.0615],
         [-0.0989, -0.1073, -0.0074, -0.0704],
         [ 0.0441,  0.0788, -0.1280, -0.1076]],

        [[ 0.0988, -0.0679, -0.0443,  0.0764],
         [ 0.0506,  0.0706, -0.0776,  0.0877],
         [ 0.0787, -0.0059, -0.0263,  0.0769],
         [ 0.0604, -0.0411,  0.0631,  0.0655]],

        [[ 0.0299,  0.0698,  0.0270, -0.0802],
         [ 0.1213, -0.0450, -0.1119, -0.0372],
         [-0.0303,  0.0959,  0.0342,  0.0791],
         [ 0.0050,  0.0420, -0.1352, -0.1343]]], device='cuda:0')

In [21]:
model.state_dict()['stages.3.blocks.1.attn.crpb_mlp.3.weight'][0]

tensor([ 0.4102,  0.0190, -0.1695, -0.0780,  0.1185,  0.0560, -0.1155, -0.1401,
        -0.1512, -0.1898,  0.1406, -0.1439, -0.0065, -0.0821,  0.2524, -0.1698,
        -0.2071, -0.1644, -0.0385,  0.2512, -0.1601, -0.0571,  0.0558, -0.1932,
        -0.1750, -0.0531, -0.0358, -0.0926,  0.0338,  0.2469, -0.1665,  0.0107,
        -0.1085, -0.0957, -0.1816, -0.2193, -0.1730,  0.1014, -0.1871,  0.1377,
        -0.0585,  0.3111,  0.0440, -0.0796, -0.1696, -0.0710, -0.0333, -0.0791,
         0.0836,  0.2392, -0.1803, -0.0753, -0.1639, -0.0884,  0.0031, -0.1653,
        -0.0262,  0.2262, -0.1484, -0.1127, -0.0729, -0.1282, -0.0657,  0.1217,
        -0.1038,  0.1075,  0.0501, -0.1895, -0.0205, -0.0921, -0.0430, -0.0467,
        -0.2041, -0.1932,  0.2660,  0.1162,  0.0723,  0.0066,  0.0251, -0.1999,
        -0.1775, -0.1588, -0.1684, -0.0295,  0.0416, -0.2120, -0.0950, -0.0661,
         0.0866, -0.0630, -0.0387, -0.2299, -0.1044,  0.0472, -0.0744, -0.1581,
        -0.1818,  0.0387,  0.1348, -0.20

# 6.Stage-2 Traing
- Supervised pre-training

## Define Transform, Loss, etc.

In [22]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [23]:
# Transforms 정의하기
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8,1), interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.9, scale=(0.02, 0.33)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

data_dir = '../../data/sports'
batch_size = 960

train_path = data_dir+'/train'
valid_path = data_dir+'/valid'
test_path = data_dir+'/test'

# dataset load
train_data = ImageFolder(train_path, transform=train_transform)
valid_data = ImageFolder(valid_path, transform=test_transform)
test_data = ImageFolder(test_path, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [24]:
max_norm = 1.0 # paper : 100 with G variants

model.to(device)
model_path = '../../models/swin2/model_w_simmim.pth'

In [25]:
mixup_fn = Mixup(mixup_alpha=.7, 
                cutmix_alpha=.7, 
                prob=.7, 
                switch_prob=0.5, 
                mode='batch',
                label_smoothing=.1,
                num_classes=100)

epochs = 200

criterion = nn.CrossEntropyLoss(label_smoothing=0.)

## Layer-Wise Learning Rate Decay ★

In [26]:
layer_names = []
for i, (name, params) in enumerate(model.named_parameters()):
    lr = base_lr
    print(f'{i}: {name}')
    layer_names.append(name)

0: absolute_pos_embed
1: embeddings.patch_embeddings.weight
2: embeddings.patch_embeddings.bias
3: embeddings.norm.weight
4: embeddings.norm.bias
5: stages.0.blocks.0.attn.t_scale
6: stages.0.blocks.0.attn.crpb_mlp.0.weight
7: stages.0.blocks.0.attn.crpb_mlp.0.bias
8: stages.0.blocks.0.attn.crpb_mlp.3.weight
9: stages.0.blocks.0.attn.qkv.weight
10: stages.0.blocks.0.attn.qkv.bias
11: stages.0.blocks.0.attn.proj.weight
12: stages.0.blocks.0.attn.proj.bias
13: stages.0.blocks.0.norm1.weight
14: stages.0.blocks.0.norm1.bias
15: stages.0.blocks.0.mlp.fc1.weight
16: stages.0.blocks.0.mlp.fc1.bias
17: stages.0.blocks.0.mlp.fc2.weight
18: stages.0.blocks.0.mlp.fc2.bias
19: stages.0.blocks.0.norm2.weight
20: stages.0.blocks.0.norm2.bias
21: stages.0.blocks.1.attn.t_scale
22: stages.0.blocks.1.attn.crpb_mlp.0.weight
23: stages.0.blocks.1.attn.crpb_mlp.0.bias
24: stages.0.blocks.1.attn.crpb_mlp.3.weight
25: stages.0.blocks.1.attn.qkv.weight
26: stages.0.blocks.1.attn.qkv.bias
27: stages.0.blocks

In [27]:
layer_names.reverse()
layer_names[:5]

['classifier.bias',
 'classifier.weight',
 'layernorm.bias',
 'layernorm.weight',
 'stages.3.blocks.1.norm2.bias']

In [28]:
lr      = 1.4e-3      # paper : 1.4e-3
lr_mult = 0.87  # paper : 0.87
weight_decay = 0.01 # paper : 0.1

param_groups = []
prev_group_name = layer_names[0].split('.')[0]

for idx, name in enumerate(layer_names):
    
    cur_group_name = name.split('.')[0]
    
    if cur_group_name != prev_group_name:
        lr *= lr_mult
    prev_group_name = cur_group_name
    weight_decay = 0.01 if ('weight' in name) and ('norm' not in name) else 0
    
    print(f"{idx}: {name}'s lr={lr}, weight_decay={weight_decay}")
    
    param_groups += [{'params': [ p for n, p in model.named_parameters() if n == name and p.requires_grad],
                      'lr' : lr,
                      'weight_decay': weight_decay}]

0: classifier.bias's lr=0.0014, weight_decay=0
1: classifier.weight's lr=0.0014, weight_decay=0.01
2: layernorm.bias's lr=0.001218, weight_decay=0
3: layernorm.weight's lr=0.001218, weight_decay=0
4: stages.3.blocks.1.norm2.bias's lr=0.0010596599999999998, weight_decay=0
5: stages.3.blocks.1.norm2.weight's lr=0.0010596599999999998, weight_decay=0
6: stages.3.blocks.1.mlp.fc2.bias's lr=0.0010596599999999998, weight_decay=0
7: stages.3.blocks.1.mlp.fc2.weight's lr=0.0010596599999999998, weight_decay=0.01
8: stages.3.blocks.1.mlp.fc1.bias's lr=0.0010596599999999998, weight_decay=0
9: stages.3.blocks.1.mlp.fc1.weight's lr=0.0010596599999999998, weight_decay=0.01
10: stages.3.blocks.1.norm1.bias's lr=0.0010596599999999998, weight_decay=0
11: stages.3.blocks.1.norm1.weight's lr=0.0010596599999999998, weight_decay=0
12: stages.3.blocks.1.attn.proj.bias's lr=0.0010596599999999998, weight_decay=0
13: stages.3.blocks.1.attn.proj.weight's lr=0.0010596599999999998, weight_decay=0.01
14: stages.3.b

In [29]:
# # 모델 레이어의 이름 추출
# layer_names = []
# for i, (name, params) in enumerate(model.named_parameters()):
#     lr = base_lr
#     print(f'{i}: {name}')
#     layer_names.append(name)

# # 뒷 레이어부터 시작하도록 뒤집기    
# layer_names.reverse()

# # 하이퍼 파라미터 정의
# lr      = 1.4e-3      # paper : 1.4e-3
# lr_mult = 0.87  # paper : 0.87
# weight_decay = 0.01 # paper : 0.1

# param_groups = []
# prev_group_name = layer_names[0].split('.')[0] # 그룹명 초기화

# for idx, name in enumerate(layer_names):    
#     cur_group_name = name.split('.')[0]    
#     if cur_group_name != prev_group_name: # 동일한 그룹에 속하면 동일한 학습율
#         lr *= lr_mult
#     prev_group_name = cur_group_name    
    
#     param_groups += [{'params': [ p for n, p in model.named_parameters() if n == name and p.requires_grad],
#                       'lr' : lr,
#                       'weight_decay': weight_decay}]

In [30]:
optimizer = optim.AdamW(param_groups)
warmup_steps = int(len(train_loader)*(epochs)*0.1)
train_steps = len(train_loader)*(epochs)
scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, 
                                                        num_warmup_steps=warmup_steps, 
                                                        num_training_steps=train_steps,
                                                        num_cycles=0.5)

## Model Train
- 100에포크 먼저 학습하며 결과 확인하고, 이후 10에포크 학습하며 결과 확인

In [31]:
training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')

# GradScaler 초기화
scaler = GradScaler()

for epoch in range(50):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs, labels = mixup_fn(inputs, labels)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=max_norm)

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)        

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(valid_loader)
    val_losses.append(val_loss)
    
    total_loss = epoch_loss + val_loss
    
    # 모델 저장
    if total_loss < best_loss:
        best_loss = total_loss
        vit_save = True
        if vit_save:
            torch.save(model.state_dict(), model_path)

    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, Val Loss: {val_loss}, Total Loss: {total_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False

    print(text)
        
text = f"Epoch 당 평균 소요시간 : {training_time / 50:.2f}초"      
print(text)

Epoch 1: 100%|██████████| 15/15 [00:53<00:00,  3.58s/it]


	Loss: 4.648767407735189, Val Loss: 4.551398277282715, Total Loss: 9.200165685017904, LR: 7.000000000000001e-05, Duration: 55.21 sec - model saved!


Epoch 2: 100%|██████████| 15/15 [00:51<00:00,  3.44s/it]


In [None]:
from sklearn.metrics import confusion_matrix
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
# 예측 수행 및 레이블 저장
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 혼동 행렬 생성
cm = confusion_matrix(all_labels, all_preds)

# 예측과 실제 레이블
y_true = all_labels  # 실제 레이블
y_pred = all_preds  # 모델에 의해 예측된 레이블

# 전체 데이터셋에 대한 정확도
accuracy = accuracy_score(y_true, y_pred)

# 평균 정밀도, 리콜, F1-Score ('weighted')
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

# 판다스 데이터프레임으로 결과 정리
performance_metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1_score]
})

# 데이터프레임 출력
performance_metrics

  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,Metric,Value
0,Accuracy,0.73
1,Precision,0.761444
2,Recall,0.73
3,F1 Score,0.720376


### 100 Epoch Result

In [None]:
for epoch in range(50):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs, labels = mixup_fn(inputs, labels)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=max_norm)

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)        

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(valid_loader)
    val_losses.append(val_loss)
    
    total_loss = epoch_loss + val_loss
    
    # 모델 저장
    if total_loss < best_loss:
        best_loss = total_loss
        vit_save = True
        if vit_save:
            torch.save(model.state_dict(), model_path)

    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, Val Loss: {val_loss}, Total Loss: {total_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False

    print(text)

# 예측 수행 및 레이블 저장
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 혼동 행렬 생성
cm = confusion_matrix(all_labels, all_preds)

# 예측과 실제 레이블
y_true = all_labels  # 실제 레이블
y_pred = all_preds  # 모델에 의해 예측된 레이블

# 전체 데이터셋에 대한 정확도
accuracy = accuracy_score(y_true, y_pred)

# 평균 정밀도, 리콜, F1-Score ('weighted')
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

# 판다스 데이터프레임으로 결과 정리
performance_metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1_score]
})

# 데이터프레임 출력
performance_metrics

Epoch 1: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.656868060429891, Val Loss: 1.0866397619247437, Total Loss: 3.7435078223546348, LR: 0.0013000171104914787, Duration: 52.05 sec - model saved!


Epoch 2: 100%|██████████| 15/15 [00:51<00:00,  3.41s/it]


	Loss: 2.7411521037419635, Val Loss: 1.1920576095581055, Total Loss: 3.933209713300069, LR: 0.001293633667309498, Duration: 52.05 sec


Epoch 3: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.6991294225056968, Val Loss: 1.190779685974121, Total Loss: 3.889909108479818, LR: 0.001287069397561797, Duration: 51.69 sec


Epoch 4: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.269495209058126, Val Loss: 1.1089175939559937, Total Loss: 3.3784128030141196, LR: 0.001280326300788529, Duration: 51.91 sec - model saved!


Epoch 5: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.632977509498596, Val Loss: 1.0065799951553345, Total Loss: 3.6395575046539306, LR: 0.0012734064310022943, Duration: 51.72 sec


Epoch 6: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.4295517921447756, Val Loss: 0.9872391819953918, Total Loss: 3.4167909741401674, LR: 0.0012663118960624632, Duration: 51.63 sec


Epoch 7: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.657050069173177, Val Loss: 1.0636014938354492, Total Loss: 3.720651563008626, LR: 0.001259044857033105, Duration: 51.78 sec


Epoch 8: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.793322213490804, Val Loss: 1.0805952548980713, Total Loss: 3.8739174683888753, LR: 0.0012516075275247052, Duration: 51.65 sec


Epoch 9: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.4584624767303467, Val Loss: 0.9559261202812195, Total Loss: 3.414388597011566, LR: 0.0012440021730198796, Duration: 51.61 sec


Epoch 10: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 3.0136234680811564, Val Loss: 1.0421936511993408, Total Loss: 4.055817119280498, LR: 0.0012362311101832846, Duration: 52.00 sec


Epoch 11: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.4933032274246214, Val Loss: 0.9437807202339172, Total Loss: 3.4370839476585386, LR: 0.0012282967061559404, Duration: 51.65 sec


Epoch 12: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.4297483364741006, Val Loss: 1.0428630113601685, Total Loss: 3.472611347834269, LR: 0.001220201377834176, Duration: 51.58 sec


Epoch 13: 100%|██████████| 15/15 [00:51<00:00,  3.41s/it]


	Loss: 2.4320934375127155, Val Loss: 0.9670103192329407, Total Loss: 3.399103756745656, LR: 0.0012119475911334192, Duration: 52.07 sec


Epoch 14: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.5664584159851076, Val Loss: 0.9769777059555054, Total Loss: 3.543436121940613, LR: 0.0012035378602370558, Duration: 51.67 sec


Epoch 15: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 2.6650139888127646, Val Loss: 0.9653957486152649, Total Loss: 3.6304097374280295, LR: 0.0011949747468305832, Duration: 51.50 sec


Epoch 16: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.5750688632329304, Val Loss: 0.995498538017273, Total Loss: 3.5705674012502033, LR: 0.0011862608593212981, Duration: 51.75 sec


Epoch 17: 100%|██████████| 15/15 [00:51<00:00,  3.41s/it]


	Loss: 2.7154326597849527, Val Loss: 0.9883191585540771, Total Loss: 3.70375181833903, LR: 0.001177398852043749, Duration: 52.04 sec


Epoch 18: 100%|██████████| 15/15 [00:51<00:00,  3.41s/it]


	Loss: 2.2634711186091105, Val Loss: 0.904066801071167, Total Loss: 3.1675379196802775, LR: 0.0011683914244512007, Duration: 52.37 sec - model saved!


Epoch 19: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.256103507677714, Val Loss: 0.8322899341583252, Total Loss: 3.0883934418360393, LR: 0.001159241320293355, Duration: 52.02 sec - model saved!


Epoch 20: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.2764090061187745, Val Loss: 0.8936916589736938, Total Loss: 3.1701006650924684, LR: 0.0011499513267805774, Duration: 51.83 sec


Epoch 21: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.439480749766032, Val Loss: 0.8417661190032959, Total Loss: 3.281246868769328, LR: 0.0011405242737348863, Duration: 51.84 sec


Epoch 22: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.376136644681295, Val Loss: 0.8899009823799133, Total Loss: 3.266037627061208, LR: 0.0011309630327279608, Duration: 51.64 sec


Epoch 23: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.1310139020284016, Val Loss: 0.7905547022819519, Total Loss: 2.9215686043103535, LR: 0.0011212705162064339, Duration: 51.89 sec - model saved!


Epoch 24: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.1650134166081747, Val Loss: 0.8339634537696838, Total Loss: 2.9989768703778585, LR: 0.0011114496766047313, Duration: 51.78 sec


Epoch 25: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 2.33322536945343, Val Loss: 0.9288654923439026, Total Loss: 3.2620908617973328, LR: 0.0011015035054457321, Duration: 51.91 sec


Epoch 26: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 2.1799781719843545, Val Loss: 0.861397385597229, Total Loss: 3.0413755575815835, LR: 0.0010914350324295228, Duration: 51.87 sec


Epoch 27: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.6036650419235228, Val Loss: 0.94874107837677, Total Loss: 3.552406120300293, LR: 0.001081247324510519, Duration: 51.68 sec


Epoch 28: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.4560523907343548, Val Loss: 0.8359626531600952, Total Loss: 3.29201504389445, LR: 0.0010709434849632434, Duration: 51.67 sec


Epoch 29: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9994967619578043, Val Loss: 0.7379708886146545, Total Loss: 2.7374676505724587, LR: 0.001060526652437038, Duration: 52.04 sec - model saved!


Epoch 30: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 2.3076162656148274, Val Loss: 0.811880350112915, Total Loss: 3.1194966157277424, LR: 0.00105, Duration: 51.86 sec


Epoch 31: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.3585557381312054, Val Loss: 0.8148956298828125, Total Loss: 3.173451368014018, LR: 0.001039366734172436, Duration: 51.77 sec


Epoch 32: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8906003713607789, Val Loss: 0.7121273279190063, Total Loss: 2.602727699279785, LR: 0.0010286300939501235, Duration: 51.96 sec - model saved!


Epoch 33: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.45534127553304, Val Loss: 0.784798800945282, Total Loss: 3.240140076478322, LR: 0.0010177933498176828, Duration: 51.65 sec


Epoch 34: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.142392118771871, Val Loss: 0.744280219078064, Total Loss: 2.886672337849935, LR: 0.001006859802752354, Duration: 51.76 sec


Epoch 35: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.0640012900034588, Val Loss: 0.7880378365516663, Total Loss: 2.852039126555125, LR: 0.0009958327832184897, Duration: 51.77 sec


Epoch 36: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.1263386329015095, Val Loss: 0.7130172252655029, Total Loss: 2.8393558581670124, LR: 0.0009847156501530602, Duration: 51.69 sec


Epoch 37: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.9871632655461628, Val Loss: 0.7495275735855103, Total Loss: 2.736690839131673, LR: 0.0009735117899424916, Duration: 51.67 sec


Epoch 38: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.319912830988566, Val Loss: 0.7525551319122314, Total Loss: 3.0724679629007974, LR: 0.0009622246153911386, Duration: 51.80 sec


Epoch 39: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 1.8786864121754965, Val Loss: 0.7501755356788635, Total Loss: 2.62886194785436, LR: 0.0009508575646817101, Duration: 51.87 sec


Epoch 40: 100%|██████████| 15/15 [00:51<00:00,  3.41s/it]


	Loss: 2.2865551551183065, Val Loss: 0.7636071443557739, Total Loss: 3.0501622994740805, LR: 0.0009394141003279682, Duration: 52.01 sec


Epoch 41: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.3690157016118367, Val Loss: 0.7316689491271973, Total Loss: 3.100684650739034, LR: 0.0009278977081200097, Duration: 51.68 sec


Epoch 42: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.115553347269694, Val Loss: 0.7264184951782227, Total Loss: 2.841971842447917, LR: 0.0009163118960624632, Duration: 51.60 sec


Epoch 43: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.7769123713175456, Val Loss: 0.6899837255477905, Total Loss: 2.4668960968653364, LR: 0.0009046601933059157, Duration: 52.02 sec - model saved!


Epoch 44: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.16112121740977, Val Loss: 0.6956101655960083, Total Loss: 2.856731383005778, LR: 0.0008929461490718994, Duration: 51.79 sec


Epoch 45: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 2.3510693470637003, Val Loss: 0.7207252979278564, Total Loss: 3.0717946449915567, LR: 0.0008811733315717645, Duration: 51.92 sec


Epoch 46: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.1274606625239056, Val Loss: 0.7758668065071106, Total Loss: 2.9033274690310162, LR: 0.0008693453269197673, Duration: 51.56 sec


Epoch 47: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 2.0673088312149046, Val Loss: 0.7239354252815247, Total Loss: 2.7912442564964293, LR: 0.0008574657380407056, Duration: 51.97 sec


Epoch 48: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 1.9367709716161092, Val Loss: 0.6858772039413452, Total Loss: 2.6226481755574547, LR: 0.0008455381835724314, Duration: 51.83 sec


Epoch 49: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8003018697102864, Val Loss: 0.664739727973938, Total Loss: 2.465041597684224, LR: 0.0008335662967635814, Duration: 51.82 sec - model saved!


Epoch 50: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.802888774871826, Val Loss: 0.7137707471847534, Total Loss: 2.5166595220565795, LR: 0.0008215537243668514, Duration: 51.75 sec


  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,Metric,Value
0,Accuracy,0.838
1,Precision,0.862524
2,Recall,0.838
3,F1 Score,0.829496


In [None]:
for epoch in range(50):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs, labels = mixup_fn(inputs, labels)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=max_norm)

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)        

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(valid_loader)
    val_losses.append(val_loss)
    
    total_loss = epoch_loss + val_loss
    
    # 모델 저장
    if total_loss < best_loss:
        best_loss = total_loss
        vit_save = True
        if vit_save:
            torch.save(model.state_dict(), model_path)

    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, Val Loss: {val_loss}, Total Loss: {total_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False

    print(text)

# 예측 수행 및 레이블 저장
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 혼동 행렬 생성
cm = confusion_matrix(all_labels, all_preds)

# 예측과 실제 레이블
y_true = all_labels  # 실제 레이블
y_pred = all_preds  # 모델에 의해 예측된 레이블

# 전체 데이터셋에 대한 정확도
accuracy = accuracy_score(y_true, y_pred)

# 평균 정밀도, 리콜, F1-Score ('weighted')
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

# 판다스 데이터프레임으로 결과 정리
performance_metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1_score]
})

# 데이터프레임 출력
performance_metrics

Epoch 1: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.0989944458007814, Val Loss: 0.7974388599395752, Total Loss: 2.8964333057403566, LR: 0.0008095041255281617, Duration: 51.54 sec


Epoch 2: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.1859948714574178, Val Loss: 0.7359203696250916, Total Loss: 2.9219152410825093, LR: 0.0007974211706720458, Duration: 51.77 sec


Epoch 3: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.0909013986587524, Val Loss: 0.7298261523246765, Total Loss: 2.820727550983429, LR: 0.0007853085403836032, Duration: 51.68 sec


Epoch 4: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.980771509806315, Val Loss: 0.720022439956665, Total Loss: 2.7007939497629803, LR: 0.0007731699242873575, Duration: 51.68 sec


Epoch 5: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8991553703943889, Val Loss: 0.7231766581535339, Total Loss: 2.622332028547923, LR: 0.0007610090199233608, Duration: 51.74 sec


Epoch 6: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 2.0280269145965577, Val Loss: 0.6833068132400513, Total Loss: 2.711333727836609, LR: 0.0007488295316208876, Duration: 51.93 sec


Epoch 7: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.7484649022420247, Val Loss: 0.7237380146980286, Total Loss: 2.472202916940053, LR: 0.0007366351693700608, Duration: 51.66 sec


Epoch 8: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.7490078926086425, Val Loss: 0.6863779425621033, Total Loss: 2.435385835170746, LR: 0.0007244296476917508, Duration: 51.98 sec - model saved!


Epoch 9: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 2.4133824904759726, Val Loss: 0.7695233225822449, Total Loss: 3.1829058130582175, LR: 0.0007122166845060985, Duration: 51.90 sec


Epoch 10: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.989539376894633, Val Loss: 0.7407202124595642, Total Loss: 2.730259589354197, LR: 0.0007, Duration: 51.72 sec


Epoch 11: 100%|██████████| 15/15 [00:51<00:00,  3.42s/it]


	Loss: 2.04110627969106, Val Loss: 0.7122923731803894, Total Loss: 2.7533986528714496, LR: 0.0006877833154939015, Duration: 52.19 sec


Epoch 12: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9627915302912393, Val Loss: 0.7036316394805908, Total Loss: 2.6664231697718304, LR: 0.0006755703523082495, Duration: 51.79 sec


Epoch 13: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.215097181002299, Val Loss: 0.7650821805000305, Total Loss: 2.9801793615023295, LR: 0.0006633648306299393, Duration: 51.62 sec


Epoch 14: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.732415795326233, Val Loss: 0.6545288562774658, Total Loss: 2.3869446516036987, LR: 0.0006511704683791123, Duration: 51.93 sec - model saved!


Epoch 15: 100%|██████████| 15/15 [00:51<00:00,  3.40s/it]


	Loss: 2.1403245528539023, Val Loss: 0.7416877150535583, Total Loss: 2.8820122679074607, LR: 0.0006389909800766392, Duration: 51.93 sec


Epoch 16: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 2.3462406635284423, Val Loss: 0.6923199892044067, Total Loss: 3.038560652732849, LR: 0.0006268300757126426, Duration: 51.52 sec


Epoch 17: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9328386584917705, Val Loss: 0.7024814486503601, Total Loss: 2.635320107142131, LR: 0.0006146914596163969, Duration: 51.69 sec


Epoch 18: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8812280178070069, Val Loss: 0.7228487730026245, Total Loss: 2.604076790809631, LR: 0.0006025788293279544, Duration: 51.58 sec


Epoch 19: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9970248063405356, Val Loss: 0.7064083814620972, Total Loss: 2.7034331878026325, LR: 0.0005904958744718383, Duration: 51.78 sec


Epoch 20: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.5427856683731078, Val Loss: 0.6887787580490112, Total Loss: 2.231564426422119, LR: 0.0005784462756331488, Duration: 51.86 sec - model saved!


Epoch 21: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8379464387893676, Val Loss: 0.8101152181625366, Total Loss: 2.648061656951904, LR: 0.0005664337032364186, Duration: 51.80 sec


Epoch 22: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.21699279944102, Val Loss: 0.7475665211677551, Total Loss: 2.964559320608775, LR: 0.0005544618164275686, Duration: 51.82 sec


Epoch 23: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.005228034655253, Val Loss: 0.6987837553024292, Total Loss: 2.7040117899576823, LR: 0.0005425342619592945, Duration: 51.70 sec


Epoch 24: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.0032418886820476, Val Loss: 0.7424818873405457, Total Loss: 2.745723776022593, LR: 0.0005306546730802327, Duration: 51.70 sec


Epoch 25: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8295075654983521, Val Loss: 0.6974217295646667, Total Loss: 2.526929295063019, LR: 0.0005188266684282354, Duration: 51.62 sec


Epoch 26: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.776107676823934, Val Loss: 0.7711319923400879, Total Loss: 2.547239669164022, LR: 0.0005070538509281006, Duration: 51.64 sec


Epoch 27: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.6546895662943522, Val Loss: 0.74345862865448, Total Loss: 2.398148194948832, LR: 0.0004953398066940844, Duration: 51.84 sec


Epoch 28: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 2.0837464809417723, Val Loss: 0.7239661812782288, Total Loss: 2.807712662220001, LR: 0.0004836881039375369, Duration: 51.46 sec


Epoch 29: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.1784119606018066, Val Loss: 0.6969752907752991, Total Loss: 2.8753872513771057, LR: 0.00047210229187999046, Duration: 51.80 sec


Epoch 30: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9118168989817301, Val Loss: 0.7019824385643005, Total Loss: 2.6137993375460304, LR: 0.0004605858996720319, Duration: 51.73 sec


Epoch 31: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 2.0050169944763185, Val Loss: 0.7253676056861877, Total Loss: 2.7303846001625063, LR: 0.0004491424353182898, Duration: 51.71 sec


Epoch 32: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8579985698064168, Val Loss: 0.7260017991065979, Total Loss: 2.5840003689130144, LR: 0.0004377753846088615, Duration: 51.76 sec


Epoch 33: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.6451035817464192, Val Loss: 0.6968458294868469, Total Loss: 2.341949411233266, LR: 0.0004264882100575085, Duration: 51.80 sec


Epoch 34: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.5046961069107057, Val Loss: 0.7231835126876831, Total Loss: 2.2278796195983888, LR: 0.00041528434984693997, Duration: 51.93 sec - model saved!


Epoch 35: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 1.7769936362902323, Val Loss: 0.70062255859375, Total Loss: 2.477616194883982, LR: 0.00040416721678151053, Duration: 51.89 sec


Epoch 36: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.915289870897929, Val Loss: 0.7166441082954407, Total Loss: 2.6319339791933696, LR: 0.00039314019724764573, Duration: 51.78 sec


Epoch 37: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 2.111906611919403, Val Loss: 0.7452526092529297, Total Loss: 2.8571592211723327, LR: 0.0003822066501823173, Duration: 51.54 sec


Epoch 38: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8990045229593913, Val Loss: 0.7047130465507507, Total Loss: 2.6037175695101418, LR: 0.00037136990604987665, Duration: 51.67 sec


Epoch 39: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.7554991563161215, Val Loss: 0.7184057235717773, Total Loss: 2.4739048798878986, LR: 0.0003606332658275641, Duration: 51.53 sec


Epoch 40: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.9059346437454223, Val Loss: 0.703572154045105, Total Loss: 2.6095067977905275, LR: 0.00035000000000000016, Duration: 51.83 sec


Epoch 41: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.6268874049186706, Val Loss: 0.7178221940994263, Total Loss: 2.344709599018097, LR: 0.000339473347562962, Duration: 51.55 sec


Epoch 42: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.893630361557007, Val Loss: 0.7163046598434448, Total Loss: 2.6099350214004517, LR: 0.00032905651503675667, Duration: 51.51 sec


Epoch 43: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.9072073419888815, Val Loss: 0.7510281801223755, Total Loss: 2.658235522111257, LR: 0.00031875267548948103, Duration: 51.39 sec


Epoch 44: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.6641966144243876, Val Loss: 0.7268722653388977, Total Loss: 2.391068879763285, LR: 0.0003085649675704773, Duration: 51.63 sec


Epoch 45: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.7722946961720785, Val Loss: 0.7361909747123718, Total Loss: 2.5084856708844505, LR: 0.0002984964945542679, Duration: 51.63 sec


Epoch 46: 100%|██████████| 15/15 [00:50<00:00,  3.40s/it]


	Loss: 1.9539666215578715, Val Loss: 0.6926761269569397, Total Loss: 2.646642748514811, LR: 0.0002885503233952689, Duration: 51.89 sec


Epoch 47: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8854046662648518, Val Loss: 0.7984238862991333, Total Loss: 2.683828552563985, LR: 0.00027872948379356616, Duration: 51.78 sec


Epoch 48: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.3878305832544962, Val Loss: 0.7342281341552734, Total Loss: 2.12205871740977, LR: 0.0002690369672720392, Duration: 51.96 sec - model saved!


Epoch 49: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8025549252827961, Val Loss: 0.7834681272506714, Total Loss: 2.5860230525334673, LR: 0.0002594757262651139, Duration: 51.59 sec


Epoch 50: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.3593415141105651, Val Loss: 0.7738597393035889, Total Loss: 2.1332012534141542, LR: 0.00025004867321942243, Duration: 51.58 sec


Unnamed: 0,Metric,Value
0,Accuracy,0.86
1,Precision,0.880127
2,Recall,0.86
3,F1 Score,0.852054


In [None]:
for epoch in range(50):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs, labels = mixup_fn(inputs, labels)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=max_norm)

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)        

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(valid_loader)
    val_losses.append(val_loss)
    
    total_loss = epoch_loss + val_loss
    
    # 모델 저장
    if total_loss < best_loss:
        best_loss = total_loss
        vit_save = True
        if vit_save:
            torch.save(model.state_dict(), model_path)

    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, Val Loss: {val_loss}, Total Loss: {total_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False

    print(text)

# 예측 수행 및 레이블 저장
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 혼동 행렬 생성
cm = confusion_matrix(all_labels, all_preds)

# 예측과 실제 레이블
y_true = all_labels  # 실제 레이블
y_pred = all_preds  # 모델에 의해 예측된 레이블

# 전체 데이터셋에 대한 정확도
accuracy = accuracy_score(y_true, y_pred)

# 평균 정밀도, 리콜, F1-Score ('weighted')
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

# 판다스 데이터프레임으로 결과 정리
performance_metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1_score]
})

# 데이터프레임 출력
performance_metrics

Epoch 1: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.8682246923446655, Val Loss: 0.7779513597488403, Total Loss: 2.646176052093506, LR: 0.000240758679706645, Duration: 51.68 sec


Epoch 2: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.9862423857053122, Val Loss: 0.7408369183540344, Total Loss: 2.7270793040593464, LR: 0.00023160857554879947, Duration: 51.68 sec


Epoch 3: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.7568724592526754, Val Loss: 0.765451192855835, Total Loss: 2.5223236521085104, LR: 0.00022260114795625115, Duration: 51.58 sec


Epoch 4: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.6125226020812988, Val Loss: 0.7139967083930969, Total Loss: 2.3265193104743958, LR: 0.00021373914067870185, Duration: 51.80 sec


Epoch 5: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.512954306602478, Val Loss: 0.7604970932006836, Total Loss: 2.273451399803162, LR: 0.00020502525316941678, Duration: 51.60 sec


Epoch 6: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.6402256488800049, Val Loss: 0.7476833462715149, Total Loss: 2.3879089951515198, LR: 0.00019646213976294433, Duration: 51.63 sec


Epoch 7: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.6622195839881897, Val Loss: 0.7394412159919739, Total Loss: 2.4016607999801636, LR: 0.00018805240886658067, Duration: 51.79 sec


Epoch 8: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.501959216594696, Val Loss: 0.7771209478378296, Total Loss: 2.279080164432526, LR: 0.00017979862216582396, Duration: 51.42 sec


Epoch 9: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.9299867312113443, Val Loss: 0.703926682472229, Total Loss: 2.6339134136835733, LR: 0.0001717032938440596, Duration: 51.64 sec


Epoch 10: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 2.096792455514272, Val Loss: 0.8021771907806396, Total Loss: 2.8989696462949115, LR: 0.00016376888981671546, Duration: 51.53 sec


Epoch 11: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.7682509620984395, Val Loss: 0.7559312582015991, Total Loss: 2.5241822203000384, LR: 0.00015599782698012037, Duration: 51.58 sec


Epoch 12: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.64539532661438, Val Loss: 0.7952052354812622, Total Loss: 2.440600562095642, LR: 0.00014839247247529466, Duration: 51.46 sec


Epoch 13: 100%|██████████| 15/15 [00:50<00:00,  3.39s/it]


	Loss: 1.8098185380299887, Val Loss: 0.7582208514213562, Total Loss: 2.5680393894513447, LR: 0.00014095514296689517, Duration: 51.83 sec


Epoch 14: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.7548200289408367, Val Loss: 0.7385174036026001, Total Loss: 2.493337432543437, LR: 0.00013368810393753685, Duration: 51.62 sec


Epoch 15: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.7015121698379516, Val Loss: 0.7559743523597717, Total Loss: 2.4574865221977236, LR: 0.00012659356899770565, Duration: 51.43 sec


Epoch 16: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 1.7704113443692526, Val Loss: 0.7706521153450012, Total Loss: 2.5410634597142536, LR: 0.00011967369921147086, Duration: 51.53 sec


Epoch 17: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.9385425845781963, Val Loss: 0.7434729337692261, Total Loss: 2.6820155183474226, LR: 0.00011293060243820324, Duration: 51.69 sec


Epoch 18: 100%|██████████| 15/15 [00:50<00:00,  3.38s/it]


	Loss: 1.601182222366333, Val Loss: 0.7555084824562073, Total Loss: 2.3566907048225403, LR: 0.00010636633269050183, Duration: 51.57 sec


Epoch 19:   0%|          | 0/15 [00:00<?, ?it/s]