In [2]:
import os
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm.notebook import tqdm

from pytz import timezone
import datetime as dt

from sklearn.metrics import f1_score

import wandb

In [3]:
from models.model_bj import resnetbase3 as MaskModel
from datasets.dataset_bj import basicDatasetA as MaskDataset
from trans.trans_bj import A_random_trans as TrainTrans
from trans.trans_bj import A_just_tensor as TestTrans

CLASS_NUM = 18
NUM_WORKERS = 4
BATCH_SIZE = 32
NUM_EPOCH = 30
SAVE_INTERVAL = 3

wandb_run_name = 'RN18_BS32_BL_JS_EP30'
wandb_project_name = 'lv1_p'
wandb_entity = 'presto105'

load_path = ''

comment = ''

In [6]:
c = ''
log = []

test_dir = '/opt/ml/input/data/train'
eval_dir = '/opt/ml/input/data/eval'
save_dir = '/opt/ml/image-classification-level1-25/save/'
now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MaskModel(CLASS_NUM)
if load_path : model.load_state_dict(torch.load(load_path))    
model = model.to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optm = torch.optim.Adam(model.parameters())

scheduler = torch.optim.lr_scheduler.MultiStepLR(optm, milestones=[6,8,9], gamma=0.1)
lrs = []


In [7]:
TrainTransform = TrainTrans()
TestTransfrom = TestTrans()

dataset_train_mask = MaskDataset(test_dir, train='train', transform=TrainTransform)
dataset_test_mask = MaskDataset(test_dir, train='test', transform=TestTransfrom)

dataloader_train_mask = DataLoader(dataset=dataset_train_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )
dataloader_test_mask = DataLoader(dataset=dataset_test_mask,
                                      batch_size=BATCH_SIZE,
                                      num_workers=NUM_WORKERS,
                                      )

dataloaders = {
        "train": dataloader_train_mask,
        "test": dataloader_test_mask
}

In [8]:
log.append(f'{c:#^80}')
log.append(f'  [Comment]')
log.append(f'{comment}')
log.append(f'{c:#^80}')
log.append(c); log.append(c); log.append(c)

log.append(f'Model         : {model.__class__.__name__}')
log.append(f'  load_state  : {load_path}')
log.append(f'Dataset       : {dataset_train_mask.__class__.__name__}')
log.append(f'  train_len    {len(dataset_train_mask):>10}')
log.append(f'  test_len     {len(dataset_test_mask):>10}')
log.append(f'Train_trans   : {TrainTrans.__name__}')
log.append(f'Test_trans    : {TestTrans.__name__}')
log.append(f'Start_Date    : {now}')
log.append(f'Device        : {device}')
log.append(f'CLASS_NUM     : {CLASS_NUM}')
log.append(f'NUM_WORKERS   : {NUM_WORKERS}')
log.append(f'BATCH_SIZE    : {BATCH_SIZE}')
log.append(f'NUM_EPOCH     : {NUM_EPOCH}')
log.append(f'SAVE_INTERVAL : {SAVE_INTERVAL}')


for line in log:
    print(line)
    
log.append(c); log.append(c); log.append(c)

################################################################################
  [Comment]

################################################################################



Model         : resnetbase3
  load_state  : 
Dataset       : basicDatasetA
  train_len         15120
  test_len           3780
Train_trans   : A_random_trans
Test_trans    : A_just_tensor
Start_Date    : 20210827_155926
Device        : cuda:0
CLASS_NUM     : 18
NUM_WORKERS   : 4
BATCH_SIZE    : 32
NUM_EPOCH     : 30
SAVE_INTERVAL : 3


