In [9]:
import csv
import glob
import pandas as pd
import numpy as np
import PIL
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import copy
import os
import wandb
import time

# 1. 데이터 로드

In [10]:
train_dir = '../../input/data/train'
test_dir = '../../input/data/eval'
save_dir = '../saved/models/'

### 하이퍼파라미터

In [11]:
model_name = 'efficientnet_b1'
learning_rate = 1e-5
batch_size = 16
step_size = 5
epochs = 3
earlystop = 5

A_transform = {
    'train':
        A.Compose([
            A.Resize(224, 224),
            #A.RandomCrop(384, 384),
            A.HorizontalFlip(p=0.5),
            A.Cutout(num_holes=8, max_h_size=32,max_w_size=32),
            A.ElasticTransform(),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ]),
    'valid':
        A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ]),
    'test':
        A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
}

In [12]:
class LoadCSV():
    def __init__(self, dir):
        self.dir = dir
        self.img_dir =train_dir + '/new_images/'
        self.origin_csv_path = train_dir + '/train.csv'
        self.trans_csv_path = train_dir + '/trans_train.csv'
        
        if not os.path.exists(self.trans_csv_path):
            self._makeCSV()
        self.df = pd.read_csv(self.trans_csv_path)
        self.df = self.df[:200]
    def _makeCSV(self):        
        with open(self.trans_csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["path", "label"])

            df = pd.read_csv(self.origin_csv_path)
            for idx in range(len(df)):
                data = df.iloc[idx]
                img_path_base = os.path.join(os.path.join(self.img_dir, data['path']), '*')
                for img_path in glob.glob(img_path_base):
                    label = 0
                    if "incorrect" in img_path:
                        label+=6
                    elif 'normal' in img_path:
                        label+=12
                    elif data['gender']=='female':
                        label+=3
                    elif data['age'] >= 30 and data['age'] < 60:
                        label+=1
                    elif data['age'] >= 60:
                        label+=2
                    writer.writerow([img_path, label])
        f.close()

class MaskDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        super().__init__()
        self.df = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        class_id = torch.tensor(self.df['label'].iloc[idx])
        img = PIL.Image.open(self.df['path'].iloc[idx])
        img = np.array(img.convert("RGB"))
        if self.transform:
            img = self.transform(image=img)['image']
        return img, class_id

# 2. 모델 설계


In [13]:
class MyModel(nn.Module):
    def __init__(self, model_name, num_classes):
        super(MyModel, self).__init__()
        self.num_classes = num_classes
        self.model = timm.create_model(model_name, pretrained=True)

        n_features = self.model.classifier.in_features
        self.model.classifier = torch.nn.Linear(in_features=n_features, out_features=num_classes, bias=True)
        # n_features = self.model.head.in_features
        # self.model.head = torch.nn.Linear(in_features=n_features, out_features=self.num_classes, bias=True)

        torch.nn.init.xavier_uniform_(self.model.classifier.weight)
        stdv = 1/np.sqrt(self.num_classes)
        self.model.classifier.bias.data.uniform_(-stdv, stdv)
        
    def forward(self, x):
        return self.model(x)

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel(model_name, 18).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

# 3. 학습

In [15]:
today = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
if not os.path.exists(save_dir + today):
    os.makedirs(save_dir + today)

In [17]:
earlystop_value = 0
best_model = copy.deepcopy(model.state_dict())
best_acc = 0
best_loss = 999999999

from sklearn.model_selection import StratifiedKFold
mask_csv = LoadCSV(train_dir)
kfold = StratifiedKFold(n_splits=5, shuffle=False)

