In [37]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize
import pickle
import matplotlib.pyplot as plt
import timm

In [38]:
def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [40]:
with open('images_path.pickle','rb') as f:
    images_path = pickle.load(f)

In [41]:
folder_path = '/opt/ml/input/data/train/images/'
MODEL_PATH ="saved"
batch_size = 64
LEARNING_RATE = 0.001
EPOCHS = 20
val_rate = 0.1
# im_path = folder_path + images_path[0]
# #im = Image.open('/opt/ml/input/data/train/images/000001_female_Asian_45/incorrect_mask.jpg')
# im = Image.open(im_path)
# im

In [42]:
class MaskGenderDataSet(Dataset):
    def __init__(self,f_path,images_path,transform=None,train=True):
        self.f_path = f_path
        self.images_path = images_path
        self.transform = transform
        self.info = pd.read_csv('image_info.csv')
        self.y = self.info['sex']
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        y = self.y[index]
        if y == 'male':
            y = 0
        else:
            y = 1
        x = Image.open(self.f_path + self.images_path[index])
        if self.transform:
            x = self.transform(x)
        
        return x, y
    

In [43]:
mask_gender_data = MaskGenderDataSet(f_path=folder_path,
                                        images_path=images_path,
                                        transform=transforms.Compose([transforms.RandomCrop(300),
                                                                      transforms.RandomPerspective(),
                                                                      transforms.RandomPerspective(),
                                                                      transforms.RandomHorizontalFlip(),
                                                                      transforms.RandomGrayscale(),
                                                                      Resize((260,260)),
                                                                      ToTensor(),
                                                                      Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
                                                                      ]))


In [44]:
mask_gender_data, validation_data = torch.utils.data.random_split(mask_gender_data,[int(len(mask_gender_data)*(1-val_rate)),
                                                                                    int(len(mask_gender_data)*val_rate)])

mask_gender_data_loader = DataLoader(dataset=mask_gender_data,batch_size=batch_size,
                                        shuffle=True,drop_last=True,num_workers=4)

validation_loader = DataLoader(dataset = validation_data, batch_size = batch_size, drop_last=True,shuffle=True)

                               

In [46]:
# images, labels =next(iter(validation_loader))

# plt.figure(figsize=(12,12))
# for n, (image, label) in enumerate(zip(images, labels), start=1):
#     plt.subplot(4,4,n)
#     plt.imshow(transforms.ToPILImage()(image))  # Normalize 처리때문에 복구
#     plt.title("{}".format(label))
#     plt.axis('off')
# plt.tight_layout()
# plt.show()    


In [47]:
# images, labels =next(iter(mask_age_data_loader))

# plt.figure(figsize=(12,12))
# for n, (image, label) in enumerate(zip(images, labels), start=1):
#     plt.subplot(4,4,n)
#     plt.imshow(transforms.ToPILImage()(image))  # Normalize 처리때문에 복구
#     plt.title("{}".format(label))
#     plt.axis('off')
# plt.tight_layout()
# plt.show()

In [48]:
# images, labels =next(iter(mask_state_data_loader))

# plt.figure(figsize=(12,12))
# for n, (image, label) in enumerate(zip(images, labels), start=1):
#     plt.subplot(4,4,n)
#     plt.imshow(transforms.ToPILImage()(image))  # Normalize 처리때문에 복구
#     plt.title("{}".format(label))
#     plt.axis('off')
# plt.tight_layout()
# plt.show()

In [49]:
class GenderClassifier(nn.Module):
    def __init__(self, num_of_classes = 2):
        super().__init__()
        self.m = timm.create_model('efficientnet_b1',pretrained=True)
        self.fc = nn.Linear(self.m.classifier.out_features, num_of_classes)
    
    def forward(self, x):
        x = self.m(x)
        return self.fc(x)

In [50]:
def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct / len(target)

In [51]:
gender_clf = GenderClassifier()

# for parm in gender_clf.parameters():
#     parm.requires_grad = False
# for parm in gender_clf.fc.parameters():
#     parm.requires_grad = True


In [52]:
gender_clf = gender_clf.to(device)
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1.5,1.0]).to(device))
optimizer = optim.Adam(gender_clf.parameters(), lr=LEARNING_RATE)


for e in range(1, EPOCHS+1):
    epoch_loss = 0
    epoch_acc = 0
    for X_batch, y_batch in mask_gender_data_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        y_pred = gender_clf(X_batch)
        
        loss = criterion(y_pred, y_batch)
        acc = accuracy(y_pred, y_batch)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() 
        epoch_acc += acc
    
    gender_clf.eval()
    vaild_acc = 0
    for data, labels, in validation_loader:
        data, labels = data.to(device), labels.to(device)
        
        y_pred = gender_clf(data)
        vacc = accuracy(y_pred, labels)
        vaild_acc += vacc
    gender_clf.train()
    va = f'{vaild_acc/len(validation_loader):.3f}' 
    l = f'{epoch_loss/len(mask_gender_data_loader):.5f}'
    a = f'{epoch_acc/len(mask_gender_data_loader):.3f}'
    print(f'Epoch {e+0:03}: | Loss: {l} | Acc: {a} | VAcc: {va}')

    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    torch.save(gender_clf.state_dict(), os.path.join(MODEL_PATH, f'gender_model_{e+0:03}.pt'))

Epoch 001: | Loss: 0.14454 | Acc: 0.942 | VAcc: 0.980
Epoch 002: | Loss: 0.07320 | Acc: 0.973 | VAcc: 0.983
Epoch 003: | Loss: 0.05096 | Acc: 0.980 | VAcc: 0.978
Epoch 004: | Loss: 0.04540 | Acc: 0.983 | VAcc: 0.985
Epoch 005: | Loss: 0.04186 | Acc: 0.985 | VAcc: 0.988
Epoch 006: | Loss: 0.03795 | Acc: 0.986 | VAcc: 0.990
Epoch 007: | Loss: 0.03576 | Acc: 0.988 | VAcc: 0.988
Epoch 008: | Loss: 0.02905 | Acc: 0.990 | VAcc: 0.991
Epoch 009: | Loss: 0.02725 | Acc: 0.990 | VAcc: 0.987
Epoch 010: | Loss: 0.03216 | Acc: 0.989 | VAcc: 0.989


KeyboardInterrupt: 

In [None]:
# MODEL_PATH ="saved"
# if not os.path.exists(MODEL_PATH):
#     os.makedirs(MODEL_PATH)
# torch.save(gender_clf.state_dict(), os.path.join(MODEL_PATH, "gender_model.pt"))
# '''
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# new_model = TheModelClass()
# new_model.load_state_dict(torch.load(os.path.join(
#     MODEL_PATH, "model.pt")))
# '''

In [None]:
# del mask_age_data_loader
# del age_clf
# torch.cuda.empty_cache()