In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as F

import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

from timm.data import Mixup
from timm.utils import ModelEmaV3
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
import transformers

from sklearn.metrics import confusion_matrix
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from collections import OrderedDict
from model.convnextv2 import load_convNext
import math
import warnings
from torch.optim.lr_scheduler import _LRScheduler

class CosineWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, min_lr=1e-6, last_epoch=-1, verbose=False):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.num_cycles = num_cycles
        self.min_lr = min_lr
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)
        
        lrs = []
        for base_lr in self.base_lrs:
            if self.last_epoch < self.num_warmup_steps:
                # Linear warmup
                lr = (base_lr - self.min_lr) * self.last_epoch / max(1, self.num_warmup_steps) + self.min_lr
            else:
                # Cosine annealing
                progress = (self.last_epoch - self.num_warmup_steps) / max(1, self.num_training_steps - self.num_warmup_steps)
                lr = self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * self.num_cycles * 2.0 * progress))
            lrs.append(lr)
        return lrs



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# checkpoint_model = convnextv2_fcmae_tiny()
model = load_convNext()

pretrain_path = '../../model/convnext/fcmae.pt'
checkpoint_model = torch.load(pretrain_path, map_location='cpu')

In [3]:
def remap_checkpoint_keys(ckpt):
    new_ckpt = OrderedDict()
    for k, v in ckpt.items():
        if k.startswith('encoder'):
            k = '.'.join(k.split('.')[1:]) # remove encoder in the name
        if k.endswith('kernel'):
            k = '.'.join(k.split('.')[:-1]) # remove kernel in the name
            new_k = k + '.weight'
            if len(v.shape) == 3: # resahpe standard convolution
                kv, in_dim, out_dim = v.shape
                ks = int(math.sqrt(kv))
                new_ckpt[new_k] = v.permute(2, 1, 0).\
                    reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
            elif len(v.shape) == 2: # reshape depthwise convolution
                kv, dim = v.shape
                ks = int(math.sqrt(kv))
                new_ckpt[new_k] = v.permute(1, 0).\
                    reshape(dim, 1, ks, ks).transpose(3, 2)
            continue
        elif 'ln' in k or 'linear' in k:
            k = k.split('.')
            k.pop(-2) # remove ln and linear in the name
            new_k = '.'.join(k)
        else:
            new_k = k
        new_ckpt[new_k] = v

    # reshape grn affine parameters and biases
    for k, v in new_ckpt.items():
        if k.endswith('bias') and len(v.shape) != 1:
            new_ckpt[k] = v.reshape(-1)
        elif 'grn' in k:
            new_ckpt[k] = v.unsqueeze(0).unsqueeze(1)
    return new_ckpt

In [4]:
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix=prefix)

    warn_missing_keys = []
    ignore_missing_keys = []
    for key in missing_keys:
        keep_flag = True
        for ignore_key in ignore_missing.split('|'):
            if ignore_key in key:
                keep_flag = False
                break
        if keep_flag:
            warn_missing_keys.append(key)
        else:
            ignore_missing_keys.append(key)

    missing_keys = warn_missing_keys

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(ignore_missing_keys) > 0:
        print("Ignored weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, ignore_missing_keys))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))

In [5]:
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]


# remove decoder weights
checkpoint_model_keys = list(checkpoint_model.keys())
for k in checkpoint_model_keys:
    if 'decoder' in k or 'mask_token'in k or \
        'proj' in k or 'pred' in k:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]

checkpoint_model = remap_checkpoint_keys(checkpoint_model)
load_state_dict(model, checkpoint_model, prefix='')

