# LynxDataset tests

## Setup notebook and imports

For now, I tested everything in pytorch 2.0.1.

I had to install albumentations.

In [1]:
# Allow reloading of libraries without restarting the kernel
%load_ext autoreload
%autoreload 2

In [2]:
from dataset import LynxDataset
dataset_csv = '/gpfsscratch/rech/ads/commun/datasets/extracted/lynx_dataset.csv'

## Accessing the elements of the dataset

In [3]:
# Create an instance of the dataset
dataset = LynxDataset(dataset_csv=dataset_csv, loader="pil")

input, output = dataset[0]  # Example for getting the first item

# Accessing data
image = input['image']
lynx_id = output['lynx_id']
# Access other metadata from input as needed

## Iterating through the dataset

In [4]:
from tqdm import tqdm

# Using tqdm to add a progress bar to the loop
for idx in tqdm(range(len(dataset)), desc="Processing dataset"):
    try:
        # Attempt to load the image data and other information
        input_dict, output_dict = dataset[idx]

        # Access key elements to ensure they're loaded correctly
        _ = input_dict['image']
        _ = output_dict['lynx_id']
        # Add checks for other elements if necessary

    except Exception as e:
        print(f"Error at index {idx}: {e}")
        # Continue to the next iteration after logging the error
        continue

print("Dataset iteration completed with checks.")

Processing dataset:  15%|█▌        | 503/3330 [00:32<04:30, 10.45it/s]

Error at index 502: image file is truncated


Processing dataset:  61%|██████    | 2033/3330 [02:04<00:51, 24.99it/s]

Error at index 2032: image file is truncated


Processing dataset: 100%|██████████| 3330/3330 [03:09<00:00, 17.60it/s]

Dataset iteration completed with checks.





## Comparing loading with pil vs opencv

In [16]:
import time
dataset_pil = LynxDataset(dataset_csv=dataset_csv, loader='pil')  # Use PIL
dataset_opencv = LynxDataset(dataset_csv=dataset_csv, loader='opencv')  # Use OpenCV

def measure_performance(dataset, num_samples=100):
    start_time = time.time()
    for i in range(num_samples):
        _ = dataset[i]
    end_time = time.time()
    return end_time - start_time

# Measure performance

_ = measure_performance(dataset_pil) #just for fairness, avoid cache difference...
pil_time = measure_performance(dataset_pil)
opencv_time = measure_performance(dataset_opencv)

print(f"Time taken with PIL: {pil_time} seconds")
print(f"Time taken with OpenCV: {opencv_time} seconds")

Time taken with PIL: 7.331503629684448 seconds
Time taken with OpenCV: 9.016743421554565 seconds