for fold, (train_idx, valid_idx) in enumerate(kfold.split(mask_csv.df['path'], mask_csv.df['label'])):
    print(f'FOLD {fold}')
    mask_train = MaskDataset(mask_csv.df,  transform=A_transform['train'])
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_idx)

    train_loader = DataLoader(mask_train, batch_size=batch_size, sampler=train_subsampler, drop_last=False, num_workers=8, pin_memory=True)
    valid_loader = DataLoader(mask_train, batch_size=batch_size, sampler=valid_subsampler, drop_last=False, num_workers=8, pin_memory=True)
    dataloaders = {'train': train_loader, 'valid':valid_loader}

    for epoch in range(epochs):
        if earlystop_value >= earlystop:
            break
        train_loss, valid_loss, train_acc_list, valid_acc_list = 0, 0, [],[]

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            with tqdm(dataloaders[phase], total=dataloaders[phase].__len__(), unit="batch") as train_bar:
                for inputs, labels in train_bar:
                    train_bar.set_description(f"{phase} Epoch {epoch} ")
                    inputs, labels = inputs.to(device), labels.to(device)

                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    outputs = outputs.cpu().detach().numpy()
                    labels = labels.cpu().detach().numpy()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += (np.argmax(outputs, axis=1)== labels).mean()
                    epoch_loss = running_loss / len(dataloaders[phase].dataset)
                    epoch_acc = running_corrects / len(dataloaders[phase].dataset)
                    train_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)

            lr_scheduler.step()
            if phase=='valid':
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(best_model_wts, f'{save_dir}{today}/baseline_{model_name}_lr{learning_rate}_stepLR{step_size}_batch{batch_size}_epoch{epoch}_valid_loss_{epoch_loss:.5f}.pt')
                    earlystop_value = 0
                else:
                    earlystop_value += 1

    model.load_state_dict(best_model_wts)

  0%|          | 0/10 [00:00<?, ?batch/s]

FOLD 0
[ 18  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57
  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75
  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93
  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111
 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199] [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40]


train Epoch 0 : 100%|██████████| 10/10 [00:02<00:00,  3.98batch/s, acc=0.00344, loss=2.51]
valid Epoch 0 : 100%|██████████| 3/3 [00:01<00:00,  2.48batch/s, acc=0.00125, loss=0.631]
train Epoch 1 : 100%|██████████| 10/10 [00:01<00:00,  5.02batch/s, acc=0.00313, loss=2.48]
valid Epoch 1 : 100%|██████████| 3/3 [00:01<00:00,  2.71batch/s, acc=0.00156, loss=0.612]
train Epoch 2 : 100%|██████████| 10/10 [00:02<00:00,  4.49batch/s, acc=0.00469, loss=2.39]
valid Epoch 2 : 100%|██████████| 3/3 [00:01<00:00,  2.42batch/s, acc=0.0025, loss=0.563]
  0%|          | 0/10 [00:00<?, ?batch/s]

FOLD 1
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  66  67  76  77  78  79  80  81  83  89  90  91  92  93
  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111
 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199] [18 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
 64 65 68 69 70 71 72 73 74 75 82 84 85 86 87 88]


train Epoch 0 : 100%|██████████| 10/10 [00:02<00:00,  4.59batch/s, acc=0.00531, loss=2.31]
valid Epoch 0 : 100%|██████████| 3/3 [00:01<00:00,  2.88batch/s, acc=0.00281, loss=0.569]
train Epoch 1 : 100%|██████████| 10/10 [00:02<00:00,  4.57batch/s, acc=0.00781, loss=2.25]
valid Epoch 1 : 100%|██████████| 3/3 [00:01<00:00,  2.34batch/s, acc=0.00281, loss=0.567]
train Epoch 2 : 100%|██████████| 10/10 [00:01<00:00,  5.03batch/s, acc=0.00813, loss=2.21]
valid Epoch 2 : 100%|██████████| 3/3 [00:01<00:00,  2.33batch/s, acc=0.00219, loss=0.556]
  0%|          | 0/10 [00:00<?, ?batch/s]

FOLD 2
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  68  69  70  71  72  73
  74  75  79  80  81  82  84  85  86  87  88 118 124 125 126 127 128 129
 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199] [ 66  67  76  77  78  83  89  90  91  92  93  94  95  96  97  98  99 100
 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 119
 120 121 122 123]


train Epoch 0 : 100%|██████████| 10/10 [00:02<00:00,  4.90batch/s, acc=0.00562, loss=2.2]
valid Epoch 0 : 100%|██████████| 3/3 [00:01<00:00,  2.38batch/s, acc=0.00219, loss=0.527]
train Epoch 1 : 100%|██████████| 10/10 [00:02<00:00,  4.08batch/s, acc=0.0109, loss=2.09]
valid Epoch 1 : 100%|██████████| 3/3 [00:01<00:00,  2.53batch/s, acc=0.00375, loss=0.501]
train Epoch 2 : 100%|██████████| 10/10 [00:02<00:00,  4.67batch/s, acc=0.0116, loss=2.01]
valid Epoch 2 : 100%|██████████| 3/3 [00:01<00:00,  2.56batch/s, acc=0.00344, loss=0.492]
  0%|          | 0/10 [00:00<?, ?batch/s]

