In [1]:
!pip install python-box

[0m

In [2]:
import numpy as np
import yaml
from box import Box

import torch
import torch.nn as nn
import torch.optim as optim

import simmim

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
simmim_config = yaml.load(open('config/pretrain.yaml'), Loader=yaml.FullLoader)
simmim_config

{'MODEL': {'TYPE': 'swinv2',
  'NAME': 'simmim_pretrain',
  'DROP_PATH_RATE': 0.1,
  'SIMMIM': {'NORM_TARGET': {'ENABLE': True, 'PATCH_SIZE': 47}},
  'SWIN': {'EMBED_DIM': 96,
   'DEPTHS': [2, 2, 6, 2],
   'NUM_HEADS': [3, 6, 12, 24],
   'WINDOW_SIZE': 6,
   'PATCH_SIZE': 4}},
 'DATA': {'IMG_SIZE': 192,
  'MASK_PATCH_SIZE': 32,
  'MASK_RATIO': 0.6,
  'BATCH_SIZE': 1024,
  'NUM_WORKERS': 24,
  'DATA_PATH': '../../data/sports/train'},
 'TRAIN': {'EPOCHS': 20,
  'WARMUP_EPOCHS': 10,
  'BASE_LR': '1e-4',
  'WARMUP_LR': '5e-7',
  'WEIGHT_DECAY': 0.05,
  'CLIP_GRAD': 5,
  'LR_SCHEDULER': {'NAME': 'multistep', 'GAMMA': 0.1, 'MULTISTEPS': [700]}},
 'PRINT_FREQ': 100,
 'SAVE_FREQ': 5,
 'TAG': 'simmim_pretrain__swinv2_tiny__img224_window7__800ep'}

In [4]:
encoder_config = {'img_size':simmim_config['DATA']['IMG_SIZE'], 
                'patch_size':simmim_config['MODEL']['SWIN']['PATCH_SIZE'], 
                'in_chans':3, 
                'num_classes':100,
                'embed_dim':simmim_config['MODEL']['SWIN']['EMBED_DIM'], 
                'depths':simmim_config['MODEL']['SWIN']['DEPTHS'], 
                'num_heads':simmim_config['MODEL']['SWIN']['NUM_HEADS'],           
                'window_size':simmim_config['MODEL']['SWIN']['WINDOW_SIZE'], 
                'mlp_ratio':4., 
                'qkv_bias':True, 
                'qk_scale':None,
                'drop_rate':0., 
                'attn_drop_rate':0., 
                'drop_path_rate':simmim_config['MODEL']['DROP_PATH_RATE'],
                'norm_layer':nn.LayerNorm, 
                'patch_norm':True, 
                'pretrained_window_sizes':[0,0,0,0],
                'ape':True}

encoder_stride = 32
in_chans = encoder_config['in_chans']
patch_size = encoder_config['patch_size']

In [5]:
encoder = simmim.SwinTransformerV2ForSimMIM(**encoder_config)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
model = simmim.SimMIM( encoder=encoder, 
                       encoder_stride=encoder_stride, 
                       in_chans=in_chans, 
                       patch_size=patch_size)

In [7]:
simmim_config = Box(simmim_config)
dataloader = simmim.build_loader_simmim(simmim_config)

samples = next(iter(dataloader))
len(samples)

3

In [8]:
samples[0].shape, samples[1].shape, samples[2].shape 

(torch.Size([1024, 3, 192, 192]),
 torch.Size([1024, 48, 48]),
 torch.Size([1024]))

In [13]:
base_lr = float(simmim_config.TRAIN.BASE_LR)
weight_decay = simmim_config.TRAIN.WEIGHT_DECAY
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
warmup_epochs = simmim_config.TRAIN.WARMUP_EPOCHS
train_epochs = simmim_config.TRAIN.EPOCHS

multisteps = simmim_config.TRAIN.LR_SCHEDULER.MULTISTEPS
gamma = simmim_config.TRAIN.LR_SCHEDULER.GAMMA
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=multisteps, gamma=gamma)

# LambdaLR 스케줄러 설정
lambda1 = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else 1 # Warmup을 위한 Lambda 함수 정의
scheduler_warmup = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

# MultiStepLR 스케줄러 설정
scheduler_multistep = optim.lr_scheduler.MultiStepLR(optimizer, milestones=multisteps, gamma=gamma)

In [10]:
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

In [11]:
device = 'cuda:2'
model.to(device)

model_save = False
model_path = '../../models/swin2/simmim.pth'

In [12]:
training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')

# GradScaler 초기화
scaler = GradScaler()

