### 모델 로드

In [20]:
import timm
model = timm.create_model('mixer_b16_224', pretrained=True)

In [21]:
import torch
import torch.nn as nn
checkpoint = torch.load("C:/Users/USER/Downloads/mlp_mixer_img21_b16.pth", map_location='cpu') # GPU 환경이 아닌 경우 'cpu'를 사용합니다.
model.load_state_dict(checkpoint)

# 모델의 마지막 분류기 부분을 새로운 클래스 수에 맞게 변경
num_classes = 3  # 새로운 클래스 수
model.head = nn.Linear(model.head.in_features, num_classes)

### 출력층만 학습

In [None]:
# 모델의 모든 파라미터를 동결
for param in model.parameters():
    param.requires_grad = False

# 출력층 파라미터만 학습 가능하도록 설정
model.head.weight.requires_grad = True
model.head.bias.requires_grad = True

In [3]:
model

MlpMixer(
  (stem): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): Sequential(
    (0): MixerBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp_tokens): Mlp(
        (fc1): Linear(in_features=196, out_features=384, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=384, out_features=196, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp_channels): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      

### 데이터 셋

In [4]:
import logging
import os
import json
from PIL import Image
from torchvision.transforms import ToTensor
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir,annotations_dir,  transform=None):
        """
        annotation_dir (string): 메타데이터가 있는 JSON 파일의 경로
        img_dir (string): 모든 이미지가 있는 디렉토리의 경로
        transform (callable, optional): 샘플에 적용될 선택적 변환
        """
        self.img_dir = img_dir
        self.transform = transform
        
        self.annotation_dir= annotations_dir
        

    def __len__(self):
        label_list= os.listdir(self.annotation_dir)
        return len(label_list)

    def __getitem__(self, idx):
        label_list= os.listdir(self.annotation_dir)
        
        img_path = os.path.join(self.img_dir, label_list[idx].split('.')[0]+'.'+label_list[idx].split('.')[1])

        try:
            image = Image.open(img_path)
        except (IOError, OSError) as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))
        
        
        
        # faceExp_uploader 부분만 라벨로 사용
        with open(self.annotation_dir+'/'+label_list[idx],'r', encoding='utf-8') as f:
            self.image_labels=json.load(f)
        label = self.image_labels['faceExp_uploader']
        label_to_int = {'기쁨': 0, '당황': 1, '중립': 2}

        # 문자열 라벨을 정수로 매핑
        label_int = label_to_int[label]
        label_tensor = torch.tensor(label_int, dtype=torch.long)
        
        if self.transform:
            image_tensor = self.transform(image)
        else:
            # 기본적으로 이미지를 Tensor로 변환
            transform = ToTensor()
            image_tensor = transform(image)
        
        return image_tensor, label_tensor

### 데이터 전처리

In [5]:
# 모폴로지 데이터 전처리
import cv2
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
from PIL import Image

class MorphologyTransform:
    def __call__(self, img):
        # img는 PIL 이미지
        img = np.array(img)
        kernel = np.ones((5,5), np.uint8)
        img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
        return Image.fromarray(img)


In [6]:
# edge 감지
class EdgeDetectionTransform:
    def __call__(self, img):
        # img는 PIL 이미지
        img = np.array(img)
        img = cv2.Canny(img, 100, 200)
        return Image.fromarray(img)

