In [1]:
import os
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm.notebook import tqdm

from pytz import timezone
import datetime as dt

from sklearn.metrics import f1_score

import wandb

In [2]:
from models.model_bj import resnetbase3 as MaskModel
from models.model_bj import MergeFreezeModel as ClassifierModel
from datasets.dataset_bj import SplitLabelsDatasetA as MaskDataset
from trans.trans_kj import A_random_trans_no_cut as TrainTrans
from trans.trans_kj import A_just_tensor as TestTrans

MASK_CLASS_NUM = 3
AGE_CLASS_NUM = 3
GENDER_CLASS_NUM = 2

CLASS_NUM = 18

NUM_WORKERS = 4
BATCH_SIZE = 32
NUM_EPOCH = 10
SAVE_INTERVAL = 3

wandb_run_name = 'RN18_splitL_frz_stepL'
wandb_project_name = 'lv1_p'
wandb_entity = 'presto105'

load_path = ''

comment = ''

In [3]:
c = ''
log = []

test_dir = '/opt/ml/input/data/train'
eval_dir = '/opt/ml/input/data/eval'
save_dir = '/opt/ml/image-classification-level1-25/save/'
now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mask_model = MaskModel(MASK_CLASS_NUM)
gender_model = MaskModel(GENDER_CLASS_NUM)
age_model = MaskModel(AGE_CLASS_NUM)

model_dict = {'mask' : mask_model, 'gender' : gender_model, 'age' : age_model}

if load_path :
    for label_class in model_dict.keys() :
        model_dict[label_class].load_state_dict(torch.load(load_path))
for label_class in model_dict.keys() :
    model_dict[label_class].to(device)

mask_loss_fn = torch.nn.CrossEntropyLoss()
mask_optm = torch.optim.Adam(mask_model.parameters())

gender_loss_fn = torch.nn.CrossEntropyLoss()
gender_optm = torch.optim.Adam(gender_model.parameters())

age_loss_fn = torch.nn.CrossEntropyLoss()
age_optm = torch.optim.Adam(age_model.parameters())

loss_dict = {'mask' : mask_loss_fn, 'gender' : gender_loss_fn, 'age' : age_loss_fn}
optm_dict = {'mask' : mask_optm, 'gender' : gender_optm, 'age' : age_optm}

In [4]:
CONCAT_NUM = MASK_CLASS_NUM + AGE_CLASS_NUM + GENDER_CLASS_NUM
merged_model = ClassifierModel(mask_model, gender_model, age_model,
                                concatclasses= CONCAT_NUM, numclasses=CLASS_NUM)
merged_model.to(device)

merged_loss_fn = torch.nn.CrossEntropyLoss()
merged_optm = torch.optim.Adam(merged_model.parameters())

scheduler = torch.optim.lr_scheduler.MultiStepLR(merged_optm, milestones=[6,8,9], gamma=0.1)
lrs = []

In [5]:
TrainTransform = TrainTrans()
TestTransfrom = TestTrans()

dataset_train_mask = MaskDataset(test_dir, train='train', transform=TrainTransform)
dataset_test_mask = MaskDataset(test_dir, train='test', transform=TestTransfrom)

dataloader_train_mask = DataLoader(dataset=dataset_train_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )
dataloader_test_mask = DataLoader(dataset=dataset_test_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )

dataloaders = {
        "train": dataloader_train_mask,
        "test": dataloader_test_mask
}

In [6]:
log.append(f'{c:#^80}')
log.append(f'  [Comment]')
log.append(f'{comment}')
log.append(f'{c:#^80}')
log.append(c); log.append(c); log.append(c)

log.append(f'Model         : {merged_model.__class__.__name__}')
log.append(f'  load_state  : {load_path}')
log.append(f'Dataset       : {dataset_train_mask.__class__.__name__}')
log.append(f'  train_len    {len(dataset_train_mask):>10}')
log.append(f'  test_len     {len(dataset_test_mask):>10}')
log.append(f'Train_trans   : {TrainTrans.__name__}')
log.append(f'Test_trans    : {TestTrans.__name__}')
log.append(f'Start_Date    : {now}')
log.append(f'Device        : {device}')
log.append(f'CLASS_NUM     : {CLASS_NUM}')
log.append(f'NUM_WORKERS   : {NUM_WORKERS}')
log.append(f'BATCH_SIZE    : {BATCH_SIZE}')
log.append(f'NUM_EPOCH     : {NUM_EPOCH}')
log.append(f'SAVE_INTERVAL : {SAVE_INTERVAL}')


