In [24]:
import os
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm.notebook import tqdm

from pytz import timezone
import datetime as dt

from sklearn.metrics import f1_score

import wandb

In [25]:
from models.model_bj import resnetbase3 as MaskModel
from models.model_bj import MergeFreezeModel as ClassifierModel
from datasets.dataset_bj import SplitLabelsDatasetA as MaskDataset
from trans.trans_kj import A_random_trans_no_cut as TrainTrans
from trans.trans_kj import A_just_tensor as TestTrans

MASK_CLASS_NUM = 3
AGE_CLASS_NUM = 3
GENDER_CLASS_NUM = 2

CLASS_NUM = 18

NUM_WORKERS = 4
BATCH_SIZE = 32
NUM_EPOCH = 5
SAVE_INTERVAL = 3

wandb_run_name = 'RN18_splitL_frz_ranSample'
wandb_project_name = 'lv1_p'
wandb_entity = 'presto105'

load_path = ''

comment = ''

In [26]:
c = ''
log = []

test_dir = '/opt/ml/input/data/train'
eval_dir = '/opt/ml/input/data/eval'
save_dir = '/opt/ml/image-classification-level1-25/save/'
now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mask_model = MaskModel(MASK_CLASS_NUM)
gender_model = MaskModel(GENDER_CLASS_NUM)
age_model = MaskModel(AGE_CLASS_NUM)

model_dict = {'mask' : mask_model, 'gender' : gender_model, 'age' : age_model}

if load_path :
    for label_class in model_dict.keys() :
        model_dict[label_class].load_state_dict(torch.load(load_path))
for label_class in model_dict.keys() :
    model_dict[label_class].to(device)

mask_loss_fn = torch.nn.CrossEntropyLoss()
mask_optm = torch.optim.Adam(mask_model.parameters())

gender_loss_fn = torch.nn.CrossEntropyLoss()
gender_optm = torch.optim.Adam(gender_model.parameters())

age_loss_fn = torch.nn.CrossEntropyLoss()
age_optm = torch.optim.Adam(age_model.parameters())

loss_dict = {'mask' : mask_loss_fn, 'gender' : gender_loss_fn, 'age' : age_loss_fn}
optm_dict = {'mask' : mask_optm, 'gender' : gender_optm, 'age' : age_optm}

In [27]:
CONCAT_NUM = MASK_CLASS_NUM + AGE_CLASS_NUM + GENDER_CLASS_NUM
merged_model = ClassifierModel(mask_model, gender_model, age_model,
                                concatclasses= CONCAT_NUM, numclasses=CLASS_NUM)
merged_model.to(device)

merged_loss_fn = torch.nn.CrossEntropyLoss()
merged_optm = torch.optim.Adam(merged_model.parameters())

scheduler = torch.optim.lr_scheduler.MultiStepLR(merged_optm, milestones=[6,8,9], gamma=0.1)
lrs = []

In [28]:
mask_model.superM.conv1.weight[0][0]

tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       device='cuda:0', grad_fn=<SelectBackward>)

In [29]:
next(merged_model.modelMASK.named_parameters())
merged_model.modelMASK.superM.conv1.weight[0][0]

tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],
       device='cuda:0', grad_fn=<SelectBackward>)

In [30]:
TrainTransform = TrainTrans()
TestTransfrom = TestTrans()

dataset_train_mask = MaskDataset(test_dir, train='train', transform=TrainTransform)
dataset_test_mask = MaskDataset(test_dir, train='test', transform=TestTransfrom)

dataloader_train_mask = DataLoader(dataset=dataset_train_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )
dataloader_test_mask = DataLoader(dataset=dataset_test_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )

dataloaders = {
        "train": dataloader_train_mask,
        "test": dataloader_test_mask
}

In [31]:
log.append(f'{c:#^80}')
log.append(f'  [Comment]')
log.append(f'{comment}')
log.append(f'{c:#^80}')
log.append(c); log.append(c); log.append(c)

log.append(f'Model         : {merged_model.__class__.__name__}')
log.append(f'  load_state  : {load_path}')
log.append(f'Dataset       : {dataset_train_mask.__class__.__name__}')
log.append(f'  train_len    {len(dataset_train_mask):>10}')
log.append(f'  test_len     {len(dataset_test_mask):>10}')
log.append(f'Train_trans   : {TrainTrans.__name__}')
log.append(f'Test_trans    : {TestTrans.__name__}')
log.append(f'Start_Date    : {now}')
log.append(f'Device        : {device}')
log.append(f'CLASS_NUM     : {CLASS_NUM}')
log.append(f'NUM_WORKERS   : {NUM_WORKERS}')
log.append(f'BATCH_SIZE    : {BATCH_SIZE}')
log.append(f'NUM_EPOCH     : {NUM_EPOCH}')
log.append(f'SAVE_INTERVAL : {SAVE_INTERVAL}')


