In [3]:
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import copy
import numpy as np
import os

import pandas as pd
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
from glob import glob

from PIL import Image

use_cuda = torch.cuda.is_available()

In [4]:
def load_image(path, imsize=224, volatile=True, use_cuda=False):
    im = Image.open(path)
    im = RGBA2RGB(im)

    # crop to sketch only (eliminate white space)
    arr = np.asarray(im)
    w,h,d = np.where(arr<255) # where the image is not white
    if len(h)==0:
        print(path)            
    xlb = min(h)
    xub = max(h)
    ylb = min(w)
    yub = max(w)
    lb = min([xlb,ylb])
    ub = max([xub,yub])            
    im = im.crop((lb, lb, ub, ub))            

    loader = transforms.Compose([
        transforms.Pad(padding),                
        transforms.Scale(imsize),
        transforms.ToTensor()])

    im = Variable(loader(im), volatile=volatile)
    im = im.unsqueeze(0)
    return im 

        
def RGBA2RGB(image, color=(255, 255, 255)):
    image.load()  # needed for split()
    background = Image.new('RGB', image.size, color)
    background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
    return background


def list_files(path, ext='png'):
        result = [y for x in os.walk(path) for y in glob(os.path.join(x[0], '*.%s' % ext))]
        return result


def get_metadata_from_path(path,cohort):
    label = path.split('/')[-2]            
    if cohort == 'kid':
        age = path.split('/')[-1].split('_')[2]
        session = path.split('/')[-1].split('.')[0].split('_')[-2] + '_' + path.split('/')[-1].split('.')[0].split('_')[-1]
    elif cohort == 'adult':
        age = 'adult'
        session = 'unknown'
    else:
        print('Need to specify a cohort: "kid" or "adult"!')
        age = 'unknown'
        session = 'unknown'
    return label, age, session

def check_invalid_sketch(filenames,invalids_path='drawings_to_exclude_clean.txt'):    
    if not os.path.exists(invalids_path):
        print('No file containing invalid paths at {}'.format(invalids_path))
        invalids = []        
    else:
        x = pd.read_csv(invalids_path, header=None)
        x.columns = ['filenames']
        invalids = list(x.filenames.values)
    valids = []   
    basenames = [f.split('/')[-1] for f in filenames]
    for i,f in enumerate(basenames):
        if f not in invalids:
            valids.append(filenames[i])
    return valids

def convert_age(Ages):
    '''
    handle trials where we didn't have age information
    '''
    ages = []
    for a in Ages:
        if len(a)>0:
            ages.append(int(a))
        else:
            ages.append(-1)
    return ages

def make_dataframe(Labels,Ages,Sessions):    
    Y = pd.DataFrame([Labels,Ages,Sessions])
    Y = Y.transpose()
    Y.columns = ['label','age','session']   
    return Y


def preprocess_features(Features, Y):
    _Y = Y.sort_values(['label','age','session'])
    inds = np.array(_Y.index)
    _Features = normalize(Features[inds])
    _Y = _Y.reset_index(drop=True) # reset pandas dataframe index
    return _Features, _Y

def normalize(X):
    X = X - X.mean(0)
    X = X / np.maximum(X.std(0), 1e-5)
    return X

## remove data where you dont have age information
def remove_nans(Features, Y):
    ind = Y.index[(Y['age'] > 0)]
    _Y = Y.loc[ind]
    _Features = Features[ind.tolist()]
    return _Features, _Y


In [5]:
###  Path to adult sketches
data_path = '/data2/jefan/quickDraw/png_micro'
all_adult_pngs = list_files(data_path)
imSize=224
numPixels=imSize*imSize
  

In [6]:
## Adult features
Labels = []
Ages = []
Sessions = []
Features = np.zeros((len(all_adult_pngs),numPixels))

for vi, v in enumerate(np.asarray(all_adult_pngs)):
    im = Image.open(v)
    im = RGBA2RGB(im)
    im2 = im.resize((224,224), Image.ANTIALIAS)
    arr = np.array(im2)
    oneChannel = arr[:,:,1];
    pixels = np.ravel(oneChannel)
    Features[vi,:] = pixels
    label, age, session = get_metadata_from_path(v,'adult')
    Labels.append(label)
    Ages.append(age)
    Sessions.append(session)
    
Y = make_dataframe(Labels,Ages,Sessions)
_Features, _Y = preprocess_features(Features, Y)

In [8]:
## Save adult pixels
cohort='adult'
np.save('./features/FEATURES_{}_{}.npy'.format('pixels', cohort), _Features)
_Y.to_csv('./features/METADATA_{}.csv'.format(cohort))


In [7]:
np.size(_Features)

190668800

In [9]:
## Now for kid features
data_path_kids = '/home/jefan/kiddraw/analysis/museumdraw/sketches'
all_pngs_kids = list_files(data_path_kids)


## filter out invalid sketches
all_pngs_kids = check_invalid_sketch(all_pngs_kids)
print('Length of sketch_paths after filtering: {}'.format(len(all_pngs_kids)))  


Length of sketch_paths after filtering: 462


In [10]:
## Kid features
Labels = []
Ages = []
Sessions = []
Features = np.zeros((len(all_pngs_kids),numPixels))

for vi, v in enumerate(np.asarray(all_pngs_kids)):
    im = Image.open(v)
    im = RGBA2RGB(im)
    im2 = im.resize((224,224), Image.ANTIALIAS)
    arr = np.array(im2)
    oneChannel = arr[:,:,1];
    pixels = np.ravel(oneChannel)
    Features[vi,:] = pixels
    label, age, session = get_metadata_from_path(v,'kid')
    Labels.append(label)
    Ages.append(age)
    Sessions.append(session)
   

Ages = convert_age(Ages)  

In [11]:
# organize metadata into dataframe
Y = make_dataframe(Labels,Ages,Sessions)
_Features, _Y = preprocess_features(Features, Y)
_Features, _Y = remove_nans(_Features, _Y)  # remove nans from kid dataframe (where we didn't have age information)

cohort='kid'
np.save('./features/FEATURES_{}_{}.npy'.format('pixels', cohort), _Features)
_Y.to_csv('./features/METADATA_{}.csv'.format(cohort))