for line in log:
    print(line)
    
log.append(c); log.append(c); log.append(c)

################################################################################
  [Comment]

################################################################################



Model         : MergeFreezeModel
  load_state  : 
Dataset       : SplitLabelsDatasetA
  train_len         17010
  test_len           1890
Train_trans   : A_random_trans_no_cut
Test_trans    : A_just_tensor
Start_Date    : 20210830_003637
Device        : cuda:0
CLASS_NUM     : 18
NUM_WORKERS   : 4
BATCH_SIZE    : 32
NUM_EPOCH     : 10
SAVE_INTERVAL : 3


In [8]:
config={"epochs": NUM_EPOCH, "batch_size": BATCH_SIZE}

best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

for label_class in ['mask', 'gender', 'age'] :    
    print(label_class)
    log.append(f"{label_class}")
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optm_dict[label_class], milestones=[6,8,9], gamma=0.1)
    # lrs = []
    for epoch in range(NUM_EPOCH):
        print(f'-'*80)
        print(f'epoch:{epoch}')
        for phase in ["train", "test"]:
            running_loss = 0.
            running_acc = 0.
            running_f1 = 0.
            
            if phase == "train":
                model_dict[label_class].train()
            elif phase == "test":
                model_dict[label_class].eval() 

            for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
                images, labels = images.to(device), labels[label_class].to(device)

                optm_dict[label_class].zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    logits = model_dict[label_class](images)
                    _, preds = torch.max(logits, 1)
                    loss = loss_dict[label_class](logits, labels)
                    if phase == "train":
                        loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                        optm_dict[label_class].step()  # 계산된 gradient를 가지고 모델 업데이트
                        # lrs.append(optm_dict[label_class].param_groups[0]["lr"])
                        # scheduler.step()
                           
                running_loss += loss.item() * images.size(0)
                running_acc += torch.sum(preds == labels.data)
                running_f1 += f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
                pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx):.3f}")

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_acc / len(dataloaders[phase].dataset)
            epoch_f1 = running_f1 / len(dataloaders[phase])

            log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc*100:.3f}, F1 : {epoch_f1:.3f}")
            print(log[-1])
            
            if phase == "test":
                wandb.init(project=wandb_project_name, entity=wandb_entity, config=config)
                wandb.run.name = wandb_run_name + '_' + label_class
                wandb.log({'accuracy': epoch_acc, 'loss': epoch_loss, 'F1': epoch_f1})
                if best_test_accuracy < epoch_acc:
                    best_test_accuracy = epoch_acc
                if best_test_loss > epoch_loss:
                    best_test_loss = epoch_loss
                if best_f1 < epoch_f1:
                    best_f1 = epoch_f1
                if epoch % SAVE_INTERVAL == 0:
                    torch.save(model_dict[label_class].state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{model_dict[label_class].__class__.__name__}_epoch_{epoch:0>3d}.pt'))
    torch.save(model_dict[label_class].state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{model_dict[label_class].__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy:.5f}, 최저 loss : {best_test_loss:.5f}, 최고 F1 : {best_f1:.5f}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1])