In [6]:
def get_loader(img_size, train_img_dir,train_annotation_dir, test_img_dir, test_annotation_dir, train_batch_size, eval_batch_size):

    transform_train = transforms.Compose([
        #주어진 이미지에서 임의의 크기 및 비율로 크롭한 후, 지정된 크기로 이미지를 리사이즈합니다, 5%~100% 사이의 영역을 crop 하여 resizing, crop되는 부분은 랜덤, ex) 중앙하단, 왼쪽 상단 등
        transforms.RandomResizedCrop((img_size, img_size), scale=(0.05, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 픽셀에 normalize 적용, 각 채널의 평균과 표준 편차, pre trained된 데이터의 통계를 기반으로 설정
    ])
    transform_test = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    trainset = CustomDataset(img_dir=train_img_dir,
                             annotations_dir=train_annotation_dir,
                             transform=transform_train)
    testset = CustomDataset(img_dir=test_img_dir,
                            annotations_dir=test_annotation_dir,
                            transform=transform_test)
    

    train_sampler = RandomSampler(trainset) # 데이터셋에서 무작위로 샘플을 선택,  데이터셋의 인덱스를 무작위로 섞어서 데이터의 순서를 랜덤하게 배치, 모델이 순서에 의존하지 않고 학습함
    test_sampler = SequentialSampler(testset) # 데이터셋에서 순차적으로 샘플을 선택, 데이터를 처음부터 끝까지 순서대로 샘플링
    train_loader = DataLoader(trainset,
                              sampler=train_sampler,
                              batch_size=train_batch_size,
                              num_workers=0,
                              pin_memory=True)
    test_loader = DataLoader(testset,
                             sampler=test_sampler,
                             batch_size=eval_batch_size,
                             num_workers=0,
                             pin_memory=True) if testset is not None else None

    return train_loader, test_loader

### 학습

In [7]:
#학습 시 하이퍼 파라미터 지정
num_epochs=15
train_batch_size = 64  # 훈련 배치 크기
eval_batch_size = 64  # 평가 배치 크기
train_img_dir= 'C:/Users/USER/Desktop/data/train_img'
train_annotation_dir= 'C:/Users/USER/Desktop/data/train_label'
test_img_dir='C:/Users/USER/Desktop/data/valid_img'
test_annotaion_dir='C:/Users/USER/Desktop/data/valid_label'
name= 'mlp_mixer_class_3_final'
output_dir= 'output'
eval_every = 113  # 몇 스텝마다 평가를 할 것인지
learning_rate = 3e-2  # 초기 학습률
weight_decay = 0  # 가중치 감소율
sweight_decay = 0  # 가중치 감소율
num_steps = eval_every * num_epochs  # 총 훈련 스텝
seed = 42  # 초기화를 위한 랜덤 시드
n_gpu=1
gradient_accumulation_steps = 1  # 업데이트를 위해 누적할 스텝 수
warmup_steps = 500  # 웜업을 위한 스텝 수
max_grad_norm = 1.0  # 최대 그래디언트 노름

In [16]:
cd Desktop

C:\Users\USER\Desktop


In [17]:
pwd

'C:\\Users\\USER\\Desktop'

In [22]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import math
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
import random
import numpy as np
import logging
from PIL import Image, ImageFile

from torch.utils.tensorboard import SummaryWriter
ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

#스케줄러
class WarmupCosineSchedule(LambdaLR):
    """ 학습 초기에는 학습률을 점진적으로 증가시키는 warmup기간을 갖고 그 이후에는 코사인 함수를 따라 학습률을 감소시킴
        스케줄러는 optimizer에 설정된 학습률을 기준으로 하여, 특정 시점에서의 학습률을 조절하는 비율(factor)을 계산
        warmup 기간 내: step / warmup_steps 비율로 학습률을 증가, 초기 학습률 x step/warmup_steps
        warmup 기간 이후: 초기 학습률 x 0.5 * (1 + cos(π * cycles * 2 * progress)) 공식에 따라 학습률이 감소
    """
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def valid(model, writer, test_loader,eval_batch_size, global_step,device):
    # Validation!
    eval_losses = AverageMeter()

    logger.info("***** Running Validation *****")
    logger.info("  Num steps = %d", len(test_loader))
    logger.info("  Batch size = %d", eval_batch_size)

    model.eval()
    all_preds, all_label = [], []
    epoch_iterator = tqdm(test_loader,
                          desc="Validating... (loss=X.X)",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True)
    loss_fct = torch.nn.CrossEntropyLoss()
    for step, batch in enumerate(epoch_iterator):
        batch = tuple(t.to(device) for t in batch)
        x, y = batch
        with torch.no_grad():
            
            logits = model(x)
            eval_loss = loss_fct(logits, y)
            eval_losses.update(eval_loss.item())

            preds = torch.argmax(logits, dim=-1)

        if len(all_preds) == 0:
            all_preds.append(preds.detach().cpu().numpy())
            all_label.append(y.detach().cpu().numpy())
        else:
            all_preds[0] = np.append(
                all_preds[0], preds.detach().cpu().numpy(), axis=0
            )
            all_label[0] = np.append(
                all_label[0], y.detach().cpu().numpy(), axis=0
            )
        epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

    all_preds, all_label = all_preds[0], all_label[0]
    accuracy = simple_accuracy(all_preds, all_label)
    print("Global Steps: %d" % global_step)
    print("Valid Loss: %2.5f" % eval_losses.avg)
    print("Valid Accuracy: %2.5f" % accuracy)

    logger.info("/n")
    logger.info("Validation Results")
    logger.info("Global Steps: %d" % global_step)
    logger.info("Valid Loss: %2.5f" % eval_losses.avg)
    logger.info("Valid Accuracy: %2.5f" % accuracy)

    writer.add_scalar("test/accuracy", scalar_value=accuracy, global_step=global_step)
    return accuracy

def save_model(model, output_dir, name):
    model_to_save = model.module if hasattr(model, 'module') else model
    model_checkpoint = os.path.join(output_dir, "%s_checkpoint.bin" % name)
    torch.save(model_to_save.state_dict(), model_checkpoint)
    logger.info("Saved model checkpoint to [DIR: %s]", output_dir)

def simple_accuracy(preds, labels):
    return (preds == labels).mean()

 



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

os.makedirs(output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join("logs", name))

train_loader, test_loader = get_loader(224, train_img_dir, train_annotation_dir, test_img_dir, test_annotaion_dir, train_batch_size, eval_batch_size)

optimizer = torch.optim.SGD(model.parameters(),
                            lr=learning_rate,
                            momentum=0.9,
                            weight_decay=weight_decay)
t_total = num_steps

# 스케줄러
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)

model.zero_grad()
# set seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

losses = AverageMeter()
global_step, best_acc = 0, 0

# epoch 단위로 학습 시작
for epoch in range(num_epochs):
    model.train()
    epoch_iterator = tqdm(train_loader,
                          desc=f"Epoch {epoch + 1}/{num_epochs}",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        batch = tuple(t.to(device) for t in batch)
        x, y = batch
        outputs = model(x)
        loss_fn = nn.CrossEntropyLoss()
        #loss 계산
        loss = loss_fn(outputs, y)
        #역전파 학습, 기울기 계산
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0: # 여러 스텝에 걸쳐 gradient를 축적하고 모델에 업데이트, gpu 메모리가 작을 때 큰 배치사이즈를 사용 가능
            #loss update
            losses.update(loss.item() * gradient_accumulation_steps)
            #임계값이 넘는 gradient를 임계값으로 수정, gradient 폭발 방지
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            scheduler.step() # 새로운 학습률을 계산하고 optimizer에 지정되어 있는 학습률 update
            optimizer.step() # weight update
            optimizer.zero_grad() # 이전 gradient 초기화
            global_step += 1

            epoch_iterator.set_description(
                f"Epoch {epoch + 1}/{num_epochs} (Step {global_step} / {t_total}) (Loss={losses.val:.5f})"
            )
            writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
            writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)

            if global_step % eval_every == 0:
                accuracy = valid(model, writer, test_loader, eval_batch_size, global_step, device)

                if best_acc < accuracy:
                    save_model(model, output_dir, name)
                    best_acc = accuracy
                model.train()
    
    losses.reset() # 각 epoch의 끝에 loss 초기화

writer.close()



Validating... (loss=1.02996): 100%|| 29/29 [03:04<00:00,  6.37s/it]00:07,  7.58s/it]
Epoch 1/15 (Step 113 / 1695) (Loss=0.85477): 100%|| 113/113 [16:23<00:00, 61.92s/it]

