In [1]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize
import pickle
import matplotlib.pyplot as plt
import timm

In [2]:
def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
with open('images_path.pickle','rb') as f:
    images_path = pickle.load(f)

In [5]:
batch_size = 64
MODEL_PATH ="saved"
folder_path = '/opt/ml/input/data/train/images/'
LEARNING_RATE = 0.001
EPOCHS = 10

In [6]:
class AgeClassifier(nn.Module):
    def __init__(self, num_of_classes = 3):
        super().__init__()
        self.m = timm.create_model('efficientnet_b1',pretrained=True)
        self.fc = nn.Linear(self.m.classifier.out_features, num_of_classes)
    
    def forward(self, x):
        x = self.m(x)
        return self.fc(x)

In [7]:
class MaskAgeDataSet(Dataset):
    def __init__(self,f_path,images_path,transform=None,train=True):
        self.f_path = f_path
        self.images_path = images_path
        self.transform = transform
        self.info = pd.read_csv('image_info.csv')
        self.y = self.info['age']
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        y = torch.tensor(self.y[index])

        x = Image.open(self.f_path + self.images_path[index])
        if self.transform:
            x = self.transform(x)
        
        return x, y

In [8]:
age_clf = AgeClassifier()

# for parm in age_clf.parameters():
#     parm.requires_grad = False
# for parm in age_clf.fc.parameters():
#     parm.requires_grad = True
mask_age_data = MaskAgeDataSet(f_path=folder_path,
                                        images_path=images_path,
                                        transform=transforms.Compose([Resize((260,260)),ToTensor(),
                                        Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))]))

mask_age_data_loader = DataLoader(dataset=mask_age_data,batch_size=batch_size,
                                        shuffle=True,num_workers=4)

In [9]:
age_clf = age_clf.to(device)
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1.,2.,2.]).to(device)) # 수정필요
optimizer = optim.Adam(age_clf.parameters(), lr=LEARNING_RATE)



for e in range(1, EPOCHS+1):
    epoch_loss = 0
    epoch_acc = 0
    for X_batch, y_batch in mask_age_data_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()        
        y_pred = age_clf(X_batch)
        
        loss = criterion(y_pred, y_batch)
        acc = accuracy(y_pred, y_batch)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss
        epoch_acc += acc
    print(f'Epoch {e+0:03}: | Loss: {epoch_loss/len(mask_age_data_loader):.5f} | Acc: {epoch_acc/len(mask_age_data_loader):.3f}')
    l = f'{epoch_loss/len(mask_age_data_loader):.5f}'
    a = f'{epoch_acc/len(mask_age_data_loader):.3f}'
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    torch.save(age_clf.state_dict(), os.path.join(MODEL_PATH, f'age_model_{l}_{a}.pt'))

Epoch 001: | Loss: 0.18679 | Acc: 0.934
Epoch 002: | Loss: 0.06075 | Acc: 0.981
Epoch 003: | Loss: 0.06131 | Acc: 0.982
Epoch 004: | Loss: 0.03123 | Acc: 0.991
Epoch 005: | Loss: 0.02771 | Acc: 0.992
Epoch 006: | Loss: 0.02553 | Acc: 0.992
Epoch 007: | Loss: 0.02824 | Acc: 0.991
Epoch 008: | Loss: 0.02281 | Acc: 0.993
Epoch 009: | Loss: 0.01100 | Acc: 0.996
Epoch 010: | Loss: 0.02033 | Acc: 0.995


In [33]:
# a = nn.CrossEntropyLoss(weight=torch.tensor([1.,2.,2]))
# b = torch.tensor([[1.,0.,0.]],requires_grad=True)
# c = torch.LongTensor([0])
# d = torch.LongTensor([1])
# e = torch.LongTensor([2])
# print(a(b,c))
# print(a(b,d))
# print(a(b,e))

tensor(0.5514, grad_fn=<NllLossBackward>)
tensor(1.5514, grad_fn=<NllLossBackward>)
tensor(1.5514, grad_fn=<NllLossBackward>)