mask
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 0.271, Accuracy : 91.552, F1 : 0.851


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 0.031, Accuracy : 99.630, F1 : 0.994
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 0.194, Accuracy : 93.380, F1 : 0.885


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 0.022, Accuracy : 99.630, F1 : 0.994
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 0.168, Accuracy : 94.350, F1 : 0.904


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 0.019, Accuracy : 99.524, F1 : 0.993
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 0.157, Accuracy : 94.627, F1 : 0.910


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 0.015, Accuracy : 99.524, F1 : 0.993
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 0.145, Accuracy : 94.950, F1 : 0.916


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 0.013, Accuracy : 99.577, F1 : 0.993
--------------------------------------------------------------------------------
epoch:5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 0.138, Accuracy : 95.168, F1 : 0.920


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 0.014, Accuracy : 99.577, F1 : 0.993
--------------------------------------------------------------------------------
epoch:6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 0.126, Accuracy : 95.503, F1 : 0.927


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 0.013, Accuracy : 99.577, F1 : 0.993
--------------------------------------------------------------------------------
epoch:7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 0.115, Accuracy : 95.838, F1 : 0.932


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 0.013, Accuracy : 99.524, F1 : 0.993
--------------------------------------------------------------------------------
epoch:8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 0.079, Accuracy : 97.337, F1 : 0.957


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 0.010, Accuracy : 99.630, F1 : 0.994
--------------------------------------------------------------------------------
epoch:9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 0.062, Accuracy : 98.072, F1 : 0.969


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 0.010, Accuracy : 99.630, F1 : 0.994
gender
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 1.749, Accuracy : 64.627, F1 : 0.483


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 1.436, Accuracy : 58.148, F1 : 0.500
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 0.909, Accuracy : 67.137, F1 : 0.521


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 0.649, Accuracy : 69.259, F1 : 0.578
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 0.670, Accuracy : 68.730, F1 : 0.560


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 0.482, Accuracy : 73.862, F1 : 0.579
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 0.611, Accuracy : 69.553, F1 : 0.576


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 0.450, Accuracy : 75.714, F1 : 0.592
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 0.588, Accuracy : 70.858, F1 : 0.584


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 0.454, Accuracy : 75.873, F1 : 0.620
--------------------------------------------------------------------------------
epoch:5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 0.569, Accuracy : 71.905, F1 : 0.596


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 0.449, Accuracy : 76.402, F1 : 0.623
--------------------------------------------------------------------------------
epoch:6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 0.558, Accuracy : 72.616, F1 : 0.599


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 0.440, Accuracy : 76.825, F1 : 0.638
--------------------------------------------------------------------------------
epoch:7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 0.542, Accuracy : 73.992, F1 : 0.615


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 0.443, Accuracy : 76.508, F1 : 0.642
--------------------------------------------------------------------------------
epoch:8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 0.535, Accuracy : 74.444, F1 : 0.622


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 0.432, Accuracy : 77.090, F1 : 0.646
--------------------------------------------------------------------------------
epoch:9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 0.519, Accuracy : 76.208, F1 : 0.639


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 0.420, Accuracy : 77.989, F1 : 0.651
age
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 12.678, Accuracy : 50.317, F1 : 0.414


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 25.197, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 7.356, Accuracy : 50.494, F1 : 0.416


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 22.120, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 4.246, Accuracy : 50.353, F1 : 0.370


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 24.266, Accuracy : 0.053, F1 : 0.001
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 3.153, Accuracy : 52.063, F1 : 0.354


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 17.057, Accuracy : 12.857, F1 : 0.102
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 2.550, Accuracy : 53.657, F1 : 0.353


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 12.813, Accuracy : 33.122, F1 : 0.225
--------------------------------------------------------------------------------
epoch:5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 2.106, Accuracy : 52.869, F1 : 0.343


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 10.062, Accuracy : 37.513, F1 : 0.248
--------------------------------------------------------------------------------
epoch:6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 1.778, Accuracy : 53.733, F1 : 0.349


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 9.862, Accuracy : 38.095, F1 : 0.256
--------------------------------------------------------------------------------
epoch:7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 1.520, Accuracy : 53.063, F1 : 0.341


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 7.273, Accuracy : 44.815, F1 : 0.306
--------------------------------------------------------------------------------
epoch:8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 1.278, Accuracy : 54.203, F1 : 0.344


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 5.225, Accuracy : 50.370, F1 : 0.341
--------------------------------------------------------------------------------
epoch:9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 1.110, Accuracy : 54.009, F1 : 0.342


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 3.011, Accuracy : 55.397, F1 : 0.369



