In [1]:
!pip install python-box

[0m

In [2]:
import numpy as np
import yaml
from box import Box

import torch
import torch.nn as nn
import torch.optim as optim

import simmim
from swin_v2 import SwinTransformerV2

from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

from timm.data import Mixup
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
simmim_config = yaml.load(open('config/pretrain.yaml'), Loader=yaml.FullLoader)
simmim_config

{'MODEL': {'TYPE': 'swinv2',
  'NAME': 'simmim_pretrain',
  'DROP_PATH_RATE': 0.1,
  'SIMMIM': {'NORM_TARGET': {'ENABLE': True, 'PATCH_SIZE': 47}},
  'SWIN': {'EMBED_DIM': 96,
   'DEPTHS': [2, 2, 6, 2],
   'NUM_HEADS': [3, 6, 12, 24],
   'WINDOW_SIZE': 7,
   'PATCH_SIZE': 4}},
 'DATA': {'IMG_SIZE': 224,
  'MASK_PATCH_SIZE': 32,
  'MASK_RATIO': 0.6,
  'BATCH_SIZE': 512,
  'NUM_WORKERS': 24,
  'DATA_PATH': '../../data/sports/train'},
 'TRAIN': {'EPOCHS': 100,
  'WARMUP_EPOCHS': 10,
  'BASE_LR': '1e-4',
  'WARMUP_LR': '5e-7',
  'WEIGHT_DECAY': 0.05,
  'CLIP_GRAD': 5,
  'LR_SCHEDULER': {'NAME': 'multistep', 'GAMMA': 0.1, 'MULTISTEPS': [70]}},
 'PRINT_FREQ': 100,
 'SAVE_FREQ': 5,
 'TAG': 'simmim_pretrain__swinv2_tiny__img224_window7__800ep'}

In [4]:
encoder_config = {'img_size':simmim_config['DATA']['IMG_SIZE'], 
                'patch_size':simmim_config['MODEL']['SWIN']['PATCH_SIZE'], 
                'in_chans':3, 
                'num_classes':100,
                'embed_dim':simmim_config['MODEL']['SWIN']['EMBED_DIM'], 
                'depths':simmim_config['MODEL']['SWIN']['DEPTHS'], 
                'num_heads':simmim_config['MODEL']['SWIN']['NUM_HEADS'],           
                'window_size':simmim_config['MODEL']['SWIN']['WINDOW_SIZE'], 
                'mlp_ratio':4., 
                'qkv_bias':True, 
                'qk_scale':None,
                'drop_rate':0., 
                'attn_drop_rate':0., 
                'drop_path_rate':simmim_config['MODEL']['DROP_PATH_RATE'],
                'norm_layer':nn.LayerNorm, 
                'patch_norm':True, 
                'pretrained_window_sizes':[0,0,0,0],
                'ape':True}

encoder_stride = 32
in_chans = encoder_config['in_chans']
patch_size = encoder_config['patch_size']

In [5]:
encoder = simmim.SwinTransformerV2ForSimMIM(**encoder_config)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
model = simmim.SimMIM( encoder=encoder, 
                       encoder_stride=encoder_stride, 
                       in_chans=in_chans, 
                       patch_size=patch_size)

In [7]:
simmim_config = Box(simmim_config)
dataloader = simmim.build_loader_simmim(simmim_config)

samples = next(iter(dataloader))
len(samples)

3

In [8]:
samples[0].shape, samples[1].shape, samples[2].shape 

(torch.Size([512, 3, 224, 224]), torch.Size([512, 56, 56]), torch.Size([512]))

In [9]:
base_lr = float(simmim_config.TRAIN.BASE_LR)
weight_decay = simmim_config.TRAIN.WEIGHT_DECAY
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
warmup_epochs = simmim_config.TRAIN.WARMUP_EPOCHS
train_epochs = simmim_config.TRAIN.EPOCHS

multisteps = simmim_config.TRAIN.LR_SCHEDULER.MULTISTEPS
gamma = simmim_config.TRAIN.LR_SCHEDULER.GAMMA
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=multisteps, gamma=gamma)

# LambdaLR 스케줄러 설정
lambda1 = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else 1 # Warmup을 위한 Lambda 함수 정의
scheduler_warmup = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

# MultiStepLR 스케줄러 설정
scheduler_multistep = optim.lr_scheduler.MultiStepLR(optimizer, milestones=multisteps, gamma=gamma)

In [10]:
device = 'cuda:2'
model.to(device)

model_save = False
model_path = '../../models/swin2/simmim.pth'

In [11]:
training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')

# GradScaler 초기화
scaler = GradScaler()

for epoch in range(train_epochs):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        image, mask = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            loss = model(image, mask)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        if simmim_config.TRAIN.CLIP_GRAD:
            clip_grad_norm_(model.parameters(), max_norm=simmim_config.TRAIN.CLIP_GRAD)
        else:
            clip_grad_norm_(model.parameters())

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        if epoch <= warmup_epochs:
            scheduler_warmup.step()
        else:
            scheduler_multistep.step()
        # scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    losses.append(epoch_loss)

    # 모델 저장
    if epoch_loss < best_loss:
        
        best_loss = epoch_loss
        vit_save = model_save
        if vit_save:
            torch.save(model.state_dict(), model_path)
        
    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss:.4f}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False    
        
    print(text)


