In [5]:
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

from pytz import timezone
import datetime as dt

from sklearn.metrics import f1_score
from module.f1_Loss import F1_Loss

In [6]:
from models.model_kj import resnetbase as MaskModel
from datasets.dataset_kj import basicDatasetA as MaskDataset
from trans.trans_kj import A_random_trans_no_cut as TrainTrans
from trans.trans_kj import A_just_tensor as TestTrans

CLASS_NUM = 18
NUM_WORKERS = 4
BATCH_SIZE = 32
NUM_EPOCH = 30
SAVE_INTERVAL = 2

load_path = ''

comment = ''

In [7]:
c = ''
log = []

test_dir = '/opt/ml/input/data/train'
eval_dir = '/opt/ml/input/data/eval'
save_dir = '/opt/ml/image-classification-level1-25/save/'
now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MaskModel(CLASS_NUM)
if load_path : model.load_state_dict(torch.load(load_path))    
model = model.to(device)

# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = F1_Loss(classes=CLASS_NUM)
optm = torch.optim.Adam(model.parameters())

TrainTransform = TrainTrans()
TestTransfrom = TestTrans()

dataset_train_mask = MaskDataset(test_dir, train='train', transform=TrainTransform)
dataset_test_mask = MaskDataset(test_dir, train='test', transform=TestTransfrom)

dataloader_train_mask = DataLoader(dataset=dataset_train_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )
dataloader_test_mask = DataLoader(dataset=dataset_test_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )

dataloaders = {
        "train": dataloader_train_mask,
        "test": dataloader_test_mask
    }

In [8]:
log.append(f'{c:#^80}')
log.append(f'  [Comment]')
log.append(f'{comment}')
log.append(f'{c:#^80}')
log.append(c); log.append(c); log.append(c)

log.append(f'Model         : {model.__class__.__name__}')
log.append(f'  load_state  : {load_path}')
log.append(f'Dataset       : {dataset_train_mask.__class__.__name__}')
log.append(f'  train_len    {len(dataset_train_mask):>10}')
log.append(f'  test_len     {len(dataset_test_mask):>10}')
log.append(f'Train_trans   : {TrainTrans.__name__}')
log.append(f'Test_trans    : {TestTrans.__name__}')
log.append(f'Start_Date    : {now}')
log.append(f'Device        : {device}')
log.append(f'CLASS_NUM     : {CLASS_NUM}')
log.append(f'NUM_WORKERS   : {NUM_WORKERS}')
log.append(f'BATCH_SIZE    : {BATCH_SIZE}')
log.append(f'NUM_EPOCH     : {NUM_EPOCH}')
log.append(f'SAVE_INTERVAL : {SAVE_INTERVAL}')


for line in log:
    print(line)
    
log.append(c); log.append(c); log.append(c)

################################################################################
  [Comment]
no_cut test
################################################################################



Model         : resnetbase
  load_state  : 
Dataset       : basicDatasetA
  train_len         17010
  test_len           1890
Train_trans   : A_random_trans_no_cut
Test_trans    : A_just_tensor
Start_Date    : 20210827_162101
Device        : cuda:0
CLASS_NUM     : 18
NUM_WORKERS   : 4
BATCH_SIZE    : 32
NUM_EPOCH     : 30
SAVE_INTERVAL : 2


In [9]:
# SAMPLE 복붙
best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

