In [1]:
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-3

In [2]:
### training set 있는 Path ###
BASE_PATH = '../data/train/'

In [3]:
import torchvision.transforms as transforms
from noise import AddGaussianNoise
### Transforms ###
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(300),
    transforms.RandomHorizontalFlip(p=0.5)
#     AddGaussianNoise(0., 1.)
#     transforms.Resize(300)
])
val_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(300)
])

## Load dataset

In [4]:
from Dataset import TrainValidDataset

In [5]:
import pandas as pd
df = pd.read_csv(BASE_PATH + 'train.csv')

In [6]:
# training set (validation set 따로 안 나눈 경우)
# train_dataset = AllClassDataset(
#     base_path = BASE_PATH, 
#     data = df, 
#     transform = transform
# )

In [7]:
# train / validation data
from sklearn.model_selection import train_test_split
def mapAgeGender(age, gender):
    answer = 0
    if age < 30:
        answer += 0
    elif age >= 60:
        answer += 2
    else:
        answer += 1
    return answer if gender == 'male' else answer + 3

y_data = df.apply(lambda x: mapAgeGender(x['age'], x['gender']), axis=1)
x_train, x_val, y_train, y_val = train_test_split(df.index, y_data, test_size=0.2, random_state=42, stratify=y_data)

In [8]:
age_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_train], 
    transform = transform,
    label="age"
)

100%|██████████| 2160/2160 [00:17<00:00, 124.22it/s]


In [9]:
age_val_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_val], 
    transform = val_trans,
    label="age"
)

100%|██████████| 540/540 [00:03<00:00, 139.21it/s]


### Dataset 확인용 Code
```python
import matplotlib.pyplot as plt
image, label = val_dataset[0]
plt.title(f'Class {label}')
plt.imshow(image.permute(1,2,0))
plt.show()
```

## DataLoader

In [10]:
age_trainloader = torch.utils.data.DataLoader(
    age_dataset,
    batch_size=128,
    num_workers=1
)

In [11]:
age_valloader = torch.utils.data.DataLoader(
    age_val_dataset,
    batch_size=128,
    num_workers=1
)

### Dataloader 확인용 Code
```python
import matplotlib.pyplot as plt
images, labels = next(iter(trainloader))
plt.title(f'Class {labels[0]}')
plt.imshow(images[0].permute(1,2,0))
plt.show()
```

## Model

In [12]:
from Model import ResnetModel

In [13]:
age_model = ResnetModel(num_classes = 3).to(DEVICE)

In [14]:
age_optimizer = torch.optim.SGD(age_model.parameters(), lr=LEARNING_RATE)

## Train

In [15]:
from train import train

In [16]:
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
train_writer = SummaryWriter('runs/train_resnet18_age_hflip_.3_SGD')
val_writer = SummaryWriter('runs/val_resnet18_age_hflip_.3_SGD')

In [17]:
### Parameter 변경 ###
train(
    model=age_model, 
    optimizer=age_optimizer,
    train_loader=age_trainloader, 
    val_loader=age_valloader,
    num_classes=3,
    device=DEVICE, 
    epochs=40, 
    save=True, 
    saved_folder="saved/resnet18_age_hflip_.3_SGD",
    train_writer=train_writer,
    val_writer=val_writer
)

100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 001: Loss: 0.7990 / Acc: 0.676 / F1: 0.52        | Val Loss: 0.6733 / Val Acc: 0.784 / Val F1: 0.61


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 002: Loss: 0.6090 / Acc: 0.816 / F1: 0.63        | Val Loss: 0.5311 / Val Acc: 0.840 / Val F1: 0.65


100%|██████████| 119/119 [02:24<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 003: Loss: 0.5067 / Acc: 0.842 / F1: 0.65        | Val Loss: 0.4470 / Val Acc: 0.855 / Val F1: 0.67


100%|██████████| 119/119 [02:25<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 004: Loss: 0.4431 / Acc: 0.850 / F1: 0.66        | Val Loss: 0.3957 / Val Acc: 0.865 / Val F1: 0.67


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 005: Loss: 0.4005 / Acc: 0.857 / F1: 0.66        | Val Loss: 0.3595 / Val Acc: 0.872 / Val F1: 0.68


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 006: Loss: 0.3681 / Acc: 0.865 / F1: 0.67        | Val Loss: 0.3353 / Val Acc: 0.874 / Val F1: 0.68


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 007: Loss: 0.3439 / Acc: 0.870 / F1: 0.68        | Val Loss: 0.3164 / Val Acc: 0.879 / Val F1: 0.69


100%|██████████| 119/119 [02:23<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 008: Loss: 0.3243 / Acc: 0.875 / F1: 0.69        | Val Loss: 0.3022 / Val Acc: 0.881 / Val F1: 0.69


100%|██████████| 119/119 [02:24<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 009: Loss: 0.3079 / Acc: 0.881 / F1: 0.70        | Val Loss: 0.2908 / Val Acc: 0.885 / Val F1: 0.69


100%|██████████| 119/119 [02:23<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 010: Loss: 0.2951 / Acc: 0.886 / F1: 0.72        | Val Loss: 0.2830 / Val Acc: 0.887 / Val F1: 0.70


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 011: Loss: 0.2827 / Acc: 0.891 / F1: 0.73        | Val Loss: 0.2763 / Val Acc: 0.892 / Val F1: 0.72


100%|██████████| 119/119 [02:23<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 012: Loss: 0.2721 / Acc: 0.896 / F1: 0.75        | Val Loss: 0.2711 / Val Acc: 0.897 / Val F1: 0.73


100%|██████████| 119/119 [02:23<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 013: Loss: 0.2633 / Acc: 0.902 / F1: 0.76        | Val Loss: 0.2660 / Val Acc: 0.901 / Val F1: 0.74


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 014: Loss: 0.2554 / Acc: 0.905 / F1: 0.78        | Val Loss: 0.2628 / Val Acc: 0.902 / Val F1: 0.75


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 015: Loss: 0.2473 / Acc: 0.907 / F1: 0.78        | Val Loss: 0.2597 / Val Acc: 0.905 / Val F1: 0.76


100%|██████████| 119/119 [02:24<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 016: Loss: 0.2386 / Acc: 0.913 / F1: 0.79        | Val Loss: 0.2555 / Val Acc: 0.907 / Val F1: 0.76


100%|██████████| 119/119 [02:23<00:00,  1.20s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 017: Loss: 0.2337 / Acc: 0.913 / F1: 0.80        | Val Loss: 0.2536 / Val Acc: 0.909 / Val F1: 0.76


100%|██████████| 119/119 [02:24<00:00,  1.21s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 018: Loss: 0.2269 / Acc: 0.918 / F1: 0.81        | Val Loss: 0.2487 / Val Acc: 0.908 / Val F1: 0.76


100%|██████████| 119/119 [02:24<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 019: Loss: 0.2213 / Acc: 0.921 / F1: 0.82        | Val Loss: 0.2530 / Val Acc: 0.907 / Val F1: 0.75


100%|██████████| 119/119 [02:24<00:00,  1.22s/it]
  0%|          | 0/119 [00:00<?, ?it/s]

Epoch 020: Loss: 0.2158 / Acc: 0.922 / F1: 0.81        | Val Loss: 0.2488 / Val Acc: 0.909 / Val F1: 0.76


 50%|█████     | 60/119 [01:14<01:12,  1.23s/it]


KeyboardInterrupt: 

In [None]:
train_writer.close()
val_writer.close()

In [3]:
# rm -rf /opt/ml/.cache/torch/hub/checkpoints/