# Import Libray required

In [2]:
import os
import pandas as pd
import numpy as np
import random
from PIL import Image
import tqdm
import timm

from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import torchvision
import matplotlib.pyplot as plt

# Check version and Set Seed

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
# print ("PyTorch version:[%s]."%((###### ).__version__))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #'cuda:0'
print ("device:[%s]."%(device))

def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)
    
set_seed(42)

PyTorch version:[1.7.1].
device:[cuda].


# Configuration

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Face Crop

In [5]:
!pip install facenet_pytorch
from facenet_pytorch import MTCNN

mtcnn = MTCNN(keep_all = True, device = device)



In [190]:
class _FaceCrop():
    '''
    device : MTCNN에게 넘겨줄 device. 기본값은 cuda.
    margin : int값을 받으며, mtcnn이 검출한 얼굴 좌표의 크기를 입력받은 값만큼 수정함. 기본값은 30
    not_detected : 입력의 형태는 [x,y,w,h]이며, mtcnn이 얼굴을 검출하지 못했을 시 해당 입력 바탕으로 사진을 자름. 기본값은 [100,50,300,300]
    input_shape : 입력의 형태는 [Height, Width]이며, input으로 들어가는 image의 shape. 기본값은 [512, 384]
    output_shape : 입력의 형태는 [Height, Width]이며, output으로 나가는 image의 shape. 기본값은 [300, 300]
    '''

    def __init__(self, device="cuda", margin=30, not_detected=[100, 50, 300, 300], input_shape=[384, 512], output_shape = [300,300]):

        self.margin = margin

        # 검출되지 않았을 경우 crop 좌표
        self.not_detected = not_detected
        self.x = self.not_detected[0]
        self.y = self.not_detected[1]
        self.w = self.not_detected[2]
        self.h = self.not_detected[3]

        # shape 정의
        self.shape = input_shape
        self.img_h = self.shape[0]
        self.img_w = self.shape[1]

        self.device = device
        self.mtcnn = MTCNN(keep_all = True, device = self.device, margin = self.margin)
        self.transform = torchvision.transforms.Resize(output_shape)

    def __call__(self, img):

        if isinstance(img, torch.Tensor):
            img = img.permute(1,2,0)*255
        else:
            img = torchvision.transforms.ToTensor()(img).permute(1,2,0)*255

        boxes, props = mtcnn.detect(img)
        
        if props[0] is not None: # 얼굴을 검출 했을 시 (min, max는 image boundary를 벗어났을 때를 대비)

            ind = props.argmax()
            box = boxes[ind]            

            x_min = max(int(box[0]) - self.margin ,0)
            y_min = max(int(box[1]) - self.margin ,0)
            x_max = min(int(box[2]) + self.margin, self.img_w)
            y_max = min(int(box[3]) + self.margin, self.img_h)

        else: # 얼굴을 검출하지 못했을 시

            x_min = self.x
            y_min = self.y
            x_max = self.x + self.w
            y_max = self.y + self.h

        img = img[y_min:y_max, x_min:x_max, :]/255
        img = img.permute(2,0,1)

        return self.transform(img)

facecrop = _FaceCrop()

def crop_collate(samples, resize = (300,300)):

    X = [facecrop(sample[0]) for sample in samples]
    y = [sample[1] for sample in samples]

    X = torch.stack(X)
    y = torch.LongTensor(y)

    return X, y

In [191]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,csv):
        super().__init__()
        self.path = pd.read_csv(csv).values[:,0]
        self.class_ = pd.read_csv(csv).values[:,1]

    def __len__(self):
        return len(self.path)

    def __getitem__(self,idx):
        img_path = self.path[idx]
        img = Image.open(img_path).convert("RGB")
        img = torchvision.transforms.ToTensor()(img)
        label = self.class_[idx]

        return img, label

In [192]:
dataset = Dataset("train_data_path_and_class.csv")

In [193]:
dataloader = torch.utils.data.DataLoader(dataset = dataset, shuffle = False, collate_fn = crop_collate, batch_size = 32)