for epoch in range(NUM_EPOCH):
    for phase in ["train", "test"]:
        running_loss = 0.
        running_acc = 0.
        running_f1 = 0.
        
        if phase == "train":
            model.train()
        elif phase == "test":
            model.eval() 
            
        for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
            images, labels = images.to(device), labels.to(device)

            optm.zero_grad()
            
            with torch.set_grad_enabled(phase == "train"):
                logits = model(images)
                _, preds = torch.max(logits, 1)
                loss = loss_fn(logits, labels)
                if phase == "train":
                    loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                    optm.step()  # 계산된 gradient를 가지고 모델 업데이트
            
            running_loss += loss.item() * images.size(0)
            running_acc += torch.sum(preds == labels.data)
            running_f1 += f1_loss(labels.data, preds)
            pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx*BATCH_SIZE):.3f}")
    
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_acc / len(dataloaders[phase].dataset)
        epoch_f1 = running_f1 / len(dataloaders[phase].dataset)

        log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc:.3f}, F1 : {epoch_f1:.3f}")
        print(log[-1])
        
        if phase == "test":
            if best_test_accuracy < epoch_acc:
                best_test_accuracy = epoch_acc
            if best_test_loss > epoch_loss:
                best_test_loss = epoch_loss
            if best_f1 < epoch_f1:
                best_f1 = epoch_f1
            if epoch % SAVE_INTERVAL == 0:
                torch.save(model.state_dict(), os.path.join(save_dir, f'{now}_{model.__class__.__name__}_epoch_{epoch:0>3d}.pt'))

torch.save(model.state_dict(), os.path.join(save_dir, f'{now}_{model.__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy}, 최저 loss : {best_test_loss}, 최고 F1 : {best_f1}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1]) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 0.943, Accuracy : 0.201, Macro_f1 : 0.09452


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 0.943, Accuracy : 0.173, Macro_f1 : 0.10520


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 0.934, Accuracy : 0.212, Macro_f1 : 0.10905


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 0.928, Accuracy : 0.292, Macro_f1 : 0.13775


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 0.922, Accuracy : 0.245, Macro_f1 : 0.13081


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 0.933, Accuracy : 0.280, Macro_f1 : 0.11481


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 0.873, Accuracy : 0.325, Macro_f1 : 0.20275


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 0.837, Accuracy : 0.374, Macro_f1 : 0.27106


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 0.839, Accuracy : 0.385, Macro_f1 : 0.25611


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 0.761, Accuracy : 0.541, Macro_f1 : 0.39407


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 0.796, Accuracy : 0.453, Macro_f1 : 0.32415


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 0.694, Accuracy : 0.625, Macro_f1 : 0.51933


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 0.762, Accuracy : 0.490, Macro_f1 : 0.37402


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 0.734, Accuracy : 0.587, Macro_f1 : 0.45340


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 0.746, Accuracy : 0.514, Macro_f1 : 0.39940


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 0.705, Accuracy : 0.625, Macro_f1 : 0.51974


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 0.738, Accuracy : 0.533, Macro_f1 : 0.41066


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 0.682, Accuracy : 0.637, Macro_f1 : 0.53794


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 0.723, Accuracy : 0.552, Macro_f1 : 0.43061


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 0.669, Accuracy : 0.685, Macro_f1 : 0.56916


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 010 // (avg) Loss : 0.714, Accuracy : 0.568, Macro_f1 : 0.44156


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 010 // (avg) Loss : 0.662, Accuracy : 0.683, Macro_f1 : 0.58276


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 011 // (avg) Loss : 0.710, Accuracy : 0.578, Macro_f1 : 0.45063


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 011 // (avg) Loss : 0.656, Accuracy : 0.696, Macro_f1 : 0.60244


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 012 // (avg) Loss : 0.702, Accuracy : 0.594, Macro_f1 : 0.46978


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 012 // (avg) Loss : 0.629, Accuracy : 0.750, Macro_f1 : 0.63412


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 013 // (avg) Loss : 0.689, Accuracy : 0.612, Macro_f1 : 0.48541


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 013 // (avg) Loss : 0.632, Accuracy : 0.740, Macro_f1 : 0.64404


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 014 // (avg) Loss : 0.691, Accuracy : 0.614, Macro_f1 : 0.48752


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 014 // (avg) Loss : 0.637, Accuracy : 0.728, Macro_f1 : 0.61290


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 015 // (avg) Loss : 0.683, Accuracy : 0.622, Macro_f1 : 0.49409


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 015 // (avg) Loss : 0.656, Accuracy : 0.684, Macro_f1 : 0.58669


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 016 // (avg) Loss : 0.679, Accuracy : 0.634, Macro_f1 : 0.49739


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 016 // (avg) Loss : 0.623, Accuracy : 0.756, Macro_f1 : 0.64969


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 017 // (avg) Loss : 0.668, Accuracy : 0.649, Macro_f1 : 0.51671


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 017 // (avg) Loss : 0.612, Accuracy : 0.777, Macro_f1 : 0.67112


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 018 // (avg) Loss : 0.666, Accuracy : 0.649, Macro_f1 : 0.51216


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 018 // (avg) Loss : 0.620, Accuracy : 0.757, Macro_f1 : 0.65576


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 019 // (avg) Loss : 0.662, Accuracy : 0.657, Macro_f1 : 0.51295


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 019 // (avg) Loss : 0.610, Accuracy : 0.775, Macro_f1 : 0.65384


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 020 // (avg) Loss : 0.657, Accuracy : 0.666, Macro_f1 : 0.51634


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 020 // (avg) Loss : 0.613, Accuracy : 0.744, Macro_f1 : 0.66781


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 021 // (avg) Loss : 0.654, Accuracy : 0.668, Macro_f1 : 0.52427


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 021 // (avg) Loss : 0.607, Accuracy : 0.768, Macro_f1 : 0.67304


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 022 // (avg) Loss : 0.652, Accuracy : 0.671, Macro_f1 : 0.52691


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 022 // (avg) Loss : 0.605, Accuracy : 0.775, Macro_f1 : 0.67229


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 023 // (avg) Loss : 0.646, Accuracy : 0.686, Macro_f1 : 0.54416


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 023 // (avg) Loss : 0.605, Accuracy : 0.768, Macro_f1 : 0.67579


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 024 // (avg) Loss : 0.643, Accuracy : 0.689, Macro_f1 : 0.54659


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 024 // (avg) Loss : 0.603, Accuracy : 0.784, Macro_f1 : 0.69337


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 025 // (avg) Loss : 0.640, Accuracy : 0.697, Macro_f1 : 0.55153


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 025 // (avg) Loss : 0.586, Accuracy : 0.817, Macro_f1 : 0.72153


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 026 // (avg) Loss : 0.640, Accuracy : 0.698, Macro_f1 : 0.54697


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 026 // (avg) Loss : 0.591, Accuracy : 0.774, Macro_f1 : 0.68934


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 027 // (avg) Loss : 0.636, Accuracy : 0.701, Macro_f1 : 0.55330


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 027 // (avg) Loss : 0.593, Accuracy : 0.798, Macro_f1 : 0.70769


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 028 // (avg) Loss : 0.634, Accuracy : 0.708, Macro_f1 : 0.55987


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 028 // (avg) Loss : 0.597, Accuracy : 0.786, Macro_f1 : 0.68126


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=532.0), HTML(value='')))


