# LynxDataset & dataloader tests

## Setup notebook and imports

For now, I tested everything in pytorch 2.0.1.

I had to install albumentations.

In [1]:
# Allow reloading of libraries without restarting the kernel
%load_ext autoreload
%autoreload 2

In [51]:
from lynx_id.data.dataset import LynxDataset
from pathlib import Path
from lynx_id.data.collate import *
from torch.utils.data import DataLoader
from tqdm import tqdm  # Import tqdm
import torch

In [52]:
dataset_csv = Path('/gpfsscratch/rech/ads/commun/datasets/extracted/lynx_dataset_full.csv')

## Single mode

In [53]:
# Create an instance of the dataset
dataset = LynxDataset(dataset_csv=dataset_csv, loader="pil")

input, output = dataset[0]  # Example for getting the first item

# Accessing data
image = input['image']
lynx_id = output['lynx_id']
# Access other metadata from input as needed

In [54]:
dataloader = DataLoader(dataset, 
                        batch_size=32, 
                        shuffle=True, 
                        num_workers=10,
                        prefetch_factor=2, 
                        persistent_workers=True,
                        pin_memory=True,
                        collate_fn=collate_single)

In [29]:
stop_after = 50 #len(dataloader)
# Adjust tqdm's 'total' parameter to stop_after, so the progress bar matches the number of iterations you want.
for i, (input, output) in enumerate(tqdm(dataloader, total=stop_after, desc="Processing"), start=1):
    if i >= stop_after:
        print("Reached stop condition after", i, "iterations.")
        break  # This will exit the loop once stop_after is reached
        

Processing:  98%|█████████▊| 49/50 [00:20<00:00,  2.37it/s]

Reached stop condition after 50 iterations.





In [55]:
print(input.keys())
print(output.keys())
print(type(input["image"]))

dict_keys(['image', 'source', 'pattern', 'date', 'location', 'image_number', 'conf', 'x', 'y', 'width', 'height', 'filepath'])
dict_keys(['lynx_id'])
<class 'numpy.ndarray'>


## Triplet mode

In [58]:
import torchvision.models as models
import torch
weights = torch.load("/gpfsscratch/rech/ads/commun/models/resnet50/pretrained_weights.pt")
model = models.resnet50(pretrained=False)
model.load_state_dict(weights)
model = torch.nn.Sequential(*(list(model.children())[:-1]))  # Remove the last classification layer



In [75]:
# Initialize dataset
dataset = LynxDataset(dataset_csv=dataset_csv, 
                      loader="pil",
                      mode='triplet',
                      load_triplet_path="/gpfsscratch/rech/ads/commun/precompute/triplet_precompute.npz",
                      save_triplet_path="/gpfswork/rech/ads/commun/kg_tests/dataloader_tests/triplet_precompute.npz",
                      model=model,
                      device="auto", 
                      verbose=True)
#input, output = dataset[0]  # Example for getting the first item
# Accessing data
for i, (anchor, positive, negative)  enumerate(dataset)
# Access other metadata from input as needed

IndexError: index 6360 is out of bounds for dimension 0 with size 4743

In [28]:
dataloader = DataLoader(dataset, 
                        batch_size=32, 
                        shuffle=True, 
                        num_workers=10,
                        prefetch_factor=2, 
                        persistent_workers=True,
                        pin_memory=True,
                        collate_fn=collate_single)

In [29]:
stop_after = 50 #len(dataloader)
show_data_info = False
# Adjust tqdm's 'total' parameter to stop_after, so the progress bar matches the number of iterations you want.
for i, (input, output) in enumerate(tqdm(dataloader, total=stop_after, desc="Processing"), start=1):
    if show_data_info:
        print(input.keys())
        print(output.keys())
    if i >= stop_after:
        print("Reached stop condition after", i, "iterations.")
        break  # This will exit the loop once stop_after is reached

Processing:  98%|█████████▊| 49/50 [00:20<00:00,  2.37it/s]

Reached stop condition after 50 iterations.





In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader

def my_collate(batch):
    # Initialize lists to gather all elements for each key
    images = []
    sources = []
    patterns = []
    dates = []
    locations = []
    image_numbers = []
    lynx_ids = []

    # Iterate over each item in the batch
    for input_dict, output_dict in batch:
        # Append data from input dictionary
        images.append(input_dict['image'])  # List of images
        sources.append(input_dict['source'])
        patterns.append(input_dict['pattern'])
        dates.append(input_dict['date'])
        locations.append(input_dict['location'])
        image_numbers.append(input_dict['image_number'])

        # Append data from output dictionary
        lynx_ids.append(output_dict['lynx_id'])

    # Construct the batched input and output dictionaries
    batched_input_dict = {
        'images': images,
        'sources': sources,
        'patterns': patterns,
        'dates': dates,
        'locations': locations,
        'image_numbers': image_numbers
    }

    batched_output_dict = {
        'lynx_ids': lynx_ids
    }

    return batched_input_dict, batched_output_dict


dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=my_collate)

#dataloader = DataLoader(dataset, batch_size=2, shuffle=False)


In [None]:
next(enumerate(dataloader))

In [None]:
# Iterate over the DataLoader
for batch in dataloader:
    inputs, outputs = batch