In [1]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize
import albumentations as albu
import albumentations.pytorch
import pickle
import matplotlib.pyplot as plt
import timm 
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from facenet_pytorch import InceptionResnetV1
from pprint import pprint
#!pip install facenet-pytorch
#!conda install -c conda-forge ipywidgets
#!pip install ipywidgets

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# with open('images_path.pickle','rb') as f:
#     images_path = pickle.load(f)
    
batch_size = 256
MODEL_PATH ="saved"
folder_path = '/opt/ml/input/data/train/images/'
LEARNING_RATE = 0.001
EPOCHS = 10
val_rate = 0.1
logs_base_dir = "logs"
os.makedirs(logs_base_dir, exist_ok=True)
exp  = f"{logs_base_dir}/AgeSmooth"
writer = SummaryWriter(exp)

In [3]:
class CustomDataSet(Dataset):
    def __init__(self,folder_path ,image_path ,label,transform=None, train=True):
        self.folder_path = folder_path
        self.image_path = image_path
        self.label = label
        self.transform = transform
        self.info = pd.read_csv(self.image_path)
        self.y = self.info[self.label]
        self.x = self.info['path']
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        y = None
        if self.label == 'age':
            if self.y[index] < 30:
                y = torch.tensor([1,0,0])
            elif 30 <= self.y[index] <= 50:
                y = torch.tensor([0,1,0])
            elif 50 < self.y[index] < 60:
                y = torch.tensor([0,1,1])
            else:
                y = torch.tensor([0,0,1])
                
        elif self.label == 'gender':
            if self.y[index] == 'male':
                y = torch.tensor(0)
            else:
                y = torch.tensor(1)
        elif self.label == 'state':
            if self.y[index] == 'mask':
                y = torch.tensor(0)
            elif self.y[index] == 'incorrect':
                y = torch.tensor(1)
            elif self.y[index] == 'normal':
                y = torch.tensor(2)
        x = np.array(Image.open(folder_path+self.x[index]))
        if self.transform:
            x = self.transform(image=x)
        return x['image'], y

In [4]:
def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)

def accuracy_multi_label(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        origin_target = torch.argmax(target, dim=1)
        # [0, 1, 1] -> [0, 1, 0] -> 2
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == origin_target).item()
    return correct / len(target)

def f1(output, target, is_training=False):
    pred = torch.argmax(output, dim=1)

    assert pred.ndim == 1
    assert target.ndim == 1 or target.ndim == 2

    if target.ndim == 2:
        target = target.argmax(dim=1)

    tp = (target * pred).sum().to(torch.float32)
    tn = ((1 - target) * (1 - pred)).sum().to(torch.float32)
    fp = ((1 - target) * pred).sum().to(torch.float32)
    fn = (target * (1 - pred)).sum().to(torch.float32)

    epsilon = 1e-7

    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    f1 = 2 * (precision * recall) / (precision + recall + epsilon)
    f1.requires_grad = is_training
    return f1

def f1_multi_label(output, target, is_training=False):
    pred = torch.argmax(output, dim=1)
    
    origin_target = torch.argmax(target, dim=1)

    assert pred.ndim == 1
    assert target.ndim == 1 or target.ndim == 2

    if origin_target.ndim == 2:
        origin_target = origin_target.argmax(dim=1)

    tp = (origin_target * pred).sum().to(torch.float32)
    tn = ((1 - origin_target) * (1 - pred)).sum().to(torch.float32)
    fp = ((1 - origin_target) * pred).sum().to(torch.float32)
    fn = (origin_target * (1 - pred)).sum().to(torch.float32)

    epsilon = 1e-7

    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    f1 = 2 * (precision * recall) / (precision + recall + epsilon)
    f1.requires_grad = is_training
    return f1

In [5]:
train_aug = albu.Compose([albu.ColorJitter(brightness=(0.2, 1), contrast=(0.3, 1), saturation=(0.2, 1), hue=(-0.3, 0.3)),
            albu.RandomCrop(300, 300),
            albu.HorizontalFlip(),
            albu.Resize(224, 224),
            albu.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            albu.pytorch.transforms.ToTensorV2()
        ])