Epoch 1: 100%|██████████| 26/26 [00:23<00:00,  1.09it/s]

	Loss: 1.1252, LR: 0.0001, Duration: 24.44 sec



Epoch 2: 100%|██████████| 26/26 [00:22<00:00,  1.18it/s]

	Loss: 1.0885, LR: 0.0001, Duration: 22.79 sec



Epoch 3: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.9282, LR: 0.0001, Duration: 22.18 sec



Epoch 4: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.8354, LR: 0.0001, Duration: 22.70 sec



Epoch 5: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.7654, LR: 0.0001, Duration: 22.65 sec



Epoch 6: 100%|██████████| 26/26 [00:20<00:00,  1.27it/s]

	Loss: 0.7284, LR: 0.0001, Duration: 21.31 sec



Epoch 7: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.7041, LR: 0.0001, Duration: 22.10 sec



Epoch 8: 100%|██████████| 26/26 [00:20<00:00,  1.25it/s]

	Loss: 0.6786, LR: 0.0001, Duration: 21.62 sec



Epoch 9: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.6711, LR: 0.0001, Duration: 22.42 sec



Epoch 10: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.6566, LR: 0.0001, Duration: 21.89 sec



Epoch 11: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.6451, LR: 0.0001, Duration: 22.73 sec



Epoch 12: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.6471, LR: 0.0001, Duration: 21.84 sec



Epoch 13: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.6455, LR: 0.0001, Duration: 23.08 sec



Epoch 14: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.6337, LR: 1e-05, Duration: 22.75 sec



Epoch 15: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.6191, LR: 1e-05, Duration: 21.93 sec



Epoch 16: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.6167, LR: 1e-05, Duration: 22.41 sec



Epoch 17: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.6149, LR: 1e-05, Duration: 22.09 sec



Epoch 18: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.6132, LR: 1e-05, Duration: 23.00 sec



Epoch 19: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.6121, LR: 1e-05, Duration: 23.00 sec



Epoch 20: 100%|██████████| 26/26 [00:20<00:00,  1.27it/s]

	Loss: 0.6124, LR: 1e-05, Duration: 21.26 sec



Epoch 21: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.6103, LR: 1e-05, Duration: 21.94 sec



Epoch 22: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.6090, LR: 1e-05, Duration: 22.61 sec



Epoch 23: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.6075, LR: 1e-05, Duration: 23.16 sec



Epoch 24: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.6064, LR: 1e-05, Duration: 23.14 sec



Epoch 25: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.6050, LR: 1e-05, Duration: 22.59 sec



Epoch 26: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.6043, LR: 1e-05, Duration: 22.32 sec



Epoch 27: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.6029, LR: 1e-05, Duration: 22.02 sec



Epoch 28: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.6009, LR: 1e-05, Duration: 22.57 sec



Epoch 29: 100%|██████████| 26/26 [00:21<00:00,  1.24it/s]

	Loss: 0.6017, LR: 1e-05, Duration: 21.78 sec



Epoch 30: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.6002, LR: 1e-05, Duration: 22.73 sec



Epoch 31: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5986, LR: 1e-05, Duration: 22.47 sec



Epoch 32: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.5971, LR: 1e-05, Duration: 22.10 sec



Epoch 33: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.5962, LR: 1e-05, Duration: 22.73 sec



Epoch 34: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5949, LR: 1e-05, Duration: 22.28 sec



Epoch 35: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5943, LR: 1e-05, Duration: 23.03 sec



Epoch 36: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5937, LR: 1e-05, Duration: 23.07 sec



Epoch 37: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5935, LR: 1e-05, Duration: 21.84 sec



Epoch 38: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5915, LR: 1e-05, Duration: 22.21 sec



Epoch 39: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5918, LR: 1e-05, Duration: 23.08 sec



Epoch 40: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5910, LR: 1e-05, Duration: 22.19 sec



Epoch 41: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.5917, LR: 1e-05, Duration: 23.13 sec



Epoch 42: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5883, LR: 1e-05, Duration: 22.20 sec



Epoch 43: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5884, LR: 1e-05, Duration: 22.26 sec



Epoch 44: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5881, LR: 1e-05, Duration: 22.39 sec



Epoch 45: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.5882, LR: 1e-05, Duration: 23.17 sec



Epoch 46: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5863, LR: 1e-05, Duration: 22.91 sec



Epoch 47: 100%|██████████| 26/26 [00:21<00:00,  1.18it/s]

	Loss: 0.5865, LR: 1e-05, Duration: 22.77 sec



Epoch 48: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5859, LR: 1e-05, Duration: 22.25 sec



Epoch 49: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.5849, LR: 1e-05, Duration: 22.54 sec



Epoch 50: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5855, LR: 1e-05, Duration: 22.95 sec



Epoch 51: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5853, LR: 1e-05, Duration: 22.20 sec



