In [1]:
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-5

In [2]:
### training set 있는 Path ###
BASE_PATH = '../data/train/'

In [3]:
import torchvision.transforms as transforms
### Transforms ###
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(300)
])

In [None]:
from Dataset import TrainValidDataset

In [None]:
import pandas as pd
df = pd.read_csv(BASE_PATH + 'train.csv')

In [None]:
# training set (validation set 따로 안 나눈 경우)
# train_dataset = AllClassDataset(
#     base_path = BASE_PATH, 
#     data = df, 
#     transform = transform
# )

In [None]:
# train / validation data
from sklearn.model_selection import train_test_split
def mapAgeGender(age, gender):
    answer = 0
    if age < 30:
        answer += 0
    elif age >= 60:
        answer += 2
    else:
        answer += 1
    return answer if gender == 'male' else answer + 3

y_data = df.apply(lambda x: mapAgeGender(x['age'], x['gender']), axis=1)
x_train, x_val, y_train, y_val = train_test_split(df.index, y_data, test_size=0.2, random_state=42, stratify=y_data)

In [None]:
mask_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_train], 
    transform = transform,
    label="mask"
)
age_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_train], 
    transform = transform,
    label="age"
)
gender_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_train], 
    transform = transform,
    label="gender"
)

In [None]:
mask_val_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_val], 
    transform = transform,
    label="mask"
)
age_val_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_val], 
    transform = transform,
    label="age"
)
gender_val_dataset = TrainValidDataset(
    base_path = BASE_PATH, 
    data = df.loc[x_val], 
    transform = transform,
    label="gender"
)

### Dataset 확인용 Code
```python
import matplotlib.pyplot as plt
image, label = val_dataset[0]
plt.title(f'Class {label}')
plt.imshow(image.permute(1,2,0))
plt.show()
```

In [None]:
mask_trainloader = torch.utils.data.DataLoader(
    mask_dataset,
    batch_size=64,
    num_workers=1
)
age_trainloader = torch.utils.data.DataLoader(
    age_dataset,
    batch_size=64,
    num_workers=1
)
gender_trainloader = torch.utils.data.DataLoader(
    gender_dataset,
    batch_size=64,
    num_workers=1
)

In [None]:
mask_valloader = torch.utils.data.DataLoader(
    mask_val_dataset,
    batch_size=64,
    num_workers=1
)
age_valloader = torch.utils.data.DataLoader(
    age_val_dataset,
    batch_size=64,
    num_workers=1
)
gender_valloader = torch.utils.data.DataLoader(
    gender_val_dataset,
    batch_size=64,
    num_workers=1
)

### Dataloader 확인용 Code
```python
import matplotlib.pyplot as plt
images, labels = next(iter(trainloader))
plt.title(f'Class {labels[0]}')
plt.imshow(images[0].permute(1,2,0))
plt.show()
```

## Model

In [4]:
from Model import ResnetModel

In [5]:
mask_model = ResnetModel(num_classes = 3).to(DEVICE)
age_model = ResnetModel(num_classes = 3).to(DEVICE)
gender_model = ResnetModel(num_classes = 2).to(DEVICE)

## Test

In [6]:
import os
import pandas as pd
from test import Test
### Test data dir ###
test_dir = '../data/eval'

In [7]:
test = Test(
    test_dir=test_dir,
    mask_model=mask_model, 
    age_model=age_model,
    gender_model=gender_model, 
    device=DEVICE
)

In [8]:
test.loadSavedModel( 
    mask='saved/mask/resnet18_mask_16_0.00_1.000000.pt',
    age='saved/resnet18_age_hflip_.3_SGD/resnet18_age_.5_hflip_18_0.23_0.92.pt',
    gender='saved/gender/resnet18_gender_16_0.01_0.998286.pt',
)

(ResnetModel(
   (pretrained): ResNet(
     (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (layer1): Sequential(
       (0): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu): ReLU(inplace=True)
         (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       )
       (1): BasicBlock(
         (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1

In [9]:
test.predictTestData(
    transform=transform
)

100%|██████████| 12600/12600 [04:24<00:00, 47.57it/s]


Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,13
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,1
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,13
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,13
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,12
...,...,...
12595,d71d4570505d6af8f777690e63edfa8d85ea4476.jpg,1
12596,6cf1300e8e218716728d5820c0bab553306c2cfd.jpg,4
12597,8140edbba31c3a824e817e6d5fb95343199e2387.jpg,9
12598,030d439efe6fb5a7bafda45a393fc19f2bf57f54.jpg,1


In [10]:
test.submission('submission.csv')

test inference is done!


### 참고
https://pangate.com/967