In [1]:
from concurrent.futures import ThreadPoolExecutor
import urllib.request
import concurrent
import pandas as pd
import os
from tqdm import tqdm
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, InterpolationMode
import numpy as np

### Get Missing images

In [None]:
def get_file_path_from_url(url):
    return '-'.join(url.split('/')[5:])

In [10]:
def get_missing_images():
    all_images = pd.read_csv('products.csv')['searchImage'].tolist()
    downloaded_images = set(os.listdir('assets'))

    print('Download Images: ', len(downloaded_images))

    missing_images_set, missing_images = set(), []

    for url in all_images:
        if get_file_path_from_url(url) in downloaded_images:
            continue
        if url in missing_images_set:
            continue
        missing_images.append(url)
        missing_images_set.add(url)

    print('Missing images: ', len(missing_images))
    return missing_images
missing_images = get_missing_images()

### Download Images

In [4]:
def download_image(image_url):
    print(image_url)
    file_name = 'assets/' + get_file_path_from_url(image_url)
    urllib.request.urlretrieve(image_url, file_name)

def multithreaded_download(images, num_threads):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = [executor.submit(download_image, image) for image in tqdm(images)]
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result(timeout=0.5)
            except concurrent.futures.TimeoutError:
                print("A thread took too long and was terminated.")

def singlethreaded_download(images):
    for image in images:
        download_image(image)
# multithreaded_download(missing_images, 30)
# singlethreaded_download(missing_images)

In [None]:
def find_bad_images():
    dir = 'assets-224'
    images_test = os.listdir(dir)
    bad_images = []
    for image in tqdm(images_test):
        try:
            img = Image.open(dir + '/' + image)
            np.array(img)
        except Exception as e:
            print(f'{image} : {e}')
            bad_images.append(image)
    print('Number of Bad Images: ', len(bad_images))
    return bad_images
bad_images = find_bad_images()

### Resize Images

In [51]:
def resize_images(files):
    transform = Compose([
        Resize(224, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(size=(224, 224))
    ])
    input_folder = 'assets'
    output_folder = 'assets-224'
    for file in files:
        try:
            img = Image.open(input_folder + '/' + file)
            img = transform(img)
            img.save(output_folder + '/' + file)
            del img
        except Exception as e:
            print(e, '===', file)

def get_image_files():
    processed = set(os.listdir('assets-224'))
    image_files = [x for x in os.listdir('assets') if x not in processed and x not in bad_images]
    return image_files

def multithreaded_resize_images(num_threads=10, batch_size=200):
    image_files = get_image_files()
    print('received image files')
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        for i in range(0, len(image_files), batch_size):
            executor.submit(resize_images, image_files[i:i+batch_size])

# multithreaded_resize_images()

In [54]:
truncated_image_files = get_image_files()

In [None]:
from torch.utils.data import DataLoader
import torch

batch_size=500

all_images_features = torch.Tensor()
input_dir = '/kaggle/input/fashion-224-cropped/assets-224'

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = preprocess(Image.open(input_dir + '/' + image_path)).unsqueeze(0)
        return image

image_dataset = ImageDataset(all_images)
image_dataloader = DataLoader(image_dataset, batch_size=batch_size)

for batch_images in tqdm(image_dataloader):
    batch_images = batch_images.to(device)

    batch_features = model.encode_image(batch_images)
    
    all_images_features = torch.cat((all_images_features, batch_features), 0)


In [55]:
img = Image.open('assets/' + truncated_image_files[0])
img.size

(1080, 1440)