In [12]:
import os
import sys
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch import seed_everything

sys.path.append('/proj/vondrick4/naveen/CoIR')
from src.datasets.fiq_corpus_dataset import fiq_corpus_dataset_clip
from src.datasets.fiq_retrieval_dataset import fiq_retrieval_dataset_clip

In [6]:
dataset = fiq_corpus_dataset_clip('val', '/proj/vondrick4/naveen/coir-data/FashionIQ', '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch32')

dataloader =  DataLoader(
    dataset=dataset, 
    batch_size=10, 
    shuffle=True,
    num_workers=2, 
    collate_fn=dataset.collate_fn,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [7]:
for batch_idx, batch in enumerate(dataloader):
    print(batch['image']['pixel_values'].shape)
    break

torch.Size([10, 3, 224, 224])


In [21]:
ret_dataset = fiq_retrieval_dataset_clip('toptee', '/proj/vondrick4/naveen/coir-data/FashionIQ', '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch32')

ret_dataloader =  DataLoader(
    dataset=ret_dataset, 
    batch_size=10, 
    shuffle=True,
    num_workers=2, 
    collate_fn=ret_dataset.collate_fn,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [22]:
for batch_idx, batch in enumerate(ret_dataloader):
    print(batch)
    break

{'query-image-id': ['B005X4EAGS', 'B004F3NKGO', 'B002ZNKR3K', 'B004XBR0XM', 'B005A1U8OC', 'B006L2ZRHC', 'B007S3RM9Y', 'B00COQ5O9K', 'B0079K2OAI', 'B00BR5R2JO'], 'query-image': {'pixel_values': tensor([[[[1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
          [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
          [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
          ...,
          [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
          [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
          [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303]],

         [[2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
          [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
          [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
          ...,
          [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
          [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
          [2.0749, 2.0749, 2.0749,  ..., 2.0749, 

In [9]:
images_path = '/proj/vondrick4/naveen/coir-data/FashionIQ/image_splits'
caps_path = '/proj/vondrick4/naveen/coir-data/FashionIQ/captions'

In [None]:
corpus_dress = json.load(open(os.path.join(images_path, 'split.dress.val.json'), 'r'))
corpus_shirt = json.load(open(os.path.join(images_path, 'split.shirt.val.json'), 'r'))
corpus_toptee = json.load(open(os.path.join(images_path, 'split.toptee.val.json'), 'r'))

In [10]:
caps_dress = json.load(open(os.path.join(caps_path, 'cap.dress.val.json'), 'r'))
caps_shirt = json.load(open(os.path.join(caps_path, 'cap.shirt.val.json'), 'r'))
caps_toptee = json.load(open(os.path.join(caps_path, 'cap.toptee.val.json'), 'r'))

In [11]:
caps_shirt

[{'target': 'B005AD7WZI',
  'candidate': 'B00CZ7QJUG',
  'captions': ['is solid white', 'is a lighter color']},
 {'target': 'B00BPD4N5E',
  'candidate': 'B0083I6W08',
  'captions': ['is green with a four leaf clover',
   'is green and has no text']},
 {'target': 'B008VNAKSU',
  'candidate': 'B0083WOL16',
  'captions': ['Is a brown tee shirt with diseal logo',
   'is darker and has short sleeves']},
 {'target': 'B005Y4JJ0Y',
  'candidate': 'B005Y4KFPM',
  'captions': ['is dark blue', 'is blue with a different character.']},
 {'target': 'B005NYBUF2',
  'candidate': 'B006QHPYIY',
  'captions': ['is red and a tshirt', 'more warmer colors']},
 {'target': 'B001OAN6BK',
  'candidate': 'B001OATAN8',
  'captions': ['is the provided product', 'look exact same']},
 {'target': 'B007A4H2IW',
  'candidate': 'B0075G2YFG',
  'captions': ['is white with mickey mouse on it',
   'has a mickey mouse graphic and is paler in colour']},
 {'target': 'B003IB71I2',
  'candidate': 'B005XIGOL8',
  'captions': ['h

In [None]:
len(corpus_dress)

In [None]:
len(corpus_shirt)

In [None]:
len(corpus_toptee)

In [None]:
corpus_all = set(corpus_dress + corpus_shirt + corpus_toptee)

In [None]:
len(corpus_all)

In [None]:
len(corpus_dress + corpus_shirt + corpus_toptee)/3

In [None]:
len(caps_dress)

In [None]:
len(caps_shirt)

In [None]:
len(caps_toptee)

In [None]:
caps_dress

In [None]:
caps_dress[]

In [None]:
images = os.listdir(images_path)

In [None]:
set(map(lambda x: x.split('.')[-1], images))

In [None]:
images[0].split('.')

In [None]:
len(images)