Global Steps: 113
Valid Loss: 0.92807
Valid Accuracy: 0.53500


Epoch 1/15 (Step 113 / 1695) (Loss=0.85477): 100%|| 113/113 [16:23<00:00,  8.71s/it]
Validating... (loss=1.19834): 100%|| 29/29 [02:52<00:00,  5.93s/it]00:07,  7.27s/it]


Global Steps: 226
Valid Loss: 0.66047
Valid Accuracy: 0.72278


Epoch 2/15 (Step 226 / 1695) (Loss=0.60045): 100%|| 113/113 [16:01<00:00,  8.51s/it]
Validating... (loss=1.05755): 100%|| 29/29 [02:50<00:00,  5.86s/it]00:07,  7.17s/it]
Epoch 3/15 (Step 339 / 1695) (Loss=0.45560): 100%|| 113/113 [15:55<00:00, 57.12s/it]

Global Steps: 339
Valid Loss: 0.61099
Valid Accuracy: 0.76000


Epoch 3/15 (Step 339 / 1695) (Loss=0.45560): 100%|| 113/113 [15:55<00:00,  8.45s/it]
Validating... (loss=0.75156): 100%|| 29/29 [02:48<00:00,  5.80s/it]00:06,  6.77s/it]
Epoch 4/15 (Step 452 / 1695) (Loss=0.53893): 100%|| 113/113 [15:54<00:00, 56.35s/it]

