In [None]:
from pathlib import Path

from gorillatracker.data.nlet import NletDataModule, build_quadlet
from gorillatracker.data.ssl import SSLDataset
from gorillatracker.ssl_pipeline.ssl_config import SSLConfig
from torchvision.transforms import Resize

DATA_DIR = Path("/workspaces/gorillatracker/video_data/cropped-images/2024-04-18")
split_path = Path(
    "/workspaces/gorillatracker/data/splits/SSL/SSL-1k-woCXL_1k-100-1k_split_20240716_1037.pkl"
)


DATASET_CLS = SSLDataset

# Sample everything
CONFIG = SSLConfig(
    tff_selection="movement",
    negative_mining="overlapping",
    n_samples=100_000,
    feature_types=["body_with_face"],
    min_confidence=0.8,
    min_images_per_tracking=50,
    split_path=split_path,
    width_range=(80, None),
    height_range=(80, None),
    movement_delta=0.05,
)

data_module = NletDataModule(
    data_dir=DATA_DIR,
    dataset_class=DATASET_CLS,
    nlet_builder=build_quadlet,
    batch_size=64,
    workers=10,
    model_transforms=Resize((224, 224)),
    training_transforms=lambda x: x,
    dataset_names=["Showcase"],
    ssl_config=CONFIG,
)

data_module.setup("fit")

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

for batch in data_module.train_dataloader():
    ids, _, _ = batch
    nlet = [Path(t[0]) for t in ids]
    fig, axes = plt.subplots(1, len(nlet), figsize=(20, 4))
    for ax, path in zip(axes, nlet):
        img = Image.open(path)
        ax.imshow(img)
        ax.axis("off")
    break