In [9]:
config={"epochs": NUM_EPOCH, "batch_size": BATCH_SIZE}
wandb.init(project=wandb_project_name, entity=wandb_entity, config=config)
wandb.run.name = wandb_run_name

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpresto105[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [10]:
best_test_accuracy = 0.
best_test_loss = float('inf')
best_f1 = 0.

for epoch in range(NUM_EPOCH):
    for phase in ["train", "test"]:
        running_loss = 0.
        running_acc = 0.
        running_f1 = 0.
        
        if phase == "train":
            model.train()
        elif phase == "test":
            model.eval() 
            
        for idx, (images, labels) in enumerate(pbar := tqdm(dataloaders[phase]), start = 1):
            images, labels = images.to(device), labels.to(device)

            optm.zero_grad()
            
            with torch.set_grad_enabled(phase == "train"):
                logits = model(images)
                _, preds = torch.max(logits, 1)
                loss = loss_fn(logits, labels)
                if phase == "train":
                    loss.backward()  # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
                    optm.step()  # 계산된 gradient를 가지고 모델 업데이트
                    lrs.append(optm.param_groups[0]["lr"])
                    scheduler.step()
            
            running_loss += loss.item() * images.size(0)
            running_acc += torch.sum(preds == labels.data)
            running_f1 += f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
            pbar.set_description(f"loss : {running_loss/(idx*BATCH_SIZE):.3f}, acc : {running_acc/(idx*BATCH_SIZE):.3f}, f1 : {running_f1/(idx):.3f}")
    
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_acc / len(dataloaders[phase].dataset)
        epoch_f1 = running_f1 / len(dataloaders[phase])

        log.append(f"[{phase.upper():<5}] Epoch {epoch:0>3d} // (avg) Loss : {epoch_loss:.3f}, Accuracy : {epoch_acc*100:.3f}, F1 : {epoch_f1:.3f}")
        print(log[-1])
        
        if phase == "test":
            wandb.log({'accuracy': epoch_acc, 'loss': epoch_loss, 'F1': epoch_f1})
            if best_test_accuracy < epoch_acc:
                best_test_accuracy = epoch_acc
            if best_test_loss > epoch_loss:
                best_test_loss = epoch_loss
            if best_f1 < epoch_f1:
                best_f1 = epoch_f1
            if epoch % SAVE_INTERVAL == 0:
                torch.save(model.state_dict(), os.path.join(save_dir, f'{now}_{model.__class__.__name__}_epoch_{epoch:0>3d}.pt'))

torch.save(model.state_dict(), os.path.join(save_dir, f'{now}_{model.__class__.__name__}_finish_{NUM_EPOCH:0>3d}.pt'))


log.append(c)
print(log[-1])   
log.append(c)
print(log[-1])  
log.append(c)
print(log[-1])  
log.append(f'{c:#^80}')
print(log[-1])            
log.append(f':::학습종료:::')
print(log[-1])
log.append(f"최고 accuracy : {best_test_accuracy:.5f}, 최저 loss : {best_test_loss:.5f}, 최고 F1 : {best_f1:.5f}")
print(log[-1])
log.append(f'{c:#^80}')
print(log[-1]) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 000 // (avg) Loss : 2.875, Accuracy : 0.302, F1 : 0.170


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 000 // (avg) Loss : 1.987, Accuracy : 0.476, F1 : 0.299


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 001 // (avg) Loss : 2.280, Accuracy : 0.357, F1 : 0.221


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 001 // (avg) Loss : 1.814, Accuracy : 0.493, F1 : 0.302


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 002 // (avg) Loss : 2.048, Accuracy : 0.385, F1 : 0.250


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 002 // (avg) Loss : 1.688, Accuracy : 0.542, F1 : 0.324


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 003 // (avg) Loss : 1.913, Accuracy : 0.398, F1 : 0.262


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 003 // (avg) Loss : 1.553, Accuracy : 0.589, F1 : 0.373


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 004 // (avg) Loss : 1.790, Accuracy : 0.430, F1 : 0.292


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 004 // (avg) Loss : 1.452, Accuracy : 0.608, F1 : 0.392


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 005 // (avg) Loss : 1.699, Accuracy : 0.456, F1 : 0.310


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 005 // (avg) Loss : 1.345, Accuracy : 0.644, F1 : 0.446


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 006 // (avg) Loss : 1.632, Accuracy : 0.467, F1 : 0.322


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 006 // (avg) Loss : 1.256, Accuracy : 0.659, F1 : 0.455


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 007 // (avg) Loss : 1.559, Accuracy : 0.489, F1 : 0.341


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 007 // (avg) Loss : 1.207, Accuracy : 0.670, F1 : 0.470


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 008 // (avg) Loss : 1.485, Accuracy : 0.509, F1 : 0.368


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 008 // (avg) Loss : 1.159, Accuracy : 0.690, F1 : 0.492


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 009 // (avg) Loss : 1.436, Accuracy : 0.517, F1 : 0.375


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 009 // (avg) Loss : 1.144, Accuracy : 0.696, F1 : 0.498


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 010 // (avg) Loss : 1.395, Accuracy : 0.531, F1 : 0.388


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 010 // (avg) Loss : 1.092, Accuracy : 0.721, F1 : 0.540


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 011 // (avg) Loss : 1.344, Accuracy : 0.551, F1 : 0.398


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 011 // (avg) Loss : 1.071, Accuracy : 0.727, F1 : 0.549


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 012 // (avg) Loss : 1.308, Accuracy : 0.559, F1 : 0.413


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 012 // (avg) Loss : 1.031, Accuracy : 0.735, F1 : 0.553


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 013 // (avg) Loss : 1.271, Accuracy : 0.570, F1 : 0.427


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 013 // (avg) Loss : 1.024, Accuracy : 0.742, F1 : 0.569


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 014 // (avg) Loss : 1.246, Accuracy : 0.575, F1 : 0.428


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 014 // (avg) Loss : 1.014, Accuracy : 0.744, F1 : 0.572


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 015 // (avg) Loss : 1.223, Accuracy : 0.585, F1 : 0.439


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 015 // (avg) Loss : 0.989, Accuracy : 0.761, F1 : 0.593


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 016 // (avg) Loss : 1.194, Accuracy : 0.590, F1 : 0.444


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 016 // (avg) Loss : 0.986, Accuracy : 0.762, F1 : 0.593


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 017 // (avg) Loss : 1.161, Accuracy : 0.603, F1 : 0.453


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 017 // (avg) Loss : 0.981, Accuracy : 0.762, F1 : 0.594


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 018 // (avg) Loss : 1.149, Accuracy : 0.611, F1 : 0.468


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 018 // (avg) Loss : 0.963, Accuracy : 0.766, F1 : 0.603


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 019 // (avg) Loss : 1.138, Accuracy : 0.613, F1 : 0.466


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 019 // (avg) Loss : 0.947, Accuracy : 0.778, F1 : 0.617


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 020 // (avg) Loss : 1.111, Accuracy : 0.616, F1 : 0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 020 // (avg) Loss : 0.923, Accuracy : 0.789, F1 : 0.630


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 021 // (avg) Loss : 1.087, Accuracy : 0.627, F1 : 0.487


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 021 // (avg) Loss : 0.913, Accuracy : 0.793, F1 : 0.637


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 022 // (avg) Loss : 1.081, Accuracy : 0.628, F1 : 0.488


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 022 // (avg) Loss : 0.896, Accuracy : 0.801, F1 : 0.651


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 023 // (avg) Loss : 1.060, Accuracy : 0.638, F1 : 0.492


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 023 // (avg) Loss : 0.893, Accuracy : 0.799, F1 : 0.653


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 024 // (avg) Loss : 1.038, Accuracy : 0.644, F1 : 0.498


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 024 // (avg) Loss : 0.873, Accuracy : 0.803, F1 : 0.657


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 025 // (avg) Loss : 1.025, Accuracy : 0.646, F1 : 0.503


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 025 // (avg) Loss : 0.860, Accuracy : 0.808, F1 : 0.659


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 026 // (avg) Loss : 1.009, Accuracy : 0.656, F1 : 0.513


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 026 // (avg) Loss : 0.859, Accuracy : 0.810, F1 : 0.668


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 027 // (avg) Loss : 0.987, Accuracy : 0.661, F1 : 0.518


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 027 // (avg) Loss : 0.864, Accuracy : 0.813, F1 : 0.673


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 028 // (avg) Loss : 0.975, Accuracy : 0.666, F1 : 0.523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 028 // (avg) Loss : 0.831, Accuracy : 0.822, F1 : 0.688


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


[TRAIN] Epoch 029 // (avg) Loss : 0.963, Accuracy : 0.668, F1 : 0.525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))


[TEST ] Epoch 029 // (avg) Loss : 0.828, Accuracy : 0.822, F1 : 0.687



################################################################################
:::학습종료:::
최고 accuracy : 0.82222, 최저 loss : 0.82838, 최고 F1 : 0.68847
################################################################################


In [None]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)

# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(eval_dir, 'info.csv'))
image_dir = os.path.join(eval_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
transform = transforms.Compose([
    transforms.Resize((512, 384), Image.BILINEAR),
    # transforms.CenterCrop(300),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
])
dataset = TestDataset(image_paths, transform)

loader = DataLoader(
    dataset,
    shuffle=False
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12600.0), HTML(value='')))




In [None]:
# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(save_dir, f'{now}_result.csv'), index=False)
log.append(f'test inference is done!')
print(log[-1])
log.append(c)
print(log[-1])
log.append(f'{c:-^80}')
print(log[-1])
log.append(c)
print(log[-1])



# log 저장
with open(os.path.join(save_dir, f'{now}.log'), "w") as f:
    now = (dt.datetime.now().astimezone(timezone("Asia/Seoul")).strftime("%Y%m%d_%H%M%S"))
    log.append(f'Finish_Date    : {now}')
    print(log[-1])
    for line in log: 
        f.write(line+'\n')

test inference is done!

--------------------------------------------------------------------------------

Finish_Date    : 20210826_210245
