In [None]:
import torchvision.datasets as datasets
from torch.utils.data import Subset
import numpy as np

LSVRC_indices = np.load('../data/badnets_indices.npy').tolist()
LSVRC_train = datasets.ImageFolder(root='../data/ImageNet-1K/train/')
trainset = Subset(LSVRC_train, LSVRC_indices)

In [None]:
import os
import shutil
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

# Set the destination folder and target resize dimensions
destination_folder = '../data/transfer_sets/badnets/'
resize_size = (32, 32)  # Target resize dimensions
max_workers = 16  # Number of threads

# Recreate the destination folder
if os.path.exists(destination_folder):
    shutil.rmtree(destination_folder)
os.makedirs(destination_folder)

# Thread-safe logging for errors
error_log = []

# Function to process a single image
def process_image(index, img, destination_folder, resize_size):
    destination_path = os.path.join(destination_folder, f'{index}.png')
    try:
        # Skip resizing if the image is already the correct size
        if img.size != resize_size:
            img = img.resize(resize_size, Image.LANCZOS)
        img.save(destination_path, format='PNG')
    except Exception as e:
        error_message = f"Error processing image at index {index}: {e}"
        error_log.append(error_message)

# Process the dataset using a ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    # Wrap tqdm around dataset indices
    with tqdm(total=len(trainset), desc="Processing images") as pbar:
        futures = []
        for index in range(len(trainset)):
            img, _ = trainset[index]  # Fetch the PIL image
            futures.append(executor.submit(process_image, index, img, destination_folder, resize_size))

        # Monitor progress as tasks complete
        for future in futures:
            future.result()
            pbar.update(1)

# Print errors if any occurred
if error_log:
    print("The following errors occurred during processing:")
    for error in error_log:
        print(error)
else:
    print("All images processed successfully!")