Epoch 52: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5845, LR: 1e-05, Duration: 22.93 sec



Epoch 53: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5830, LR: 1e-05, Duration: 22.92 sec



Epoch 54: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5829, LR: 1e-05, Duration: 22.21 sec



Epoch 55: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5821, LR: 1e-05, Duration: 22.51 sec



Epoch 56: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5818, LR: 1e-05, Duration: 23.08 sec



Epoch 57: 100%|██████████| 26/26 [00:22<00:00,  1.18it/s]

	Loss: 0.5812, LR: 1e-05, Duration: 22.79 sec



Epoch 58: 100%|██████████| 26/26 [00:20<00:00,  1.26it/s]

	Loss: 0.5813, LR: 1e-05, Duration: 21.46 sec



Epoch 59: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5811, LR: 1e-05, Duration: 22.38 sec



Epoch 60: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.5794, LR: 1e-05, Duration: 21.99 sec



Epoch 61: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5799, LR: 1e-05, Duration: 22.95 sec



Epoch 62: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5787, LR: 1e-05, Duration: 22.99 sec



Epoch 63: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5782, LR: 1e-05, Duration: 21.93 sec



Epoch 64: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5783, LR: 1e-05, Duration: 22.52 sec



Epoch 65: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5794, LR: 1e-05, Duration: 22.29 sec



Epoch 66: 100%|██████████| 26/26 [00:20<00:00,  1.26it/s]

	Loss: 0.5774, LR: 1e-05, Duration: 21.55 sec



Epoch 67: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.5760, LR: 1e-05, Duration: 22.16 sec



Epoch 68: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.5762, LR: 1e-05, Duration: 22.58 sec



Epoch 69: 100%|██████████| 26/26 [00:20<00:00,  1.27it/s]

	Loss: 0.5760, LR: 1e-05, Duration: 21.19 sec



Epoch 70: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5764, LR: 1e-05, Duration: 21.83 sec



Epoch 71: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5753, LR: 1e-05, Duration: 21.91 sec



Epoch 72: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.5743, LR: 1e-05, Duration: 23.25 sec



Epoch 73: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5748, LR: 1e-05, Duration: 22.34 sec



Epoch 74: 100%|██████████| 26/26 [00:21<00:00,  1.22it/s]

	Loss: 0.5740, LR: 1e-05, Duration: 22.05 sec



Epoch 75: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5738, LR: 1e-05, Duration: 22.34 sec



Epoch 76: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5728, LR: 1e-05, Duration: 23.04 sec



Epoch 77: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5732, LR: 1e-05, Duration: 23.02 sec



Epoch 78: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.5725, LR: 1e-05, Duration: 22.62 sec



Epoch 79: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5732, LR: 1e-05, Duration: 22.20 sec



Epoch 80: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5729, LR: 1e-05, Duration: 21.90 sec



Epoch 81: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5727, LR: 1e-05, Duration: 22.29 sec



Epoch 82: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5718, LR: 1e-05, Duration: 22.38 sec



Epoch 83: 100%|██████████| 26/26 [00:20<00:00,  1.27it/s]

	Loss: 0.5715, LR: 1e-05, Duration: 21.26 sec



Epoch 84: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5709, LR: 1e-05, Duration: 22.90 sec



Epoch 85: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5701, LR: 1e-05, Duration: 21.83 sec



Epoch 86: 100%|██████████| 26/26 [00:22<00:00,  1.18it/s]

	Loss: 0.5698, LR: 1e-05, Duration: 22.83 sec



Epoch 87: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.5703, LR: 1e-05, Duration: 23.13 sec



Epoch 88: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5694, LR: 1e-05, Duration: 23.04 sec



Epoch 89: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5699, LR: 1e-05, Duration: 22.96 sec



Epoch 90: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5682, LR: 1e-05, Duration: 23.01 sec



Epoch 91: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5681, LR: 1e-05, Duration: 22.90 sec



Epoch 92: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5687, LR: 1e-05, Duration: 22.87 sec



Epoch 93: 100%|██████████| 26/26 [00:21<00:00,  1.23it/s]

	Loss: 0.5691, LR: 1e-05, Duration: 21.88 sec



Epoch 94: 100%|██████████| 26/26 [00:22<00:00,  1.16it/s]

	Loss: 0.5679, LR: 1e-05, Duration: 23.12 sec



Epoch 95: 100%|██████████| 26/26 [00:20<00:00,  1.25it/s]

	Loss: 0.5673, LR: 1e-05, Duration: 21.59 sec



Epoch 96: 100%|██████████| 26/26 [00:22<00:00,  1.17it/s]

	Loss: 0.5684, LR: 1e-05, Duration: 22.88 sec



Epoch 97: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5663, LR: 1e-05, Duration: 22.43 sec



Epoch 98: 100%|██████████| 26/26 [00:21<00:00,  1.19it/s]

	Loss: 0.5656, LR: 1e-05, Duration: 22.57 sec



Epoch 99: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s]

	Loss: 0.5651, LR: 1e-05, Duration: 22.56 sec