################################################################################
:::학습종료:::
최고 accuracy : 0.99630, 최저 loss : 0.01007, 최고 F1 : 0.99399
################################################################################


In [9]:
CONCAT_NUM = MASK_CLASS_NUM + AGE_CLASS_NUM + GENDER_CLASS_NUM
merged_model = ClassifierModel(mask_model, gender_model, age_model,
                                concatclasses= CONCAT_NUM, numclasses=CLASS_NUM)
merged_model.to(device)

merged_loss_fn = torch.nn.CrossEntropyLoss()
merged_optm = torch.optim.Adam(merged_model.parameters())

scheduler = torch.optim.lr_scheduler.MultiStepLR(merged_optm, milestones=[6,8,9], gamma=0.1)
lrs = []

In [10]:
best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

for epoch in range(NUM_EPOCH):
    for phase in ["train", "test"]:
        running_loss = 0.
        running_acc = 0.
        running_f1 = 0.
        
        if phase == "train":
            merged_model.train()
        elif phase == "test":
            merged_model.eval() 

        for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
            images, labels = images.to(device), labels['merged'].to(device)

            merged_optm.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                logits = merged_model(images)
                _, preds = torch.max(logits, 1)
                loss = merged_loss_fn(logits, labels)
                if phase == "train":
                    loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                    merged_optm.step()  # 계산된 gradient를 가지고 모델 업데이트
                    lrs.append(merged_optm.param_groups[0]["lr"])
                    scheduler.step()
                        
            running_loss += loss.item() * images.size(0)
            running_acc += torch.sum(preds == labels.data)
            running_f1 += f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
            pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx):.3f}")

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_acc / len(dataloaders[phase].dataset)
        epoch_f1 = running_f1 / len(dataloaders[phase])

        log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc*100:.3f}, F1 : {epoch_f1:.3f}")
        print(log[-1])
        
        if phase == "test":
            wandb.run.name = wandb_run_name + '_merged'
            wandb.log({'accuracy': epoch_acc, 'loss': epoch_loss, 'F1': epoch_f1})
            if best_test_accuracy < epoch_acc:
                best_test_accuracy = epoch_acc
            if best_test_loss > epoch_loss:
                best_test_loss = epoch_loss
            if best_f1 < epoch_f1:
                best_f1 = epoch_f1
            if epoch % SAVE_INTERVAL == 0:
                torch.save(merged_model.state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{merged_model.__class__.__name__}_epoch_{epoch:0>3d}.pt'))
torch.save(merged_model.state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{merged_model.__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy:.5f}, 최저 loss : {best_test_loss:.5f}, 최고 F1 : {best_f1:.5f}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 3.094, Accuracy : 0.247, F1 : 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 3.278, Accuracy : 0.476, F1 : 0.002


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 3.079, Accuracy : 0.347, F1 : 0.002


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 3.259, Accuracy : 0.476, F1 : 0.002


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 3.064, Accuracy : 0.341, F1 : 0.002


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 3.233, Accuracy : 0.476, F1 : 0.002


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))




KeyboardInterrupt: 

In [41]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)

# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(eval_dir, 'info.csv'))
image_dir = os.path.join(eval_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
transform = transforms.Compose([
    transforms.Resize((512, 384), Image.BILINEAR),
    # transforms.CenterCrop(300),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
])
dataset = TestDataset(image_paths, transform)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
merged_model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
    with torch.no_grad():
        images = images.to(device)
        pred = merged_model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12600.0), HTML(value='')))




In [42]:
# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(save_dir, f'{now}_result.csv'), index=False)
log.append(f'test inference is done!')
print(log[-1])
log.append(c)
print(log[-1])
log.append(f'{c:-^80}')
print(log[-1])
log.append(c)
print(log[-1])



# log 저장
with open(os.path.join(save_dir, f'{now}.log'), "w") as f:
    now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
    log.append(f'Finish_Date    : {now}')
    print(log[-1])
    for line in log: 
        f.write(line+'\n')

test inference is done!

--------------------------------------------------------------------------------

Finish_Date    : 20210830_002337
