Accompanying Blog Post: https://barrymoo.dev/blog/simple-scaleable-preprocessing-with-pytorch-and-ray-0

In [1]:
from math import ceil

def split_word(word, num_chars=2, overlap=1):
    # Refinement types amiright?!
    assert num_chars > overlap, f"The number of characters should be greater than the overlap"
    assert num_chars >= 2, f"Number of characters should be greater than or equal to 2, got {num_chars}"
    assert overlap >= 1, f"Overlap should be greater than or equal to 1, got {overlap}"
    
    word_length = len(word)
    
    if word_length <= num_chars:
        return [word]
    
    num_segments = ceil(
        (word_length - overlap) / (num_chars - overlap)
    )
    
    output = [None] * num_segments
    for idx in range(num_segments):
        if idx == num_segments - 1:
            end = word_length
            begin = end - num_chars
        else:
            begin = num_chars * idx - overlap * idx
            end = begin + num_chars
        
        output[idx] = word[begin:end]
    
    return output

In [2]:
a = "hello"

In [3]:
split_word(a, num_chars=2, overlap=1)

['he', 'el', 'll', 'lo']

In [4]:
split_word(a, num_chars=3, overlap=2)

['hel', 'ell', 'llo']

In [5]:
from torch.utils.data import Dataset

class WordSplitter(Dataset):
    def __init__(self, inputs, num_chars=2, overlap=1):
        self.inputs = inputs
        self.num_chars = num_chars
        self.overlap = overlap
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        filename = self.inputs[idx]
        
        with open(filename, "r") as f:
            word = f.read().strip()
        
        return split_word(
            word,
            num_chars=self.num_chars,
            overlap=self.overlap
        )
    
    @classmethod
    def collate_fn(*batch):
        return batch[1]

In [6]:
import csv

with open("dataset.csv", "r") as csv_file:
    reader = csv.DictReader(csv_file)
    input_files = [f"inputs/{row['input']}" for row in reader]
input_files

['inputs/a.txt', 'inputs/b.txt', 'inputs/c.txt', 'inputs/d.txt']

In [7]:
word_splitter = WordSplitter(input_files, num_chars=3, overlap=2)

In [8]:
assert word_splitter[0] == ['sim', 'imp', 'mpl', 'ple']

In [9]:
from torch.utils.data import DataLoader

loader = DataLoader(
    word_splitter,
    batch_size=1,
    shuffle=False,
    num_workers=len(word_splitter),
    collate_fn=WordSplitter.collate_fn,
)

In [10]:
for metadatas in loader:
    for metadata in metadatas:
        print(metadata)

['sim', 'imp', 'mpl', 'ple']
['sca', 'cal', 'ale', 'lea', 'eab', 'abl', 'ble']
['pre', 'rep', 'epr', 'pro', 'roc', 'oce', 'ces', 'ess', 'ssi', 'sin', 'ing']
['pyt', 'yto', 'tor', 'orc', 'rch']
