In [3]:
import torch
from torch.utils.data import Dataset
#from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
class LazyImageDataset(Dataset):
    def __init__(self, image_links, transform=None):
        self.image_links = image_links
        self.transform = transform
        self.valid_indices = self._find_valid_images()

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        url = self.image_links[idx]
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
        
        if self.transform:
            image = self.transform(image)
        
        return image

    def _find_valid_images(self):
        valid_indices = []
        for idx, url in enumerate(self.image_links):
            try:
                response = requests.get(url)
                image = Image.open(BytesIO(response.content))
                # Check if the image can be opened without error
                image.verify()
                valid_indices.append(idx)
            except:
                print(f"Corrupted image: {url}")
        return valid_indices


# Example usage
df = pd.read_csv("/inditextech_hackupc_challenge_images.csv")
image_links = [df.iloc[i, 0] for i in range(len(df[:100]))]
#image_links = ["https://static.zara.net/photos///2024/V/0/3/p/5767/521/712/2/w/2048/5767521712_6_1_1.jpg?ts=1707751045954","https://static.zara.net/photos///2024/V/0/2/p/9621/451/406/2/w/2048/9621451406_6_1_1.jpg?ts=1708614924346"]  # Your list of image links
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
dataset = LazyImageDataset(image_links, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# Now you can use this dataloader in your training loop

In [None]:
# Define a function to show images
def show_images(images):
    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    for i, image in enumerate(images):
        # Convert tensor to numpy array and transpose it to (height, width, channels)
        np_image = image.numpy().transpose((1, 2, 0))
        # Undo normalization
        np_image = np.clip(np_image, 0, 1)
        axes[i].imshow(np_image)
        axes[i].axis('off')
    plt.show()

# Fetch a batch of images from the dataloader
batch_images = next(iter(dataloader))
show_images(batch_images)