In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, target_column: str):
        """
        Initializes the dataset from a pandas DataFrame.

        Args:
            dataframe (pd.DataFrame): The source DataFrame containing the data.
            target_column (str): The name of the target column.
        """
        self.dataframe = dataframe
        self.target_column = target_column
        self.features = dataframe.drop(columns=[target_column]).values
        self.targets = dataframe[target_column].values

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        features = self.features[idx]
        target = self.targets[idx]
        
        features_tensor = torch.tensor(features, dtype=torch.float)
        # change target type if you need it
        
        target_tensor = torch.tensor(target, dtype=torch.long)
        sample = (features_tensor, target_tensor)

        return sample

In [None]:
from torchvision import datasets, transforms
from tqdm.auto import tqdm

import numpy as np
import random
import pickle


class SimpleTripletDataset(Dataset):
    """
    A dataset class for generating triplets for training models using the triplet loss on MNIST data.

    Attributes:
        is_train (bool): Indicates if the dataset is for training. If True, triplets are generated.
        describe (str): A description used in progress reporting.
        labels (np.ndarray): The array of labels corresponding to MNIST digits.
        values (np.ndarray): The array of image data corresponding to MNIST digits.
        transform (transforms.ToTensor): A transform for converting images to tensors.
        triplets (List[Tuple[int, int, int]]): A list of generated triplets indices for training.
    """

    def __init__(self, dataset: datasets.MNIST, describe: str, train: bool = True) -> None:
        """
        Initializes the SimpleTripletDataset.

        Args:
            dataset (datasets.MNIST): The MNIST dataset from which labels and images are extracted.
            describe (str): Description string used for progress reporting.
            train (bool): Indicates whether the dataset is used for training. Defaults to True.
        """
        self.is_train: bool = train
        self.describe: str = describe

        self.labels: np.ndarray = dataset.targets.numpy()
        self.values: np.ndarray = dataset.data.numpy()

        self.transform: transforms.ToTensor = transforms.ToTensor()

        if self.is_train:
            self.triplets: list[tuple[int, int, int]] = self.generate_triplets()
        else:
            self.triplets = []

    def generate_triplets(self) -> list[tuple[int, int, int]]:
        """
        Generates triplets for training, each triplet includes an anchor, a positive, and a negative image index.

        Returns:
            A list of tuples, where each tuple contains three indices: (anchor_idx, positive_idx, negative_idx).
        """
        label_to_indices: dict[int, np.ndarray] = {label: np.where(self.labels == label)[0] for label in np.unique(self.labels)}
        triplets: list[tuple[int, int, int]] = []

        for anchor_idx in tqdm(range(len(self.labels)), desc=self.describe):
            anchor_label: int = self.labels[anchor_idx]
            positive_indices: np.ndarray = label_to_indices[anchor_label]
            positive_indices = positive_indices[positive_indices != anchor_idx]

            negative_indices: np.ndarray = np.where(self.labels != anchor_label)[0]

            if positive_indices.size > 0 and negative_indices.size > 0:
                positive_idx: int = random.choice(positive_indices)
                negative_idx: int = random.choice(negative_indices)
                triplets.append((anchor_idx, positive_idx, negative_idx))

        return triplets


    def __len__(self) -> int:
        return len(self.values)


    def __getitem__(self, idx: int):
        """
        Retrieves an item from the dataset. For training, returns a triplet and the label;
        for non-training mode, returns the image tensor and its label.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            For training (is_train=True): 
                A tuple containing the transformed anchor, positive, negative images and the anchor label.
                (torch.Tensor, torch.Tensor, torch.Tensor, int)
            Else:
                A tuple containing the transformed image and its label.
                (torch.Tensor, int)
        """
        if self.is_train:
            anchor_idx, positive_idx, negative_idx = self.triplets[idx]
            anchor = self.values[anchor_idx]
            positive = self.values[positive_idx]
            negative = self.values[negative_idx]

            return (self.transform(anchor), 
                    self.transform(positive), 
                    self.transform(negative), 
                    self.labels[anchor_idx])
        else:
            anchor_idx = idx
            anchor = self.values[anchor_idx]
            return self.transform(anchor), self.labels[anchor_idx]


    def __getstate__(self):
        state = self.__dict__.copy()
        if 'transform' in state:
            del state['transform']
        return state


    def __setstate__(self, state):
        self.__dict__.update(state)
        self.transform = transforms.ToTensor()


    def save(self, file_path: str) -> None:
        """
        Saves the current dataset object to a file using pickle.

        Args:
            file_path (str): The file path where the dataset object will be saved.
        """
        with open(file_path, 'wb') as f:
            pickle.dump(self, f)


    @classmethod
    def load(cls, file_path: str) -> "SimpleTripletDataset":
        """
        Loads a dataset object from a file using pickle.

        Args:
            file_path (str): The file path from where the dataset object will be loaded.

        Returns:
            SimpleTripletDataset: The loaded dataset object.
        """
        with open(file_path, 'rb') as f:
            return pickle.load(f)
        