Global Steps: 452
Valid Loss: 0.56051
Valid Accuracy: 0.76278


Epoch 4/15 (Step 452 / 1695) (Loss=0.53893): 100%|| 113/113 [15:54<00:00,  8.45s/it]
Validating... (loss=1.23321): 100%|| 29/29 [02:50<00:00,  5.88s/it]00:07,  7.10s/it]
Epoch 5/15 (Step 565 / 1695) (Loss=0.81489): 100%|| 113/113 [15:56<00:00,  8.46s/it]


Global Steps: 565
Valid Loss: 0.70866
Valid Accuracy: 0.68889


Validating... (loss=0.66739): 100%|| 29/29 [02:50<00:00,  5.87s/it]00:07,  7.15s/it]
Epoch 6/15 (Step 678 / 1695) (Loss=0.49435): 100%|| 113/113 [15:55<00:00, 57.21s/it]

Global Steps: 678
Valid Loss: 0.56712
Valid Accuracy: 0.77722


Epoch 6/15 (Step 678 / 1695) (Loss=0.49435): 100%|| 113/113 [15:55<00:00,  8.45s/it]
Validating... (loss=0.82345): 100%|| 29/29 [02:50<00:00,  5.86s/it]00:06,  6.78s/it]
Epoch 7/15 (Step 791 / 1695) (Loss=0.53851): 100%|| 113/113 [15:54<00:00, 56.84s/it]

Global Steps: 791
Valid Loss: 0.45838
Valid Accuracy: 0.82611


Epoch 7/15 (Step 791 / 1695) (Loss=0.53851): 100%|| 113/113 [15:54<00:00,  8.45s/it]
Validating... (loss=0.93572): 100%|| 29/29 [02:50<00:00,  5.87s/it]00:06,  6.93s/it]
Epoch 8/15 (Step 904 / 1695) (Loss=0.52666): 100%|| 113/113 [15:54<00:00, 57.06s/it]

Global Steps: 904
Valid Loss: 0.44099
Valid Accuracy: 0.82722


Epoch 8/15 (Step 904 / 1695) (Loss=0.52666): 100%|| 113/113 [15:54<00:00,  8.45s/it]
Validating... (loss=0.76807): 100%|| 29/29 [02:51<00:00,  5.90s/it]<00:07,  7.17s/it]
Epoch 9/15 (Step 1017 / 1695) (Loss=0.45519): 100%|| 113/113 [15:56<00:00, 57.41s/it]

Global Steps: 1017
Valid Loss: 0.43088
Valid Accuracy: 0.83444


Epoch 9/15 (Step 1017 / 1695) (Loss=0.45519): 100%|| 113/113 [15:56<00:00,  8.46s/it]
Validating... (loss=0.27453): 100%|| 29/29 [02:50<00:00,  5.86s/it]7<00:06,  6.87s/it]
Epoch 10/15 (Step 1130 / 1695) (Loss=0.61595): 100%|| 113/113 [15:57<00:00, 56.97s/it]