FOLD 3
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  82  83  84  85  86  87  88  89  90  91  92
  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110
 111 112 113 114 115 116 117 119 120 121 122 123 159 160 161 162 163 164
 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 183
 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199] [ 79  80  81 118 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
 156 157 158 182]


train Epoch 0 : 100%|██████████| 10/10 [00:02<00:00,  4.46batch/s, acc=0.0122, loss=1.98]
valid Epoch 0 : 100%|██████████| 3/3 [00:01<00:00,  2.49batch/s, acc=0.0025, loss=0.485]
train Epoch 1 : 100%|██████████| 10/10 [00:02<00:00,  4.63batch/s, acc=0.0147, loss=1.9]
valid Epoch 1 : 100%|██████████| 3/3 [00:01<00:00,  2.92batch/s, acc=0.00344, loss=0.491]
train Epoch 2 : 100%|██████████| 10/10 [00:02<00:00,  4.69batch/s, acc=0.0197, loss=1.8]
valid Epoch 2 : 100%|██████████| 3/3 [00:01<00:00,  2.61batch/s, acc=0.00469, loss=0.476]
  0%|          | 0/10 [00:00<?, ?batch/s]

FOLD 4
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 182] [159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
 177 178 179 180 181 183 184 185 186 187 188 189 190 191 192 193 194 195
 196 197 198 199]


train Epoch 0 : 100%|██████████| 10/10 [00:01<00:00,  5.17batch/s, acc=0.0187, loss=1.81]
valid Epoch 0 : 100%|██████████| 3/3 [00:01<00:00,  2.62batch/s, acc=0.00594, loss=0.461]
train Epoch 1 : 100%|██████████| 10/10 [00:02<00:00,  4.72batch/s, acc=0.0153, loss=1.86]
valid Epoch 1 : 100%|██████████| 3/3 [00:01<00:00,  2.98batch/s, acc=0.00719, loss=0.422]
train Epoch 2 : 100%|██████████| 10/10 [00:01<00:00,  5.01batch/s, acc=0.0169, loss=1.84]
valid Epoch 2 : 100%|██████████| 3/3 [00:01<00:00,  2.63batch/s, acc=0.00781, loss=0.421]


# 4. 추론

In [None]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = PIL.Image.open(self.img_paths[index])
        image = np.array(image.convert("RGB"))
        if self.transform:
            image = self.transform(image=image)
            image = image['image']
        return image

    def __len__(self):
        return len(self.img_paths)

In [None]:
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))
image_dir = os.path.join(test_dir, 'new_images')

image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]
dataset = TestDataset(image_paths, A_transform['test'])
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

model.eval()
all_predictions = []
with tqdm(test_loader, total=test_loader.__len__(), unit="batch") as test_bar:
    for images in test_bar:
        with torch.no_grad():
            images = images.to(device)
            pred = model(images)
            pred = pred.argmax(dim=-1)
            all_predictions.extend(pred.cpu().numpy())
    
submission['ans'] = all_predictions
submission.to_csv(os.path.join(test_dir, 'submission.csv'), index=False)
print('test inference is done!')

  1%|          | 6/788 [00:00<00:57, 13.52batch/s]


FileNotFoundError: [Errno 2] No such file or directory: '../../input/data/eval/new_images/d8c0d7ae6cf662506012135e730b855c321dbc8f.jpg'

In [None]:
import wandb
wandb.login()

with wandb.init(project=model_name + str(today), entity='nudago'):
    wandb_config = wandb.config
    wandb_config.learning_rate = 0.01

    wandb.log({"loss": loss})
# wandb.log({"loss:" loss.item()}, step=example_ct)

[34m[1mwandb[0m: Currently logged in as: [33mnudago[0m (use `wandb login --relogin` to force relogin)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…