Removing key mask_token from pretrained checkpoint
Removing key proj.weight from pretrained checkpoint
Removing key proj.bias from pretrained checkpoint
Removing key decoder.0.dwconv.weight from pretrained checkpoint
Removing key decoder.0.dwconv.bias from pretrained checkpoint
Removing key decoder.0.layernorm.weight from pretrained checkpoint
Removing key decoder.0.layernorm.bias from pretrained checkpoint
Removing key decoder.0.pwconv1.weight from pretrained checkpoint
Removing key decoder.0.pwconv1.bias from pretrained checkpoint
Removing key decoder.0.grn.gamma from pretrained checkpoint
Removing key decoder.0.grn.beta from pretrained checkpoint
Removing key decoder.0.pwconv2.weight from pretrained checkpoint
Removing key decoder.0.pwconv2.bias from pretrained checkpoint
Removing key pred.weight from pretrained checkpoint
Removing key pred.bias from pretrained checkpoint
Weights of ConvNeXtV2 not initialized from pretrained model: ['stem.stem_ln.weight', 'stem.stem_ln.bias', 'downs

In [6]:
# 총 파라미터 수 계산
total_params = sum(p.numel() for p in model.parameters())

# 학습 가능한 파라미터 수 계산
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print('='*80)
print(f"\nTotal Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}\n")
print('='*80)

# model_summary = summary(model.cuda(), (3, 224, 224))

# Transforms 정의하기
train_transform = transforms.Compose([
    transforms.TrivialAugmentWide(interpolation=F.InterpolationMode.BICUBIC),
    transforms.RandomResizedCrop(224, scale=(0.6,1), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

data_dir = '../../data/sports'
batch_size = 256

train_path = data_dir+'/train'
valid_path = data_dir+'/valid'
test_path = data_dir+'/test'

# dataset load
train_data = ImageFolder(train_path, transform=train_transform)
valid_data = ImageFolder(valid_path, transform=test_transform)
test_data = ImageFolder(test_path, transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

device = 'cuda:4'
max_norm = 3.0 

model.to(device)

model_ema = None
ema_active = True
if ema_active:
    ema_decay = 0.9998
    model_ema = ModelEmaV3(
        model,
        decay=ema_decay,
    )
    print(f"Using EMA with decay = {ema_decay}")

model_path = ''

mixup = True
if mixup :
    mixup_fn = Mixup(mixup_alpha=.8, 
                    cutmix_alpha=1., 
                    prob=1., 
                    switch_prob=0.5, 
                    mode='batch',
                    label_smoothing=.1,
                    num_classes=100)
    
    criterion = SoftTargetCrossEntropy()
else :
    criterion = LabelSmoothingCrossEntropy(.1)
    
criterion = nn.CrossEntropyLoss(label_smoothing=0.)


Total Parameters: 27,943,396
Trainable Parameters: 27,943,396

Using EMA with decay = 0.9998


In [7]:
layer_names = []
for i, (name, params) in enumerate(model.named_parameters()):
    layer_names.append(name)
    
layer_names.reverse()

lr      = 8e-4  
lr_mult = 0.9  
weight_decay = 0.05 

param_groups = []
prev_group_name = layer_names[0].split('.')[0]

for idx, name in enumerate(layer_names):
    
    cur_group_name = name.split('.')[0]
    
    if cur_group_name != prev_group_name:
        lr *= lr_mult
    prev_group_name = cur_group_name
    
    print(f"{idx}: {name}'s lr={lr}")
    
    param_groups += [{'params': [ p for n, p in model.named_parameters() if n == name and p.requires_grad],
                      'lr' : lr,
                      'weight_decay': weight_decay}]

0: fc.bias's lr=0.0008
1: fc.weight's lr=0.0008
2: layernorm.bias's lr=0.00072
3: layernorm.weight's lr=0.00072
4: stages.3.2.pwconv2.bias's lr=0.000648
5: stages.3.2.pwconv2.weight's lr=0.000648
6: stages.3.2.grn.beta's lr=0.000648
7: stages.3.2.grn.gamma's lr=0.000648
8: stages.3.2.pwconv1.bias's lr=0.000648
9: stages.3.2.pwconv1.weight's lr=0.000648
10: stages.3.2.layernorm.bias's lr=0.000648
11: stages.3.2.layernorm.weight's lr=0.000648
12: stages.3.2.dwconv.bias's lr=0.000648
13: stages.3.2.dwconv.weight's lr=0.000648
14: stages.3.1.pwconv2.bias's lr=0.000648
15: stages.3.1.pwconv2.weight's lr=0.000648
16: stages.3.1.grn.beta's lr=0.000648
17: stages.3.1.grn.gamma's lr=0.000648
18: stages.3.1.pwconv1.bias's lr=0.000648
19: stages.3.1.pwconv1.weight's lr=0.000648
20: stages.3.1.layernorm.bias's lr=0.000648
21: stages.3.1.layernorm.weight's lr=0.000648
22: stages.3.1.dwconv.bias's lr=0.000648
23: stages.3.1.dwconv.weight's lr=0.000648
24: stages.3.0.pwconv2.bias's lr=0.000648
25: st

In [8]:
epochs = 500

optimizer = optim.AdamW(param_groups)
warmup_steps = int(len(train_loader)*(epochs)*0.1)
train_steps = len(train_loader)*(epochs)
scheduler = CosineWarmupScheduler(optimizer, 
                                num_warmup_steps=warmup_steps, 
                                num_training_steps=train_steps,
                                num_cycles=0.5,
                                min_lr=1e-7)
# scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, 
#                                                         num_warmup_steps=warmup_steps, 
#                                                         num_training_steps=train_steps,
#                                                         num_cycles=0.5)

training_time = 0
losses = []
val_losses = []
lrs = []
best_loss = float('inf')
model_save = False

for i in range(epochs // 100):
    for epoch in range(100):
        model.train()
        start_time = time.time()
        running_loss = 0.0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1 + i*100}")
        
        for _, data in pbar:
            inputs, labels = data[0].to(device), data[1].to(device)
            inputs, labels = mixup_fn(inputs, labels)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
                
            loss.backward()
            # 그라디언트 클리핑 적용
            clip_grad_norm_(model.parameters(), max_norm=max_norm)
            optimizer.step()
            
            # EMA 모델 업데이트, 필요한 경우
            if model_ema is not None:
                model_ema.update(model)
                
            scheduler.step()
                
            lr = optimizer.param_groups[0]["lr"]
            lrs.append(lr)
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        losses.append(epoch_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data in valid_loader:
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
        val_loss /= len(valid_loader)
        val_losses.append(val_loss)
        
        # 모델 저장 조건 수정
        total_loss = val_loss + epoch_loss
        if total_loss < best_loss:
            best_loss = total_loss
            # torch.save(model.state_dict(), model_path)
            model_save = True
            save_text = ' - model saved!'
        else:
            save_text = ''

        epoch_duration = time.time() - start_time
        training_time += epoch_duration
        
        text = f'\tLoss: {epoch_loss:.4f}, Val_Loss: {val_loss:.4f}, Total Mean Loss: {total_loss/2:.4f}, LR: {lr}, Duration: {epoch_duration:.2f} sec{save_text}'
        print(text)

    # 예측 수행 및 레이블 저장
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 혼동 행렬 생성
    cm = confusion_matrix(all_labels, all_preds)

    # 예측과 실제 레이블
    y_true = all_labels  # 실제 레이블
    y_pred = all_preds  # 모델에 의해 예측된 레이블

    # 전체 데이터셋에 대한 정확도
    accuracy = accuracy_score(y_true, y_pred)

    # 평균 정밀도, 리콜, F1-Score ('weighted')
    precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')

    # 판다스 데이터프레임으로 결과 정리
    performance_metrics = pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
        'Value': [accuracy, precision, recall, f1_score]
    })

    # 데이터프레임 출력
    print(f"\n[{i*100+100} epoch result]\n", performance_metrics)


Epoch 1: 100%|██████████| 53/53 [01:40<00:00,  1.89s/it]


	Loss: 4.7468, Val_Loss: 4.6995, Total Mean Loss: 4.7232, LR: 1.6098e-05, Duration: 101.37 sec - model saved!


Epoch 2: 100%|██████████| 53/53 [01:38<00:00,  1.86s/it]


	Loss: 4.6487, Val_Loss: 4.6234, Total Mean Loss: 4.6360, LR: 3.2096000000000006e-05, Duration: 99.94 sec - model saved!


Epoch 3: 100%|██████████| 53/53 [01:39<00:00,  1.87s/it]


	Loss: 4.6034, Val_Loss: 4.6151, Total Mean Loss: 4.6093, LR: 4.809400000000001e-05, Duration: 100.48 sec - model saved!


Epoch 4: 100%|██████████| 53/53 [01:39<00:00,  1.87s/it]


	Loss: 4.6015, Val_Loss: 4.6175, Total Mean Loss: 4.6095, LR: 6.409200000000001e-05, Duration: 100.47 sec


Epoch 5: 100%|██████████| 53/53 [01:39<00:00,  1.87s/it]


	Loss: 4.6035, Val_Loss: 4.6189, Total Mean Loss: 4.6112, LR: 8.009000000000001e-05, Duration: 100.31 sec


Epoch 6: 100%|██████████| 53/53 [01:38<00:00,  1.87s/it]


	Loss: 4.6051, Val_Loss: 4.6212, Total Mean Loss: 4.6132, LR: 9.608800000000002e-05, Duration: 100.13 sec


Epoch 7: 100%|██████████| 53/53 [01:39<00:00,  1.87s/it]


	Loss: 4.6051, Val_Loss: 4.6230, Total Mean Loss: 4.6141, LR: 0.00011208600000000002, Duration: 100.38 sec


Epoch 8:  11%|█▏        | 6/53 [00:11<01:28,  1.89s/it]