In [5]:
import json
import random
import PIL

In [7]:
class FashionIQ():
  def __init__(self, path, split='train', transform=None):
    super(FashionIQ, self).__init__()
    
    self.split = split
    self.transform = transform
    self.img_path = path + '/'
    
    failures = [u'B00AZJJGDS', u'B00BTDROGU', u'B00AZJJER6', u'B00AZJJYBM', u'B009H3SDPK', u'B00AKQBDKU', u'B00AZJJE1C', u'B00BTDS338', u'B0082P4LLO', u'B00CM2RZ3E', u'B00C621JPK', u'B00AZGHX7M', u'B00A885VUI', u'B0057LRI6Q', u'B008583YKC', u'B009NY5YZA']

    data = {
        'image_splits': {},
        'captions': {}
    }
    import os
    for k in data:
        print(os.listdir(path + '/' + k))
        for f in os.listdir(path + '/' + k):
          print(f)
          if (split == 'train' and 'train' in f) or (split == 'test' and 'val' in f):
            d = json.load(open(path + '/' + k + '/' + f))
            data[k][f] = d

    imgs = []
    asin2id = {}
    for k in data['image_splits']:
        for asin in data['image_splits'][k]:
            if asin in failures:
                continue
            asin2id[asin] = len(imgs)
            imgs += [{
                'asin': asin,
                'file_path': path + '/images/' + asin + '.jpg',
                'captions': [asin2id[asin]]
            }]
    print('write attribute2idx.json')
    json_object = json.dumps(asin2id, indent = 4)
    # Writing to sample.json
    with open("attribute2idx.json", "w") as outfile:
        outfile.write(json_object)

    queries = []
    for k in data['captions']:
        for query in data['captions'][k]:
            if query['candidate'] in failures or query['target'] in failures:
                continue
            query['source_id'] = asin2id[query['candidate']]
            query['target_id'] = asin2id[query['target']]
            query['captions'] = [c.encode('utf-8') for c in query['captions']]
            queries += [query]
            
    
    self.data = data
    self.imgs = imgs
    self.queries = queries
    
    if split == 'test':
        self.test_queries = [{
              'source_img_id': query['source_id'],
              'target_img_id': query['target_id'],
              'target_caption': query['target_id'],
              'target_caption': query['target_id'],
              'mod': {'str': query['captions'][0] + ' inadditiontothat ' + query['captions'][1]}
          } for query in queries]

  def get_all_texts(self):
    texts = ['inadditiontothat']
    for query in self.queries:
        texts += query['captions']
    return texts

  def __len__(self):
    return len(self.imgs)

  def generate_random_query_target(self):
    query = random.choice(self.queries)
    mod_str = random.choice([
            query['captions'][0] + ' inadditiontothat ' + query['captions'][1],
            query['captions'][1] + ' inadditiontothat ' + query['captions'][0]
        ])
        
    return {
      'source_img_id': query['source_id'],
      'source_img_data': self.get_img(query['source_id']),
      'target_img_id': query['target_id'],
      'target_caption': query['target_id'],
      'target_img_data': self.get_img(query['target_id']),
      'target_caption': query['target_id'],
      'mod': {'str': mod_str}
    }

  def get_img(self, idx, raw_img=False):
    img_path = self.imgs[idx]['file_path']
    with open(img_path, 'rb') as f:
      img = PIL.Image.open(f)
      img = img.convert('RGB')
    if raw_img:
      return img
    if self.transform:
      img = self.transform(img)
    return img

In [8]:
path = '/home/piai/chan/largescale_multimedia/project/fashion-iq/data'
fashioniq = FashionIQ(path=path)

['split.toptee.train.json', 'split.dress.test.json', 'split.shirt.train.json', 'split.toptee.val.json', 'split.dress.train.json', 'split.shirt.val.json', 'split.dress.val.json', 'split.shirt.test.json', 'split.toptee.test.json']
split.toptee.train.json
split.dress.test.json
split.shirt.train.json
split.toptee.val.json
split.dress.train.json
split.shirt.val.json
split.dress.val.json
split.shirt.test.json
split.toptee.test.json
['cap.toptee.val.json', 'cap.toptee.test.json', 'cap.shirt.test.json', 'cap.shirt.val.json', 'cap.dress.val.json', 'cap.dress.test.json', 'cap.toptee.train.json', 'cap.dress.train.json', 'cap.shirt.train.json', 'dict.dress.json', 'dict.shirt.json', 'dict.toptee.json']
cap.toptee.val.json
cap.toptee.test.json
cap.shirt.test.json
cap.shirt.val.json
cap.dress.val.json
cap.dress.test.json
cap.toptee.train.json
cap.dress.train.json
cap.shirt.train.json
dict.dress.json
dict.shirt.json
dict.toptee.json
write attribute2idx.json
