In [1]:
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-5

In [2]:
### training set 있는 Path ###
BASE_PATH = '../data/train/'

In [3]:
import torchvision.transforms as transforms
### Transforms ###
transform = transforms.Compose([
    transforms.ToTensor(),
])

## Load dataset

In [4]:
from Dataset import AllClassDataset

In [5]:
import pandas as pd
df = pd.read_csv(BASE_PATH + 'train.csv')

In [6]:
# training set (validation set 따로 안 나눈 경우)
# train_dataset = AllClassDataset(
#     base_path = BASE_PATH, 
#     data = df, 
#     transform = transform
# )

In [7]:
# train / validation data
from sklearn.model_selection import train_test_split
def mapAgeGender(age, gender):
    answer = 0
    if age < 30:
        answer += 0
    elif age >= 60:
        answer += 2
    else:
        answer += 1
    return answer if gender == 'male' else answer + 3

y_data = df.apply(lambda x: mapAgeGender(x['age'], x['gender']), axis=1)
x_train, x_val, y_train, y_val = train_test_split(df.index, y_data, test_size=0.2, random_state=42, stratify=y_data)

In [8]:
train_dataset = AllClassDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_train], 
    transform = transform
)

100%|██████████| 2160/2160 [00:18<00:00, 119.36it/s]


In [9]:
val_dataset = AllClassDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_val], 
    transform = transform
)

100%|██████████| 540/540 [00:03<00:00, 136.47it/s]


### Dataset 확인용 Code
```python
import matplotlib.pyplot as plt
image, label = val_dataset[0]
plt.title(f'Class {label}')
plt.imshow(image.permute(1,2,0))
plt.show()
```

## DataLoader

In [10]:
trainloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    num_workers=1
)

In [11]:
valloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=64,
    num_workers=1
)

### Dataloader 확인용 Code
```python
import matplotlib.pyplot as plt
images, labels = next(iter(trainloader))
plt.title(f'Class {labels[0]}')
plt.imshow(images[0].permute(1,2,0))
plt.show()
```

## Model

In [12]:
from Model import ModifiedModel

In [13]:
model = ModifiedModel(num_classes = 18).to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /opt/ml/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## Train

In [15]:
from train import train

In [16]:
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
train_writer = SummaryWriter('runs/train_resnet18_.5_50_fz')
val_writer = SummaryWriter('runs/val_resnet18_.5_50_fz')

In [17]:
### Parameter 변경 ###
train(
    model=model, 
    optimizer=optimizer,
    train_loader=trainloader, 
    val_loader=valloader,
    device=DEVICE, 
    epochs=50, 
    save=True, 
    saved_folder="saved",
    train_writer=train_writer,
    val_writer=val_writer
)

100%|██████████| 237/237 [01:19<00:00,  2.96it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

Epoch 028: Loss: 0.0078 / Acc: 1.000 / F1: 1.00        | Val Loss: 0.3676 / Val Acc: 0.909 / Val F1: 0.83


100%|██████████| 237/237 [01:15<00:00,  3.12it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

Epoch 029: Loss: 0.0067 / Acc: 1.000 / F1: 1.00        | Val Loss: 0.3724 / Val Acc: 0.909 / Val F1: 0.83


100%|██████████| 237/237 [01:19<00:00,  2.96it/s]
  0%|          | 0/237 [00:00<?, ?it/s]

Epoch 030: Loss: 0.0058 / Acc: 1.000 / F1: 1.00        | Val Loss: 0.3773 / Val Acc: 0.908 / Val F1: 0.83


 39%|███▉      | 92/237 [00:31<00:49,  2.93it/s]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/opt/conda/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/conda/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/opt/conda/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

## Test

In [18]:
import os
import pandas as pd
from test import Test
### Test data dir ###
test_dir = '../data/eval'

In [19]:
test = Test(
    test_dir=test_dir,
    model=model, 
    optimizer=optimizer,
    device=DEVICE
)

In [20]:
test.loadSavedModel( 
    checkpoint_path='saved/ModifiedModel_13_0.03_0.992155.pt'
)

(ModifiedModel(
   (pretrained): ResNet(
     (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (layer1): Sequential(
       (0): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu): ReLU(inplace=True)
         (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       )
       (1): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0

In [21]:
test.predictTestData(
    transform=transform
)

  0%|          | 5/12600 [00:00<04:40, 44.89it/s]

Test Data loaded


100%|██████████| 12600/12600 [02:16<00:00, 92.60it/s] 


Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,13
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,1
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,13
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,13
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12
...,...,...
12595,d71d4570505d6af8f777690e63edfa8d85ea4476.jpg,1
12596,6cf1300e8e218716728d5820c0bab553306c2cfd.jpg,4
12597,8140edbba31c3a824e817e6d5fb95343199e2387.jpg,9
12598,030d439efe6fb5a7bafda45a393fc19f2bf57f54.jpg,1


In [22]:
test.submission('submission.csv')

test inference is done!


### 참고
https://pangate.com/967