Epoch 100: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s]

	Loss: 0.5660, LR: 1e-05, Duration: 22.34 sec





In [12]:
torch.save(model.state_dict(), model_path)

In [None]:
device = 'cuda:2'

model_save = False
model_path = '../../models/swin2/simmim.pth'

In [15]:
swin = SwinTransformerV2(pretrained_window_sizes=[7,7,7,7], ape=True)
swin.load_state_dict(model.encoder.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['mask_token'])

In [None]:
swin = SwinTransformerV2(pretrained_window_sizes=[6,6,6,6], ape=True)

In [None]:
swin_config = yaml.load(open('config/train.yaml'), Loader=yaml.FullLoader)
swin_config

In [None]:
config = Box(swin_config)
model = swin

print(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
state_dict = torch.load(config.MODEL.PRETRAINED, map_location='cpu')

# remain encoder only
not_encoder_keys = [k for k in state_dict.keys() if 'encoder' not in k]
for k in not_encoder_keys:
    del state_dict[k]

# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
    del state_dict[k]

# delete relative_coords_table since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
for k in relative_position_index_keys:
    del state_dict[k]

# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
    del state_dict[k]

# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
    relative_position_bias_table_pretrained = state_dict[k]
    relative_position_bias_table_current = model.state_dict()[k]
    L1, nH1 = relative_position_bias_table_pretrained.size()
    L2, nH2 = relative_position_bias_table_current.size()
    if nH1 != nH2:
        print(f"Error in loading {k}, passing......")
    else:
        if L1 != L2:
            # bicubic interpolate relative_position_bias_table if not match
            S1 = int(L1 ** 0.5)
            S2 = int(L2 ** 0.5)
            relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
                mode='bicubic')
            state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)

# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
for k in absolute_pos_embed_keys:
    # dpe
    absolute_pos_embed_pretrained = state_dict[k]
    absolute_pos_embed_current = model.state_dict()[k.replace('encoder.','')]
    _, L1, C1 = absolute_pos_embed_pretrained.size()
    _, L2, C2 = absolute_pos_embed_current.size()
    if C1 != C1:
        print(f"Error in loading {k}, passing......")
    else:
        if L1 != L2:
            S1 = int(L1 ** 0.5)
            S2 = int(L2 ** 0.5)
            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
            absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
            absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
            absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
            absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
            state_dict[k] = absolute_pos_embed_pretrained_resized

# check classifier, if not match, then re-init classifier to zero
head_bias_pretrained = state_dict['encoder.classifier.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.classifier.bias.shape[0]
if (Nc1 != Nc2):
    torch.nn.init.constant_(model.classifier.bias, 0.)
    torch.nn.init.constant_(model.classifier.weight, 0.)
    del state_dict['encoder.classifier.weight']
    del state_dict['encoder.classifier.bias']
    print(f"Error in loading classifier head, re-init classifier head to 0")

msg = model.load_state_dict(state_dict, strict=False)
print(msg)

print(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")

torch.cuda.empty_cache()

In [None]:
def load_pretrained(config, model):
    print(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
    state_dict = torch.load(config.MODEL.PRETRAINED, map_location='cpu')

    # remain encoder only
    not_encoder_keys = [k for k in state_dict.keys() if 'encoder' not in k]
    for k in not_encoder_keys:
        del state_dict[k]

    # delete relative_position_index since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete relative_coords_table since we always re-init it
    relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
    for k in relative_position_index_keys:
        del state_dict[k]

    # delete attn_mask since we always re-init it
    attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
    for k in attn_mask_keys:
        del state_dict[k]

    # bicubic interpolate relative_position_bias_table if not match
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
    for k in relative_position_bias_table_keys:
        relative_position_bias_table_pretrained = state_dict[k]
        relative_position_bias_table_current = model.state_dict()[k]
        L1, nH1 = relative_position_bias_table_pretrained.size()
        L2, nH2 = relative_position_bias_table_current.size()
        if nH1 != nH2:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                # bicubic interpolate relative_position_bias_table if not match
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                    relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
                    mode='bicubic')
                state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)

    # bicubic interpolate absolute_pos_embed if not match
    absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
    for k in absolute_pos_embed_keys:
        # dpe
        absolute_pos_embed_pretrained = state_dict[k]
        absolute_pos_embed_current = model.state_dict()[k.replace('encoder.','')]
        _, L1, C1 = absolute_pos_embed_pretrained.size()
        _, L2, C2 = absolute_pos_embed_current.size()
        if C1 != C1:
            print(f"Error in loading {k}, passing......")
        else:
            if L1 != L2:
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
                absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
                absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
                    absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
                absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
                state_dict[k] = absolute_pos_embed_pretrained_resized

    # check classifier, if not match, then re-init classifier to zero
    head_bias_pretrained = state_dict['encoder.classifier.bias']
    Nc1 = head_bias_pretrained.shape[0]
    Nc2 = model.classifier.bias.shape[0]
    if (Nc1 != Nc2):
        torch.nn.init.constant_(model.classifier.bias, 0.)
        torch.nn.init.constant_(model.classifier.weight, 0.)
        del state_dict['encoder.classifier.weight']
        del state_dict['encoder.classifier.bias']
        print(f"Error in loading classifier head, re-init classifier head to 0")

    msg = model.load_state_dict(state_dict, strict=False)
    print(msg)

    print(f"=> loaded successfully '{config.MODEL.PRETRAINED}'")

    torch.cuda.empty_cache()

In [None]:
swin_config = Box(swin_config)
load_pretrained(swin_config, swin)

In [None]:
swin.load_state_dict(model.encoder.state_dict(), strict=False)

In [16]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [17]:
# Transforms 정의하기
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8,1), interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.9, scale=(0.02, 0.33)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