vaild_preprocess = albu.Compose([albu.Resize(224, 224),
                                 albu.Normalize(mean=(0.5, 0.5, 0.5),std=(0.2, 0.2, 0.2)),
                                 albu.pytorch.transforms.ToTensorV2()
        ])

train_age_data = CustomDataSet(folder_path,'train_path.csv', 'age', transform = train_aug, train=True)
train_age_data_loader = DataLoader(dataset=train_age_data,batch_size=batch_size, shuffle=True,drop_last=True,num_workers=2)
vaild_age_data = CustomDataSet(folder_path,'train_path.csv', 'age', transform = vaild_preprocess, train=False)
vaild_age_data_loader = DataLoader(dataset=vaild_age_data,batch_size=batch_size, shuffle=False,drop_last=False,num_workers=2)

In [6]:
class AgeSmoothClassifier(nn.Module):
    def __init__(self, num_of_classes = 3):
        super().__init__()
        self.m = InceptionResnetV1(classify=True,num_classes=3)
    def forward(self, x):
        x = self.m(x)
        return x

In [7]:
age_clf = AgeSmoothClassifier()
age_clf = age_clf.to(device)

In [10]:

#criterion = nn.CrossEntropyLoss(weight = torch.tensor([1.5,1.0]).to(device)) #1.5,1.0
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = optim.Adam(age_clf.parameters(), lr=LEARNING_RATE,amsgrad=True)


for e in range(1, EPOCHS+1):
    epoch_loss = 0
    epoch_acc = 0
    for X_batch, y_batch in train_age_data_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        y_pred = age_clf(X_batch)
        
        loss = criterion(y_pred, y_batch)
        acc = accuracy_multi_label(y_pred, y_batch)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() 
        epoch_acc += acc
    writer.add_scalar('Loss/train', loss, e)
    writer.add_scalar('Acc/train', acc, e)
    
    age_clf.eval()
    vaild_acc = 0
    f1 = 0
    for data, labels, in vaild_age_data_loader:
        with torch.no_grad():
            data, labels = data.to(device), labels.to(device)
        
            y_pred = age_clf(data)
            vacc = accuracy_multi_label(y_pred, labels)
            vaild_acc += vacc
            f1 += f1_multi_label(y_pred,labels)
    writer.add_scalar('F1/val', f1, e)
    age_clf.train()
    va = f'{vaild_acc/len(vaild_age_data_loader):.3f}' 
    l = f'{epoch_loss/len(train_age_data_loader):.5f}'
    a = f'{epoch_acc/len(train_age_data_loader):.3f}'
    f1 = f'{f1/len(vaild_age_data_loader):.3f}'
    print(f'Epoch {e+0:03}: | Loss: {l} | Acc: {a} | VAcc: {va} | F1: {f1}')

    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    torch.save(age_clf.state_dict(), os.path.join(MODEL_PATH, f'age_smooth_model_{e+0:03}.pt'))

Epoch 001: | Loss: 0.17836 | Acc: 0.820 | VAcc: 0.734 | F1: 0.788


KeyboardInterrupt: 

In [None]:
# images, labels =next(iter(train_age_data_loader))


# plt.figure(figsize=(12,12))
# for n, (image, label) in enumerate(zip(images, labels), start=1):
#     plt.subplot(4,4,n)
#     plt.imshow(transforms.ToPILImage()(image))  # Normalize 처리때문에 복구
#     plt.title("{}".format(label))
#     plt.axis('off')
# plt.tight_layout()
# plt.show()    

# images, labels =next(iter(vaild_age_data_loader))


# plt.figure(figsize=(12,12))
# for n, (image, label) in enumerate(zip(images, labels), start=1):
#     plt.subplot(4,4,n)
#     plt.imshow(transforms.ToPILImage()(image))  # Normalize 처리때문에 복구
#     plt.title("{}".format(label))
#     plt.axis('off')
# plt.tight_layout()
# plt.show()    