In [90]:
import torch
import torchvision
from torchvision import transforms
import PIL
from PIL import Image
import os
import glob
from glob import glob
import pandas as pd
import numpy as np
import torch.optim as optm
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [112]:
class cfg:
    data_dir = "/opt/ml/input/data/train/"
    img_dir = f"{data_dir}/images"
    df_path = f"{data_dir}/train.csv"
    img_height = 512
    img_width = 384
    batch_size = 128
    lr = 0.001
    epoch = 10

In [123]:
num2class = ['incorrect_mask', 'mask1', 'mask2', 'mask3', 'mask4', 'mask5', 'normal']
class2num = {k : v for v, k in enumerate(num2class)}

df = pd.read_csv(cfg.df_path)

df['gender'] = df['gender'].map({'female':0, 'male':1})
df['age'] = df['age'].map(age_label_func)

df.head()

Unnamed: 0,id,gender,race,age,path
0,1,0,Asian,1,000001_female_Asian_45
1,2,0,Asian,1,000002_female_Asian_52
2,4,1,Asian,1,000004_male_Asian_54
3,5,0,Asian,1,000005_female_Asian_58
4,6,0,Asian,1,000006_female_Asian_59


In [114]:
def get_ext(img_dir, img_id):
    filename = os.listdir(os.path.join(img_dir, img_id))[0]
    ext = os.path.splitext(filename)[-1].lower()
    
    return ext

In [115]:
image_dirs=[]
for path in df.path:
    image_dirs.append(os.path.join(cfg.img_dir, path)) 

In [116]:
#mask labeling
imgs=[]
mask_labels=[]
for path in image_dirs:
    for image in glob(f'{path}/**'):
        if 'normal' in image:
            mask_labels.append(2)
        elif 'incorrect' in image:
            mask_labels.append(1)
        else:
            mask_labels.append(0)
        imgs.append(image)

In [129]:
def age_label_func(x):
    if x<30: return 0
    elif 30<=x<60: return 1
    else: return 2

def labeling():
    result=[]
    genders=[]
    ages=[]
    
    for gender in df['gender']:
        genders.extend([gender] * 7)
        
    for age in df['age']:
        ages.extend([age] * 7)
        
    result.append(ages)
    result.append(genders)

    return result

age_labels, gender_labels = labeling()

In [131]:
class BaseDataset(Dataset):
    def __init__(self, img_paths, transform):
        super().__init__()
        self.img_paths = img_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)

In [134]:
class MaskDataset(BaseDataset):
    def __init__(self):
        super().__init__()
    
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        if self.transform:
            image=self.transform(image)
        
        return image, mask_labels[index]

In [137]:
class AgeDataset(BaseDataset):
    def __init__(self):
        super().__init__()
        
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        if self.transform:
            image=self.transform(image)
        
        return image, age_labels[index]

In [136]:
class GenderDataset(BaseDataset):
    def __init__(self):
        super().__init__()
        
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        if self.transform:
            image=self.transform(image)
        
        return image, gender_labels[index]

In [128]:
transform = transforms.Compose([
    Resize((int(cfg.img_height/2), int(cfg.img_width/2)),Image.BILINEAR),
    CenterCrop(int(cfg.img_height/4)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_dataset = MaskDataset(imgs, transform)
mask_dataloader = DataLoader(mask_dataset, batch_size=64, shuffle=True)

# batch_iterator = iter(mask_dataloader)
# images = next(batch_iterator)
# plt.imshow(torchvision.utils.make_grid(images, nrow=5).permute(1, 2, 0))
# plt.show()

In [84]:
class Gyudel(nn.Module):
    def __init__(self, class_num):
        super().__init__()
        self.mask_model = torchvision.models.resnext101_32x8d(pretrained=True)
        self.mask_model.fc=torch.nn.Linear(in_features=2048, out_features=class_num, bias=True)
        
    def forward(self, x):
        x=self.mask_model(x)
        
        return x

In [87]:
model = Gyudel(3)
optim = optm.Adam(model.parameters(), lr=cfg.lr)
loss_fn = nn.CrossEntropyLoss()

In [135]:
model.to(device)
model_dir=os.path.join(os.getcwd(),'saved/total_model/')

with tqdm(mask_dataloader) as pbar:
    running_loss=0.
    running_acc=0.
    
    tot_pred=torch.tensor([]).to(device)
    tot_label=torch.tensor([]).to(device)
    
    for n,(image,label) in enumerate(pbar):
        image=image.to(device)
        label=label.to(device)
            
        logit=model(image)
        _,pred=torch.max(logit,1)
        
        optim.zero_grad()
        loss=loss_fn(logit,label)
        loss.backward()
        optim.step()
        running_loss+=loss.item()*image.size(0)
        running_acc+=torch.sum(pred==label)
        pbar.set_postfix({'epoch' : epoch+1, 'loss' : running_loss/(n+1), 'accuracy' : float(running_acc)/(n+1),'F1 score':f1_score(label.cpu(),pred.cpu(),average='weighted')})
            
        tot_pred=torch.hstack((tot_pred,pred))
        tot_label=torch.hstack((tot_label,label))
        epoch_loss=running_loss/len(mask_dataloader.dataset)
        epoch_acc=running_acc/len(mask_dataloader.dataset)
        
        print(f"현재 epoch-{n+1}의 평균 Loss : {epoch_loss:.3f}, 평균 Accuracy : {epoch_acc:.3f}, F1 score : {f1_score(tot_label.cpu(),tot_pred.cpu(),average='weighted')}" )
        

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=296.0), HTML(value='')))

현재 epoch-1의 평균 Loss : 0.000, 평균 Accuracy : 0.003, F1 score : 1.0
현재 epoch-2의 평균 Loss : 0.000, 평균 Accuracy : 0.007, F1 score : 1.0
현재 epoch-3의 평균 Loss : 0.000, 평균 Accuracy : 0.010, F1 score : 0.9947456601889547
현재 epoch-4의 평균 Loss : 0.000, 평균 Accuracy : 0.013, F1 score : 0.9960669398907104
현재 epoch-5의 평균 Loss : 0.000, 평균 Accuracy : 0.017, F1 score : 0.9968595062552984
현재 epoch-6의 평균 Loss : 0.000, 평균 Accuracy : 0.020, F1 score : 0.9947499051704015
현재 epoch-7의 평균 Loss : 0.000, 평균 Accuracy : 0.024, F1 score : 0.9955059270934484
현재 epoch-8의 평균 Loss : 0.000, 평균 Accuracy : 0.027, F1 score : 0.9960698351406507
현재 epoch-9의 평균 Loss : 0.001, 평균 Accuracy : 0.030, F1 score : 0.9929876008097429
현재 epoch-10의 평균 Loss : 0.001, 평균 Accuracy : 0.034, F1 score : 0.993693967572723
현재 epoch-11의 평균 Loss : 0.001, 평균 Accuracy : 0.037, F1 score : 0.9928250215223172
현재 epoch-12의 평균 Loss : 0.001, 평균 Accuracy : 0.040, F1 score : 0.9934263574383754
현재 epoch-13의 평균 Loss : 0.001, 평균 Accuracy : 0.044, F1 score : 0.9939