for epoch in range(train_epochs):
    model.train()
    start_time = time.time()
    running_loss = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch + 1}")
    
    for _, data in pbar:
        image, mask = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        # AutoCast 적용
        with autocast():
            loss = model(image, mask)
            
        # 스케일링된 그라디언트 계산
        scaler.scale(loss).backward()

        # 그라디언트 클리핑 전에 스케일링 제거
        scaler.unscale_(optimizer)
        if simmim_config.TRAIN.CLIP_GRAD:
            clip_grad_norm_(model.parameters(), max_norm=simmim_config.TRAIN.CLIP_GRAD)
        else:
            clip_grad_norm_(model.parameters())

        # 옵티마이저 스텝 및 스케일러 업데이트
        scaler.step(optimizer)
        scaler.update()
        if epoch <= warmup_epochs:
            scheduler_warmup.step()
        else:
            scheduler_multistep.step()
        # scheduler.step()
            
        lr = optimizer.param_groups[0]["lr"]
        lrs.append(lr)
        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    losses.append(epoch_loss)

    # 모델 저장
    if epoch_loss < best_loss:
        
        best_loss = epoch_loss
        vit_save = model_save
        if vit_save:
            torch.save(model.state_dict(), model_path)
        
    epoch_duration = time.time() - start_time
    training_time += epoch_duration
    
    text = f'\tLoss: {epoch_loss}, LR: {lr}, Duration: {epoch_duration:.2f} sec'
    
    if vit_save:
        text += f' - model saved!'
        vit_save = False    
        
    print(text)


Epoch 1: 100%|██████████| 13/13 [00:16<00:00,  1.30s/it]

	Loss: 1.1231169058726385, LR: 0.0001, Duration: 17.41 sec



Epoch 2: 100%|██████████| 13/13 [00:16<00:00,  1.29s/it]

	Loss: 1.0928216255628145, LR: 0.0001, Duration: 17.75 sec



Epoch 3: 100%|██████████| 13/13 [00:17<00:00,  1.35s/it]

	Loss: 1.0854026079177856, LR: 0.0001, Duration: 18.55 sec



Epoch 4: 100%|██████████| 13/13 [00:16<00:00,  1.28s/it]

	Loss: 1.015883390720074, LR: 0.0001, Duration: 17.66 sec



Epoch 5: 100%|██████████| 13/13 [00:16<00:00,  1.28s/it]

	Loss: 0.909022895189432, LR: 0.0001, Duration: 17.59 sec



Epoch 6: 100%|██████████| 13/13 [00:16<00:00,  1.27s/it]

	Loss: 0.8582831942118131, LR: 0.0001, Duration: 17.44 sec



Epoch 7: 100%|██████████| 13/13 [00:16<00:00,  1.28s/it]

	Loss: 0.8378349496768072, LR: 0.0001, Duration: 17.62 sec



Epoch 8: 100%|██████████| 13/13 [00:15<00:00,  1.22s/it]


	Loss: 0.8183376789093018, LR: 0.0001, Duration: 16.94 sec


Epoch 9: 100%|██████████| 13/13 [00:16<00:00,  1.31s/it]

	Loss: 0.7871220753743098, LR: 0.0001, Duration: 17.91 sec



Epoch 10: 100%|██████████| 13/13 [00:17<00:00,  1.31s/it]

	Loss: 0.7618823555799631, LR: 0.0001, Duration: 18.01 sec



Epoch 11: 100%|██████████| 13/13 [00:16<00:00,  1.28s/it]

	Loss: 0.7539997100830078, LR: 0.0001, Duration: 17.66 sec



Epoch 12: 100%|██████████| 13/13 [00:17<00:00,  1.31s/it]

	Loss: 0.7373715272316566, LR: 0.0001, Duration: 18.07 sec



Epoch 13: 100%|██████████| 13/13 [00:16<00:00,  1.28s/it]

	Loss: 0.7281529765862685, LR: 0.0001, Duration: 17.68 sec



Epoch 14: 100%|██████████| 13/13 [00:15<00:00,  1.23s/it]

	Loss: 0.7187970555745639, LR: 0.0001, Duration: 16.92 sec



Epoch 15: 100%|██████████| 13/13 [00:16<00:00,  1.27s/it]

	Loss: 0.6983148501469538, LR: 0.0001, Duration: 17.47 sec



Epoch 16: 100%|██████████| 13/13 [00:17<00:00,  1.34s/it]

	Loss: 0.6886927989812998, LR: 0.0001, Duration: 18.37 sec



Epoch 17: 100%|██████████| 13/13 [00:16<00:00,  1.25s/it]

	Loss: 0.6805060093219464, LR: 0.0001, Duration: 17.28 sec



Epoch 18: 100%|██████████| 13/13 [00:16<00:00,  1.23s/it]

	Loss: 0.6734302777510422, LR: 0.0001, Duration: 17.04 sec



Epoch 19: 100%|██████████| 13/13 [00:17<00:00,  1.32s/it]

	Loss: 0.671228670156919, LR: 0.0001, Duration: 18.20 sec



Epoch 20: 100%|██████████| 13/13 [00:17<00:00,  1.36s/it]

	Loss: 0.6655413279166589, LR: 0.0001, Duration: 18.71 sec



