In [1]:
import os
import cv2
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

import data_loader as D
from collections import Counter

ModuleNotFoundError: No module named 'config'

In [5]:
train_dir = '/opt/ml/input/data/train'
train_img_dir = os.path.join(train_dir, 'images')
train_img_sub_dirs = [os.path.join(train_img_dir, sub_dir) for sub_dir in os.listdir(train_img_dir) if os.path.isdir(os.path.join(train_img_dir, sub_dir))]

default_img_paths = np.array([[os.path.join(sub_dir, img) for img in os.listdir(sub_dir) if not img.startswith('.')]  for sub_dir in train_img_sub_dirs]).flatten()

default_transforms = transforms.Compose([
    transforms.ToTensor(),
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
])

In [11]:
    
def _cal_label_weight(gender, age):
    weight = np.zeros(gender.shape)
    # gender
    weight += (gender == 'female')*3
    # age
    weight += ((age >= 30) & (age < 60))*1
    weight += (age >= 60)*2

    return weight

In [24]:
train_info = pd.read_csv(os.path.join(train_dir, 'train.csv'))
# print(train_info)
train_info['gender'] = train_info.gender.map({'female': 3, 'male': 0})
print(train_info)

age = train_info.age
weight = ((age >= 30) & (age < 60))*1
weight += (age >= 60)*2
train_info['age'] = weight
print(train_info)
print(train_info.query(f"path == '000001_female_Asian_45'")['age'].values[0])

          id  gender   race  age                    path
0     000001       3  Asian   45  000001_female_Asian_45
1     000002       3  Asian   52  000002_female_Asian_52
2     000004       0  Asian   54    000004_male_Asian_54
3     000005       3  Asian   58  000005_female_Asian_58
4     000006       3  Asian   59  000006_female_Asian_59
...      ...     ...    ...  ...                     ...
2695  006954       0  Asian   19    006954_male_Asian_19
2696  006955       0  Asian   19    006955_male_Asian_19
2697  006956       0  Asian   19    006956_male_Asian_19
2698  006957       0  Asian   20    006957_male_Asian_20
2699  006959       0  Asian   19    006959_male_Asian_19

[2700 rows x 5 columns]
          id  gender   race  age                    path
0     000001       3  Asian    1  000001_female_Asian_45
1     000002       3  Asian    1  000002_female_Asian_52
2     000004       0  Asian    1    000004_male_Asian_54
3     000005       3  Asian    1  000005_female_Asian_58
4     

In [13]:
train_info['label_weight'] = _cal_label_weight(train_info['gender'], train_info['age'])
train_info

Unnamed: 0,id,gender,race,age,path,label_weight
0,000001,female,Asian,45,000001_female_Asian_45,4.0
1,000002,female,Asian,52,000002_female_Asian_52,4.0
2,000004,male,Asian,54,000004_male_Asian_54,1.0
3,000005,female,Asian,58,000005_female_Asian_58,4.0
4,000006,female,Asian,59,000006_female_Asian_59,4.0
...,...,...,...,...,...,...
2695,006954,male,Asian,19,006954_male_Asian_19,0.0
2696,006955,male,Asian,19,006955_male_Asian_19,0.0
2697,006956,male,Asian,19,006956_male_Asian_19,0.0
2698,006957,male,Asian,20,006957_male_Asian_20,0.0


In [51]:
mask_iter = D.mask_train_img_iter_numworker_batch
next(iter(mask_iter))[1]

tensor([ 0,  0,  6,  0,  0,  0, 12, 12,  6, 12, 12,  6, 12,  6,  6,  6, 12,  0,
         0,  6])

In [50]:
age_iter = D.age_train_img_iter_numworker_batch
next(iter(age_iter))[1]
train_info.groupby('age').count()

Unnamed: 0_level_0,id,gender,race,path
age,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1281,1281,1281,1281
1,1227,1227,1227,1227
2,192,192,192,192


In [49]:
gen_iter = D.gender_train_img_iter_numworker_batch
print(next(iter(gen_iter))[1])
train_info.groupby('gender').count()

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


Unnamed: 0_level_0,id,race,age,path
gender,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1042,1042,1042,1042
3,1658,1658,1658,1658
