In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from detector import AgeGenderEstimator
import numpy as np
from skimage import io
import matplotlib.pyplot as plt
import cv2
import tqdm

In [None]:
class TrainDataset(Dataset):
    def __init__(self,txt_path,transform=None,flip=False):
        self.imgs_path = []
        self.words = []
        self.transform = transform
        self.flip = flip
        self.batch_count = 0
        self.img_size = 112
            
        f = open(txt_path,'r')
        lines = f.readlines()
        isFirst = True
        labels = []
        for line in lines:
            line = line.rstrip() 
            if line.startswith('#'):
                if isFirst is True:
                    isFirst = False
                else:
                    labels_copy = labels.copy()
                    self.words.append(labels_copy)        
                    labels.clear()       
                path = line[2:]
                path = txt_path.replace('label.txt','images/') + path
                self.imgs_path.append(path)            
            else:
                line = line.split(' ')
                label = [float(x) for x in line]
                labels.append(label)

        self.words.append(labels)

    def __len__(self):
        return len(self.imgs_path)    

    def __getitem__(self,index):
        img = io.imread(self.imgs_path[index])
        #img = img.astype(np.float32)/255.0

        labels = self.words[index]
        annotations = np.zeros((0, 4))
        if len(labels) == 0:
            return annotations
        for idx, label in enumerate(labels):
            annotation = np.zeros((1,4))
            # bbox
            annotation[0,0] = label[0]                  # x1
            annotation[0,1] = label[1]                  # y1
            annotation[0,2] = label[0] + label[2]       # x2
            annotation[0,3] = label[1] + label[3]       # y2

            annotations = np.append(annotations,annotation,axis=0)
        
        sample = {'img':img, 'annot':annotations}
        if self.transform is not None:
            sample = self.transform(sample)

        return sample

In [None]:
ds = TrainDataset("../data/widerface/train/label.txt")
model = AgeGenderEstimator()

In [None]:
f = open("../data/widerface/train/label.txt", 'r')
lines = f.readlines()
isFirst = True
labels = []
for line in lines:
    line = line.rstrip() 
    labels.append(line)
f.close()

In [None]:
result_list = []

for i in tqdm.tqdm(range(len(ds))):
    data = ds[i]
    img, annot = data['img'], data['annot']
    for j in range(len(annot)):
        try:
            new_annot = annot[j]
            new_annot = [int(i) for i in new_annot]
            new_img = img[new_annot[1] : new_annot[3], new_annot[0] : new_annot[2]]
            new_img = cv2.resize(new_img, (112, 112))
            #####
            
            output = model.detect(torch.tensor(new_img).unsqueeze(0))
            result_list.append([output[0][0], output[1][0]])
        except:
            result_list.append(["UNK", 0])

In [None]:
idx = 0

for i in range(len(labels)):
    if labels[i].startswith("#"):
        continue
    else:
        labels[i] = labels[i] + " " + result_list[idx][0] + " " + result_list[idx][1]
        idx += 1