for line in log:
    print(line)
    
log.append(c); log.append(c); log.append(c)

################################################################################
  [Comment]

################################################################################



Model         : MergeFreezeModel
  load_state  : 
Dataset       : SplitLabelsDatasetA
  train_len         17010
  test_len           1890
Train_trans   : A_random_trans_no_cut
Test_trans    : A_just_tensor
Start_Date    : 20210830_102118
Device        : cuda:0
CLASS_NUM     : 18
NUM_WORKERS   : 4
BATCH_SIZE    : 32
NUM_EPOCH     : 5
SAVE_INTERVAL : 3


In [32]:
config={"epochs": NUM_EPOCH, "batch_size": BATCH_SIZE}

best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

for label_class in ['mask', 'gender', 'age'] :    
    wandb.init(project=wandb_project_name, entity=wandb_entity, config=config)
    wandb.run.name = wandb_run_name + '_' + label_class
    log.append(f"{label_class}")
    print(label_class)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optm_dict[label_class], milestones=[6,8,9], gamma=0.1)
    lrs = []
    for epoch in range(NUM_EPOCH):
        print(f'-'*80)
        print(f'epoch:{epoch}')
        for phase in ["train", "test"]:
            running_loss = 0.
            running_acc = 0.
            running_f1 = 0.
            
            if phase == "train":
                model_dict[label_class].train()
            elif phase == "test":
                model_dict[label_class].eval() 

            for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
                images, labels = images.to(device), labels[label_class].to(device)

                optm_dict[label_class].zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    logits = model_dict[label_class](images)
                    _, preds = torch.max(logits, 1)
                    loss = loss_dict[label_class](logits, labels)
                    if phase == "train":
                        loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                        optm_dict[label_class].step()  # 계산된 gradient를 가지고 모델 업데이트
                        lrs.append(optm_dict[label_class].param_groups[0]["lr"])
                        scheduler.step()
                           
                running_loss += loss.item() * images.size(0)
                running_acc += torch.sum(preds == labels.data)
                running_f1 += f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
                pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx):.3f}")

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_acc / len(dataloaders[phase].dataset)
            epoch_f1 = running_f1 / len(dataloaders[phase])

            log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc*100:.3f}, F1 : {epoch_f1:.3f}")
            print(log[-1])
            
            if phase == "test":
                wandb.log({'test accuracy': epoch_acc*100, 'test loss': epoch_loss, 'test F1': epoch_f1})
                if best_test_accuracy < epoch_acc:
                    best_test_accuracy = epoch_acc
                if best_test_loss > epoch_loss:
                    best_test_loss = epoch_loss
                if best_f1 < epoch_f1:
                    best_f1 = epoch_f1
                if epoch % SAVE_INTERVAL == 0:
                    torch.save(model_dict[label_class].state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{model_dict[label_class].__class__.__name__}_epoch_{epoch:0>3d}.pt'))
            else :
                wandb.log({'train accuracy': epoch_acc*100, 'train loss': epoch_loss, 'train F1': epoch_f1})
    torch.save(model_dict[label_class].state_dict(), os.path.join(save_dir, f'{now}_{label_class}_{model_dict[label_class].__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy:.5f}, 최저 loss : {best_test_loss:.5f}, 최고 F1 : {best_f1:.5f}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1])

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train accuracy,65.60258
train loss,0.6675
train F1,0.52272
_runtime,301.0
_timestamp,1630286401.0
_step,5.0
test accuracy,72.32804
test loss,0.53157
test F1,0.5415


0,1
train accuracy,▁▄█
train loss,█▂▁
train F1,▁▅█
_runtime,▁▁▄▅██
_timestamp,▁▁▄▅██
_step,▁▂▄▅▇█
test accuracy,▁▆█
test loss,█▂▁
test F1,▁▅█


[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


mask
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 0.257, Accuracy : 91.581, F1 : 0.867


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 0.049, Accuracy : 98.519, F1 : 0.980
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 0.164, Accuracy : 95.103, F1 : 0.917


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 0.032, Accuracy : 99.048, F1 : 0.986
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 0.136, Accuracy : 96.020, F1 : 0.933


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 0.022, Accuracy : 99.418, F1 : 0.991
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 0.113, Accuracy : 96.508, F1 : 0.942


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 0.019, Accuracy : 99.471, F1 : 0.992
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 0.099, Accuracy : 96.908, F1 : 0.949


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 0.015, Accuracy : 99.471, F1 : 0.992


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train accuracy,96.9077
train loss,0.09917
train F1,0.94858
_runtime,494.0
_timestamp,1630286985.0
_step,9.0
test accuracy,99.4709
test loss,0.01524
test F1,0.99208


0,1
train accuracy,▁▆▇▇█
train loss,█▄▃▂▁
train F1,▁▅▇▇█
_runtime,▁▁▃▃▄▅▆▆██
_timestamp,▁▁▃▃▄▅▆▆██
_step,▁▂▃▃▄▅▆▆▇█
test accuracy,▁▅███
test loss,█▄▂▂▁
test F1,▁▅▇██


[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


gender
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 1.922, Accuracy : 64.403, F1 : 0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 1.479, Accuracy : 60.000, F1 : 0.528
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 1.054, Accuracy : 67.543, F1 : 0.520


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 0.857, Accuracy : 66.720, F1 : 0.580
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 0.745, Accuracy : 68.348, F1 : 0.551


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 0.604, Accuracy : 69.418, F1 : 0.596
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 0.634, Accuracy : 69.283, F1 : 0.567


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 0.491, Accuracy : 72.804, F1 : 0.623
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 0.600, Accuracy : 70.270, F1 : 0.577


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 0.478, Accuracy : 73.439, F1 : 0.632


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train accuracy,70.27043
train loss,0.60018
train F1,0.57671
_runtime,496.0
_timestamp,1630287486.0
_step,9.0
test accuracy,73.43916
test loss,0.47788
test F1,0.63167


0,1
train accuracy,▁▅▆▇█
train loss,█▃▂▁▁
train F1,▁▄▆▇█
_runtime,▁▁▃▃▄▅▆▆██
_timestamp,▁▁▃▃▄▅▆▆██
_step,▁▂▃▃▄▅▆▆▇█
test accuracy,▁▅▆██
test loss,█▄▂▁▁
test F1,▁▄▆▇█


[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


age
--------------------------------------------------------------------------------
epoch:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 13.888, Accuracy : 50.406, F1 : 0.415


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 21.837, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 10.142, Accuracy : 50.494, F1 : 0.416


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 16.027, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 7.151, Accuracy : 50.494, F1 : 0.416


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 9.702, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 5.129, Accuracy : 50.482, F1 : 0.415


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 5.051, Accuracy : 0.000, F1 : 0.000
--------------------------------------------------------------------------------
epoch:4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 3.594, Accuracy : 50.241, F1 : 0.391


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 2.334, Accuracy : 0.159, F1 : 0.002



################################################################################
:::학습종료:::
최고 accuracy : 0.99471, 최저 loss : 0.01524, 최고 F1 : 0.99208
################################################################################


In [13]:
# CONCAT_NUM = MASK_CLASS_NUM + AGE_CLASS_NUM + GENDER_CLASS_NUM
# merged_model = ClassifierModel(mask_model, gender_model, age_model,
#                                 concatclasses= CONCAT_NUM, numclasses=CLASS_NUM)
# merged_model.to(device)

# merged_loss_fn = torch.nn.CrossEntropyLoss()
# merged_optm = torch.optim.Adam(merged_model.parameters())

# scheduler = torch.optim.lr_scheduler.MultiStepLR(merged_optm, milestones=[6,8,9], gamma=0.1)
# lrs = []

In [33]:
wandb.init(project=wandb_project_name, entity=wandb_entity, config=config)
wandb.run.name = wandb_run_name + '_merged'

best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

NUM_EPOCH = 10
for epoch in range(NUM_EPOCH):
    for phase in ["train", "test"]:
        running_loss = 0.
        running_acc = 0.
        running_f1 = 0.
        
        if phase == "train":
            merged_model.train()
        elif phase == "test":
            merged_model.eval() 

        for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
            images, labels = images.to(device), labels['merged'].to(device)

            merged_optm.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                logits = merged_model(images)
                _, preds = torch.max(logits, 1)
                loss = merged_loss_fn(logits, labels)
                if phase == "train":
                    loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                    merged_optm.step()  # 계산된 gradient를 가지고 모델 업데이트
                    lrs.append(merged_optm.param_groups[0]["lr"])
                    scheduler.step()
                        
            running_loss += loss.item() * images.size(0)
            running_acc += torch.sum(preds == labels.data)
            running_f1 += f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
            pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx):.3f}")

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_acc / len(dataloaders[phase].dataset)
        epoch_f1 = running_f1 / len(dataloaders[phase])

        log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc*100:.3f}, F1 : {epoch_f1:.3f}")
        print(log[-1])
        
        if phase == "test":
            wandb.log({'test accuracy': epoch_acc*100, 'test loss': epoch_loss, 'test F1': epoch_f1})
            if best_test_accuracy < epoch_acc:
                best_test_accuracy = epoch_acc
            if best_test_loss > epoch_loss:
                best_test_loss = epoch_loss
            if best_f1 < epoch_f1:
                best_f1 = epoch_f1
            if epoch % SAVE_INTERVAL == 0:
                torch.save(merged_model.state_dict(), os.path.join(save_dir, f'{now}_merged_{merged_model.__class__.__name__}_epoch_{epoch:0>3d}.pt'))
        else :
            wandb.log({'train accuracy': epoch_acc*100, 'train loss': epoch_loss, 'train F1': epoch_f1})
torch.save(merged_model.state_dict(), os.path.join(save_dir, f'{now}_merged_{merged_model.__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy:.5f}, 최저 loss : {best_test_loss:.5f}, 최고 F1 : {best_f1:.5f}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1])

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train accuracy,50.24104
train loss,3.59416
train F1,0.39133
_runtime,497.0
_timestamp,1630287987.0
_step,9.0
test accuracy,0.15873
test loss,2.33392
test F1,0.00152


0,1
train accuracy,▆███▁
train loss,█▅▃▂▁
train F1,████▁
_runtime,▁▁▃▃▄▅▆▆██
_timestamp,▁▁▃▃▄▅▆▆██
_step,▁▂▃▃▄▅▆▆▇█
test accuracy,▁▁▁▁█
test loss,█▆▄▂▁
test F1,▁▁▁▁█


[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 2.008, Accuracy : 34.968, F1 : 0.233


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 1.200, Accuracy : 52.169, F1 : 0.469


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 1.653, Accuracy : 38.460, F1 : 0.277


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 1.205, Accuracy : 52.169, F1 : 0.468


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 1.627, Accuracy : 38.366, F1 : 0.273


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 1.203, Accuracy : 52.011, F1 : 0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 1.602, Accuracy : 38.883, F1 : 0.278


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 1.204, Accuracy : 51.799, F1 : 0.452


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 1.583, Accuracy : 39.406, F1 : 0.283


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 1.185, Accuracy : 51.058, F1 : 0.413


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 1.576, Accuracy : 39.518, F1 : 0.284


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 1.194, Accuracy : 49.894, F1 : 0.395


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 1.561, Accuracy : 40.141, F1 : 0.285


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 1.165, Accuracy : 48.624, F1 : 0.376


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 1.562, Accuracy : 40.270, F1 : 0.280


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 1.165, Accuracy : 48.677, F1 : 0.385


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 1.558, Accuracy : 41.217, F1 : 0.289


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 1.179, Accuracy : 48.783, F1 : 0.382


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 1.560, Accuracy : 39.965, F1 : 0.276


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 1.211, Accuracy : 46.614, F1 : 0.365



################################################################################
:::학습종료:::
최고 accuracy : 0.52169, 최저 loss : 1.16482, 최고 F1 : 0.46888
################################################################################


In [34]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)

# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(eval_dir, 'info.csv'))
image_dir = os.path.join(eval_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
transform = transforms.Compose([
    transforms.Resize((512, 384), Image.BILINEAR),
    # transforms.CenterCrop(300),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
])
dataset = TestDataset(image_paths, transform)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
merged_model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
    with torch.no_grad():
        images = images.to(device)
        pred = merged_model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12600.0), HTML(value='')))




In [35]:
# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(save_dir, f'{now}_result.csv'), index=False)
log.append(f'test inference is done!')
print(log[-1])
log.append(c)
print(log[-1])
log.append(f'{c:-^80}')
print(log[-1])
log.append(c)
print(log[-1])



# log 저장
with open(os.path.join(save_dir, f'{now}.log'), "w") as f:
    now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
    log.append(f'Finish_Date    : {now}')
    print(log[-1])
    for line in log: 
        f.write(line+'\n')

test inference is done!

--------------------------------------------------------------------------------

Finish_Date    : 20210830_111329