[TRAIN] Epoch 029 // (avg) Loss : 0.626, Accuracy : 0.716, Macro_f1 : 0.57920


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=60.0), HTML(value='')))


[TEST ] Epoch 029 // (avg) Loss : 0.587, Accuracy : 0.812, Macro_f1 : 0.72092



################################################################################
:::학습종료:::
최고 accuracy : 0.81693, 최저 loss : 0.58552, 최고 Macro_F1 : 0.72153
################################################################################


In [10]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image=np.array(image))['image'].float()
        return image

    def __len__(self):
        return len(self.img_paths)

# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(eval_dir, 'info.csv'))
image_dir = os.path.join(eval_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
dataset = TestDataset(image_paths, TestTransfrom)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12600.0), HTML(value='')))

test inference is done!

--------------------------------------------------------------------------------

Finish_Date    : 20210826_102851


In [11]:
# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(save_dir, f'{now}_result.csv'), index=False)
log.append(f'test inference is done!')
print(log[-1])
log.append(c)
print(log[-1])
log.append(f'{c:-^80}')
print(log[-1])
log.append(c)
print(log[-1])



# log 저장
with open(os.path.join(save_dir, f'{now}.log'), "w") as f:
    now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
    log.append(f'Finish_Date    : {now}')
    print(log[-1])
    for line in log: 
        f.write(line+'\n')

test inference is done!

--------------------------------------------------------------------------------

Finish_Date    : 20210827_171235