Global Steps: 1130
Valid Loss: 0.39474
Valid Accuracy: 0.83944


Epoch 10/15 (Step 1130 / 1695) (Loss=0.61595): 100%|| 113/113 [15:57<00:00,  8.47s/it]
Validating... (loss=0.31959): 100%|| 29/29 [02:48<00:00,  5.80s/it]5<00:06,  6.82s/it]
Epoch 11/15 (Step 1243 / 1695) (Loss=0.37431): 100%|| 113/113 [15:53<00:00, 56.32s/it]

Global Steps: 1243
Valid Loss: 0.38341
Valid Accuracy: 0.84667


Epoch 11/15 (Step 1243 / 1695) (Loss=0.37431): 100%|| 113/113 [15:53<00:00,  8.44s/it]
Validating... (loss=0.07008): 100%|| 29/29 [02:49<00:00,  5.86s/it]0<00:06,  6.81s/it]


Global Steps: 1356
Valid Loss: 0.36632
Valid Accuracy: 0.85278


Epoch 12/15 (Step 1356 / 1695) (Loss=0.50477): 100%|| 113/113 [15:50<00:00,  8.41s/it]
Validating... (loss=0.09151): 100%|| 29/29 [02:50<00:00,  5.86s/it]2<00:06,  6.95s/it]
Epoch 13/15 (Step 1469 / 1695) (Loss=0.49020): 100%|| 113/113 [15:52<00:00, 56.89s/it]

Global Steps: 1469
Valid Loss: 0.37080
Valid Accuracy: 0.85389


Epoch 13/15 (Step 1469 / 1695) (Loss=0.49020): 100%|| 113/113 [15:52<00:00,  8.43s/it]
Validating... (loss=0.08488): 100%|| 29/29 [02:49<00:00,  5.84s/it]0<00:06,  6.75s/it]
Epoch 14/15 (Step 1582 / 1695) (Loss=0.31477): 100%|| 113/113 [15:50<00:00, 56.76s/it]

Global Steps: 1582
Valid Loss: 0.36491
Valid Accuracy: 0.86167


Epoch 14/15 (Step 1582 / 1695) (Loss=0.31477): 100%|| 113/113 [15:50<00:00,  8.41s/it]
Validating... (loss=0.08366): 100%|| 29/29 [02:49<00:00,  5.84s/it]3<00:06,  6.97s/it]


Global Steps: 1695
Valid Loss: 0.36217
Valid Accuracy: 0.86333


Epoch 15/15 (Step 1695 / 1695) (Loss=0.32732): 100%|| 113/113 [15:53<00:00,  8.44s/it]


### 출력층만 학습

In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import math
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
import random
import numpy as np
import logging
from PIL import Image, ImageFile

from torch.utils.tensorboard import SummaryWriter
ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

