In [None]:
# default_exp data

# Data

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
import pytorch_lightning as pl

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torch.utils.data.dataloader import DataLoader

In [None]:
# export
class RetroDataset(pl.LightningDataModule):
    def __init__(self, dataset_name, column, encoder_name, dataset_config=None, batch_size=32, k=10, n_perc=100):
        self.dataset_name = dataset_name
        self.column = column
        self.encoder_name = encoder_name
        self.dataset_config = dataset_config
        self.batch_size = batch_size
        self.k = k
        self.n_perc = n_perc
    
    def setup(self, stage=None):
        self.model = SentenceTransformer(self.encoder_name)
        train_ds = load_dataset(self.dataset_name, self.dataset_config, split=f"train[:{self.n_perc}]")
        valid_ds = load_dataset(self.dataset_name, self.dataset_config, split=f"validation[:{self.n_perc}]")

        train_ds = train_ds.map(lambda example: {"embeddings": self.model.encode(example[self.column])}, batched=True)
        train_ds.add_faiss_index(column="embeddings")
        valid_ds = valid_ds.map(lambda example: {"embeddings": self.model.encode(example[self.column])}, batched=True)
        valid_ds.add_faiss_index(column="embeddings")

        def get_nearest_neighbors(example):
            _, retrieved_examples = train_ds.get_nearest_examples("embeddings", example["embeddings"], k=self.k)
            example["retrieved_examples"] = retrieved_examples[self.column]

            return example
        
        self.train_ds = train_ds.map(get_nearest_neighbors)
        self.valid_ds = valid_ds.map(get_nearest_neighbors)
    
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
    
    def valid_dataloader(self):
        return DataLoader(self.valid_ds, batch_size=self.batch_size, shuffle=True)
    
    def get_nearest_neighbors(self, example, k=10):
        embed = self.model.encode(example)
        _, retrieved_examples = self.train_ds.get_nearest_examples("embeddings", embed, k=k)

        return retrieved_examples[self.column]

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()