data_dir = '../../data/sports'
batch_size = 960

train_path = data_dir+'/train'
valid_path = data_dir+'/valid'
test_path = data_dir+'/test'

# dataset load
train_data = ImageFolder(train_path, transform=train_transform)
valid_data = ImageFolder(valid_path, transform=test_transform)
test_data = ImageFolder(test_path, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [18]:
max_norm = 5.0

swin.to(device)

SwinTransformerV2(
  (embeddings): embeddings(
    (patch_embeddings): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (stages): ModuleList(
    (0): StageLayer(
      (blocks): ModuleList(
        (0-1): 2 x SwinTransformerBlock(
          (attn): WindowAttention(
            (crpb_mlp): Sequential(
              (0): Linear(in_features=2, out_features=384, bias=True)
              (1): ReLU(inplace=True)
              (2): Dropout(p=0.125, inplace=False)
              (3): Linear(in_features=384, out_features=3, bias=False)
            )
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
  

In [19]:
mixup_fn = Mixup(mixup_alpha=1., 
                cutmix_alpha=1., 
                prob=1., 
                switch_prob=0.5, 
                mode='batch',
                label_smoothing=.1,
                num_classes=100)

epochs = 100

criterion = nn.CrossEntropyLoss(label_smoothing=0.)
optimizer = optim.AdamW(swin.parameters(), lr=1e-3, weight_decay=5e-3)
warmup_steps = int(len(train_loader)*epochs*0.1)
train_steps = len(train_loader)*epochs
scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, 
                                                        num_warmup_steps=warmup_steps, 
                                                        num_training_steps=train_steps,
                                                        num_cycles=0.5)

2024-01-10 15:48:00.880244: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-10 15:48:00.880312: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-10 15:48:00.881078: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-10 15:48:00.886287: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [20]:
training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')

# GradScaler 초기화
scaler = GradScaler()

for epoch in range(epochs):
    swin.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs, labels = mixup_fn(inputs, labels)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            outputs = swin(inputs)
            loss = criterion(outputs, labels)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        clip_grad_norm_(swin.parameters(), max_norm=max_norm)

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)        

    swin.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in valid_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = swin(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    val_loss /= len(valid_loader)
    val_losses.append(val_loss)
    
    # 모델 저장
    if val_loss < best_loss:
        best_loss = val_loss
        # vit_save = True
        # if vit_save:
        #     torch.save(swin.state_dict(), )

    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, Val Loss: {val_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    # if vit_save:
    #     text += f' - swin saved!'
    #     vit_save = False

    print(text)
        
text = f"Epoch 당 평균 소요시간 : {training_time / epochs:.2f}초"      
print(text)

Epoch 1: 100%|██████████| 15/15 [00:51<00:00,  3.44s/it]


	Loss: 4.5992762565612795, Val Loss: 4.361097812652588, LR: 0.0001, Duration: 52.77 sec


Epoch 2: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 4.420850372314453, Val Loss: 4.1072845458984375, LR: 0.0002, Duration: 51.46 sec


Epoch 3: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.375983746846517, Val Loss: 4.015015125274658, LR: 0.0003, Duration: 51.04 sec


Epoch 4: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.330751005808512, Val Loss: 3.967747449874878, LR: 0.0004, Duration: 51.00 sec


Epoch 5: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.274534289042155, Val Loss: 3.9106011390686035, LR: 0.0005, Duration: 50.97 sec


Epoch 6: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.291929562886556, Val Loss: 3.8842830657958984, LR: 0.0006, Duration: 50.95 sec


Epoch 7: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.255568885803223, Val Loss: 3.8145711421966553, LR: 0.0007, Duration: 51.07 sec


Epoch 8: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 4.206695016225179, Val Loss: 3.7762258052825928, LR: 0.0008, Duration: 51.15 sec


Epoch 9: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 4.162663761774699, Val Loss: 3.6490609645843506, LR: 0.0009000000000000001, Duration: 51.12 sec


Epoch 10: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 4.142976347605387, Val Loss: 3.63301944732666, LR: 0.001, Duration: 50.84 sec


Epoch 11: 100%|██████████| 15/15 [00:50<00:00,  3.33s/it]


	Loss: 4.138050826390584, Val Loss: 3.5600311756134033, LR: 0.0009996954135095479, Duration: 50.92 sec


Epoch 12: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.068909454345703, Val Loss: 3.5125815868377686, LR: 0.0009987820251299122, Duration: 50.98 sec


Epoch 13: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.111667633056641, Val Loss: 3.3922088146209717, LR: 0.0009972609476841367, Duration: 50.97 sec


Epoch 14: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 4.033602094650268, Val Loss: 3.2706634998321533, LR: 0.0009951340343707852, Duration: 51.01 sec


Epoch 15: 100%|██████████| 15/15 [00:50<00:00,  3.36s/it]


	Loss: 3.997556479771932, Val Loss: 3.1662673950195312, LR: 0.000992403876506104, Duration: 51.23 sec


Epoch 16: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.972299083073934, Val Loss: 3.2309482097625732, LR: 0.0009890738003669028, Duration: 51.06 sec


Epoch 17: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.935915470123291, Val Loss: 3.173598527908325, LR: 0.0009851478631379982, Duration: 50.85 sec


Epoch 18: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.876936690012614, Val Loss: 2.978503704071045, LR: 0.0009806308479691594, Duration: 50.98 sec


Epoch 19: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.782620573043823, Val Loss: 2.863264799118042, LR: 0.0009755282581475768, Duration: 50.81 sec


Epoch 20: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.944172477722168, Val Loss: 2.951862335205078, LR: 0.0009698463103929542, Duration: 51.11 sec


Epoch 21: 100%|██████████| 15/15 [00:50<00:00,  3.36s/it]


	Loss: 3.8698577721913656, Val Loss: 2.8209657669067383, LR: 0.0009635919272833937, Duration: 51.24 sec


Epoch 22: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.7705761273701985, Val Loss: 2.71178936958313, LR: 0.0009567727288213005, Duration: 51.03 sec


Epoch 23: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.781675084431966, Val Loss: 2.6807870864868164, LR: 0.0009493970231495835, Duration: 51.13 sec


Epoch 24: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.7442699432373048, Val Loss: 2.5968668460845947, LR: 0.0009414737964294635, Duration: 51.17 sec


Epoch 25: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.7778147061665854, Val Loss: 2.5897953510284424, LR: 0.0009330127018922195, Duration: 50.87 sec


Epoch 26: 100%|██████████| 15/15 [00:50<00:00,  3.37s/it]


	Loss: 3.533814557393392, Val Loss: 2.4396042823791504, LR: 0.0009240240480782129, Duration: 51.40 sec


Epoch 27: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.796156120300293, Val Loss: 2.500469923019409, LR: 0.0009145187862775209, Duration: 51.01 sec


Epoch 28: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.5853291352589927, Val Loss: 2.383317470550537, LR: 0.0009045084971874737, Duration: 51.01 sec


Epoch 29: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.527343527475993, Val Loss: 2.3644769191741943, LR: 0.0008940053768033609, Duration: 51.05 sec


Epoch 30: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.6358637968699137, Val Loss: 2.351447105407715, LR: 0.000883022221559489, Duration: 50.73 sec


Epoch 31: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.5587464332580567, Val Loss: 2.213412046432495, LR: 0.0008715724127386971, Duration: 51.13 sec


Epoch 32: 100%|██████████| 15/15 [00:50<00:00,  3.36s/it]


	Loss: 3.4356503168741863, Val Loss: 2.1071174144744873, LR: 0.0008596699001693256, Duration: 51.26 sec


Epoch 33: 100%|██████████| 15/15 [00:50<00:00,  3.36s/it]


	Loss: 3.4244800408681235, Val Loss: 2.0860495567321777, LR: 0.0008473291852294987, Duration: 51.30 sec


Epoch 34: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.4971884727478026, Val Loss: 2.093473196029663, LR: 0.0008345653031794292, Duration: 50.78 sec


Epoch 35: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.511097224553426, Val Loss: 2.1360254287719727, LR: 0.0008213938048432696, Duration: 50.80 sec


Epoch 36: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 3.545870129267375, Val Loss: 2.066307544708252, LR: 0.0008078307376628291, Duration: 51.09 sec


Epoch 37: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.3102293491363524, Val Loss: 1.9978548288345337, LR: 0.0007938926261462366, Duration: 50.93 sec


Epoch 38: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.414239184061686, Val Loss: 1.9859976768493652, LR: 0.0007795964517353734, Duration: 50.74 sec


Epoch 39: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.416082827250163, Val Loss: 1.8684029579162598, LR: 0.0007649596321166025, Duration: 50.82 sec


Epoch 40: 100%|██████████| 15/15 [00:50<00:00,  3.33s/it]


	Loss: 3.138517888387044, Val Loss: 1.8220951557159424, LR: 0.00075, Duration: 50.94 sec


Epoch 41: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.254134941101074, Val Loss: 1.8229365348815918, LR: 0.0007347357813929454, Duration: 50.84 sec


Epoch 42: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.2163746198018393, Val Loss: 1.7288466691970825, LR: 0.0007191855733945387, Duration: 50.85 sec


Epoch 43: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.4814995765686034, Val Loss: 1.825628638267517, LR: 0.0007033683215379002, Duration: 50.97 sec


Epoch 44: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.434501886367798, Val Loss: 1.8478795289993286, LR: 0.0006873032967079561, Duration: 50.83 sec


Epoch 45: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.3214625835418703, Val Loss: 1.8126704692840576, LR: 0.0006710100716628344, Duration: 50.83 sec


Epoch 46: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.305809164047241, Val Loss: 1.7101682424545288, LR: 0.0006545084971874737, Duration: 50.92 sec


Epoch 47: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.2320428530375165, Val Loss: 1.6829944849014282, LR: 0.0006378186779084996, Duration: 50.89 sec


Epoch 48: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.2073685487111407, Val Loss: 1.6750632524490356, LR: 0.0006209609477998338, Duration: 50.78 sec


Epoch 49: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.3867226282755536, Val Loss: 1.6697527170181274, LR: 0.0006039558454088796, Duration: 50.83 sec


Epoch 50: 100%|██████████| 15/15 [00:49<00:00,  3.31s/it]


	Loss: 3.123853158950806, Val Loss: 1.5660886764526367, LR: 0.0005868240888334653, Duration: 50.62 sec


Epoch 51: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.252556816736857, Val Loss: 1.5544921159744263, LR: 0.0005695865504800327, Duration: 50.69 sec


Epoch 52: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.0271260738372803, Val Loss: 1.5278981924057007, LR: 0.0005522642316338268, Duration: 50.94 sec


Epoch 53: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.2504122893015546, Val Loss: 1.4976816177368164, LR: 0.0005348782368720626, Duration: 50.83 sec


Epoch 54: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 2.974802859624227, Val Loss: 1.4685289859771729, LR: 0.0005174497483512506, Duration: 50.98 sec


Epoch 55: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.232813851038615, Val Loss: 1.4910255670547485, LR: 0.0005, Duration: 50.64 sec


Epoch 56: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.3418434143066404, Val Loss: 1.5499149560928345, LR: 0.0004825502516487497, Duration: 50.86 sec


Epoch 57: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.102729876836141, Val Loss: 1.4518831968307495, LR: 0.00046512176312793734, Duration: 50.71 sec


Epoch 58: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.984872738520304, Val Loss: 1.3703336715698242, LR: 0.00044773576836617336, Duration: 50.73 sec


Epoch 59: 100%|██████████| 15/15 [00:49<00:00,  3.31s/it]


	Loss: 2.9487711906433107, Val Loss: 1.3900600671768188, LR: 0.0004304134495199674, Duration: 50.58 sec


Epoch 60: 100%|██████████| 15/15 [00:49<00:00,  3.31s/it]


	Loss: 3.1685937881469726, Val Loss: 1.360853910446167, LR: 0.00041317591116653486, Duration: 50.62 sec


Epoch 61: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.048887332280477, Val Loss: 1.3899701833724976, LR: 0.0003960441545911204, Duration: 50.71 sec


Epoch 62: 100%|██████████| 15/15 [00:49<00:00,  3.31s/it]


	Loss: 3.2343202590942384, Val Loss: 1.3511807918548584, LR: 0.0003790390522001662, Duration: 50.58 sec


Epoch 63: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.0106656551361084, Val Loss: 1.366744875907898, LR: 0.00036218132209150044, Duration: 50.88 sec


Epoch 64: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 3.1047253608703613, Val Loss: 1.3298567533493042, LR: 0.00034549150281252633, Duration: 51.02 sec


Epoch 65: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.005611276626587, Val Loss: 1.3468254804611206, LR: 0.0003289899283371657, Duration: 50.78 sec


Epoch 66: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.8786781152089436, Val Loss: 1.2808141708374023, LR: 0.00031269670329204396, Duration: 50.68 sec


Epoch 67: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.9180047512054443, Val Loss: 1.2425788640975952, LR: 0.0002966316784621, Duration: 50.82 sec


Epoch 68: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.0212397893269856, Val Loss: 1.2673382759094238, LR: 0.00028081442660546124, Duration: 50.87 sec


Epoch 69: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.740767256418864, Val Loss: 1.206329584121704, LR: 0.00026526421860705474, Duration: 50.90 sec


Epoch 70: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.8472370862960816, Val Loss: 1.1802430152893066, LR: 0.0002500000000000001, Duration: 50.73 sec


Epoch 71: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.8219935417175295, Val Loss: 1.1751976013183594, LR: 0.0002350403678833976, Duration: 50.86 sec


Epoch 72: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.0673156420389813, Val Loss: 1.2240365743637085, LR: 0.00022040354826462666, Duration: 50.61 sec


Epoch 73: 100%|██████████| 15/15 [00:49<00:00,  3.31s/it]


	Loss: 3.1022398153940838, Val Loss: 1.2352367639541626, LR: 0.00020610737385376348, Duration: 50.53 sec


Epoch 74: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.8730401833852133, Val Loss: 1.2071330547332764, LR: 0.00019216926233717085, Duration: 50.84 sec


Epoch 75: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.073003562291463, Val Loss: 1.1854075193405151, LR: 0.0001786061951567303, Duration: 50.79 sec


Epoch 76: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 3.0296328544616697, Val Loss: 1.1801347732543945, LR: 0.00016543469682057105, Duration: 50.73 sec


Epoch 77: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 3.0158047835032145, Val Loss: 1.1497093439102173, LR: 0.00015267081477050133, Duration: 50.82 sec


Epoch 78: 100%|██████████| 15/15 [00:50<00:00,  3.33s/it]


	Loss: 2.862226947148641, Val Loss: 1.1495566368103027, LR: 0.00014033009983067452, Duration: 50.87 sec


Epoch 79: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.8607118050257365, Val Loss: 1.148755669593811, LR: 0.00012842758726130281, Duration: 50.76 sec


Epoch 80: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.7824679454167685, Val Loss: 1.1178703308105469, LR: 0.00011697777844051105, Duration: 50.64 sec


Epoch 81: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.873526938756307, Val Loss: 1.1135333776474, LR: 0.00010599462319663906, Duration: 50.67 sec


Epoch 82: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.6985460917154946, Val Loss: 1.110559344291687, LR: 9.549150281252633e-05, Duration: 50.73 sec


Epoch 83: 100%|██████████| 15/15 [00:50<00:00,  3.34s/it]


	Loss: 2.9345927238464355, Val Loss: 1.0977905988693237, LR: 8.548121372247918e-05, Duration: 50.98 sec


Epoch 84: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.941496499379476, Val Loss: 1.1104034185409546, LR: 7.597595192178702e-05, Duration: 50.85 sec


Epoch 85: 100%|██████████| 15/15 [00:50<00:00,  3.33s/it]


	Loss: 2.693594749768575, Val Loss: 1.1019744873046875, LR: 6.698729810778065e-05, Duration: 50.95 sec


Epoch 86: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.7669259945551556, Val Loss: 1.0928895473480225, LR: 5.852620357053651e-05, Duration: 50.89 sec


Epoch 87: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.9359314997990924, Val Loss: 1.0912108421325684, LR: 5.060297685041659e-05, Duration: 50.64 sec


Epoch 88: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.550270660718282, Val Loss: 1.0734676122665405, LR: 4.322727117869951e-05, Duration: 50.66 sec


Epoch 89: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.666748531659444, Val Loss: 1.0755447149276733, LR: 3.6408072716606344e-05, Duration: 50.83 sec


Epoch 90: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.943815541267395, Val Loss: 1.0754739046096802, LR: 3.0153689607045842e-05, Duration: 50.75 sec


Epoch 91: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.8767631530761717, Val Loss: 1.0660046339035034, LR: 2.4471741852423235e-05, Duration: 50.73 sec


Epoch 92: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.745632561047872, Val Loss: 1.0670671463012695, LR: 1.9369152030840554e-05, Duration: 50.73 sec


Epoch 93: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.8098102172215778, Val Loss: 1.0644277334213257, LR: 1.4852136862001764e-05, Duration: 50.88 sec


Epoch 94: 100%|██████████| 15/15 [00:50<00:00,  3.35s/it]


	Loss: 2.6137861172358194, Val Loss: 1.0602092742919922, LR: 1.0926199633097156e-05, Duration: 51.17 sec


Epoch 95: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.5384429693222046, Val Loss: 1.0594589710235596, LR: 7.59612349389599e-06, Duration: 50.75 sec


Epoch 96: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.9381473700205487, Val Loss: 1.0573269128799438, LR: 4.865965629214819e-06, Duration: 50.82 sec


Epoch 97: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.698636611302694, Val Loss: 1.0562412738800049, LR: 2.739052315863355e-06, Duration: 50.73 sec


Epoch 98: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.9519651254018147, Val Loss: 1.0567539930343628, LR: 1.2179748700879012e-06, Duration: 50.76 sec


Epoch 99: 100%|██████████| 15/15 [00:49<00:00,  3.33s/it]


	Loss: 2.5671448548634848, Val Loss: 1.0566096305847168, LR: 3.0458649045211895e-07, Duration: 50.84 sec


Epoch 100: 100%|██████████| 15/15 [00:49<00:00,  3.32s/it]


	Loss: 2.682647705078125, Val Loss: 1.0565671920776367, LR: 0.0, Duration: 50.70 sec
Epoch 당 평균 소요시간 : 50.90초


In [21]:
from sklearn.metrics import confusion_matrix
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [23]:
# 예측 수행 및 레이블 저장
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = swin(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 혼동 행렬 생성
cm = confusion_matrix(all_labels, all_preds)

# 예측과 실제 레이블
y_true = all_labels  # 실제 레이블
y_pred = all_preds  # 모델에 의해 예측된 레이블

# 전체 데이터셋에 대한 정확도
accuracy = accuracy_score(y_true, y_pred)

# 평균 정밀도, 리콜, F1-Score ('weighted')
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

# 판다스 데이터프레임으로 결과 정리
performance_metrics = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
    'Value': [accuracy, precision, recall, f1_score]
})

# 데이터프레임 출력
performance_metrics

Unnamed: 0,Metric,Value
0,Accuracy,0.794
1,Precision,0.827087
2,Recall,0.794
3,F1 Score,0.784236