#스케줄러
class WarmupCosineSchedule(LambdaLR):
    """ 학습 초기에는 학습률을 점진적으로 증가시키는 warmup기간을 갖고 그 이후에는 코사인 함수를 따라 학습률을 감소시킴
        스케줄러는 optimizer에 설정된 학습률을 기준으로 하여, 특정 시점에서의 학습률을 조절하는 비율(factor)을 계산
        warmup 기간 내: step / warmup_steps 비율로 학습률을 증가, 초기 학습률 x step/warmup_steps
        warmup 기간 이후: 초기 학습률 x 0.5 * (1 + cos(π * cycles * 2 * progress)) 공식에 따라 학습률이 감소
    """
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def valid(model, writer, test_loader,eval_batch_size, global_step,device):
    # Validation!
    eval_losses = AverageMeter()

    logger.info("***** Running Validation *****")
    logger.info("  Num steps = %d", len(test_loader))
    logger.info("  Batch size = %d", eval_batch_size)

    model.eval()
    all_preds, all_label = [], []
    epoch_iterator = tqdm(test_loader,
                          desc="Validating... (loss=X.X)",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True)
    loss_fct = torch.nn.CrossEntropyLoss()
    for step, batch in enumerate(epoch_iterator):
        batch = tuple(t.to(device) for t in batch)
        x, y = batch
        with torch.no_grad():
            
            logits = model(x)
            eval_loss = loss_fct(logits, y)
            eval_losses.update(eval_loss.item())

            preds = torch.argmax(logits, dim=-1)

        if len(all_preds) == 0:
            all_preds.append(preds.detach().cpu().numpy())
            all_label.append(y.detach().cpu().numpy())
        else:
            all_preds[0] = np.append(
                all_preds[0], preds.detach().cpu().numpy(), axis=0
            )
            all_label[0] = np.append(
                all_label[0], y.detach().cpu().numpy(), axis=0
            )
        epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

    all_preds, all_label = all_preds[0], all_label[0]
    accuracy = simple_accuracy(all_preds, all_label)
    print("Global Steps: %d" % global_step)
    print("Valid Loss: %2.5f" % eval_losses.avg)
    print("Valid Accuracy: %2.5f" % accuracy)

    logger.info("/n")
    logger.info("Validation Results")
    logger.info("Global Steps: %d" % global_step)
    logger.info("Valid Loss: %2.5f" % eval_losses.avg)
    logger.info("Valid Accuracy: %2.5f" % accuracy)

    writer.add_scalar("test/accuracy", scalar_value=accuracy, global_step=global_step)
    return accuracy

def save_model(model, output_dir, name):
    model_to_save = model.module if hasattr(model, 'module') else model
    model_checkpoint = os.path.join(output_dir, "%s_checkpoint.bin" % name)
    torch.save(model_to_save.state_dict(), model_checkpoint)
    logger.info("Saved model checkpoint to [DIR: %s]", output_dir)

def simple_accuracy(preds, labels):
    return (preds == labels).mean()

 



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

os.makedirs(output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join("logs", name))

train_loader, test_loader = get_loader(224, train_img_dir, train_annotation_dir, test_img_dir, test_annotaion_dir, train_batch_size, eval_batch_size)

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
                            lr=learning_rate,
                            momentum=0.9,
                            weight_decay=weight_decay)
t_total = num_steps

# 스케줄러
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)

model.zero_grad()
# set seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
    torch.cuda.manual_seed_all(seed)

losses = AverageMeter()
global_step, best_acc = 0, 0

# epoch 단위로 학습 시작
for epoch in range(num_epochs):
    model.train()
    epoch_iterator = tqdm(train_loader,
                          desc=f"Epoch {epoch + 1}/{num_epochs}",
                          bar_format="{l_bar}{r_bar}",
                          dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        batch = tuple(t.to(device) for t in batch)
        x, y = batch
        outputs = model(x)
        loss_fn = nn.CrossEntropyLoss()
        #loss 계산
        loss = loss_fn(outputs, y)
        #역전파 학습, 기울기 계산
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0: # 여러 스텝에 걸쳐 gradient를 축적하고 모델에 업데이트, gpu 메모리가 작을 때 큰 배치사이즈를 사용 가능
            #loss update
            losses.update(loss.item() * gradient_accumulation_steps)
            #임계값이 넘는 gradient를 임계값으로 수정, gradient 폭발 방지
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            scheduler.step() # 새로운 학습률을 계산하고 optimizer에 지정되어 있는 학습률 update
            optimizer.step() # weight update
            optimizer.zero_grad() # 이전 gradient 초기화
            global_step += 1

            epoch_iterator.set_description(
                f"Epoch {epoch + 1}/{num_epochs} (Step {global_step} / {t_total}) (Loss={losses.val:.5f})"
            )
            writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step)
            writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step)

            if global_step % eval_every == 0:
                accuracy = valid(model, writer, test_loader, eval_batch_size, global_step, device)

                if best_acc < accuracy:
                    save_model(model, output_dir, name)
                    best_acc = accuracy
                model.train()
    
    losses.reset() # 각 epoch의 끝에 loss 초기화

writer.close()

