In [2]:
import torch
from model import load_resnet_model
from dataloader import create_dataloader
from train import train_model
from checkpoint import save_checkpoint, load_checkpoint
from tqdm import tqdm

In [3]:
model = load_resnet_model('resnet50', num_classes=10)



In [4]:
# dataloader = create_dataloader('./data/records/train/data_batch_1.pth',32)

In [28]:
import torch
from torch.utils.data import Dataset, DataLoader
import os

class MultiFileDataset(Dataset):
    def __init__(self, directory, file_pattern='data_batch_{}.pth'):
        """
        A dataset that loads tensors from multiple files in a directory.

        Args:
        - directory (str): Directory containing the data files.
        - file_pattern (str): Pattern of the filenames. 
                              The '{}' will be replaced by the batch number.
        """
        self.directory = directory
        self.file_pattern = file_pattern
        self.current_batch = 0
        self.data = None
        self.labels = None
        self.load_next_batch()

    def load_next_batch(self):
        """
        Loads the next batch of data from file.
        """
        self.current_batch += 1
        file_path = os.path.join(self.directory, self.file_pattern.format(self.current_batch))
        
        if os.path.isfile(file_path):
            batch = torch.load(file_path)
            self.data = batch['data']
            self.labels = batch['labels']
            self.index = 0  # Reset index
        else:
            self.data = None
            self.labels = None

    def __len__(self):
        return len(self.data) if self.data is not None else 0

    def __getitem__(self, idx):
        if idx >= len(self):
            self.load_next_batch()

        if self.data is not None and idx < len(self):
            return self.data[idx], self.labels[idx]
        else:
            raise StopIteration

def create_dataloader(directory, batch_size, num_workers=0, shuffle=False):
    dataset = MultiFileDataset(directory)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [29]:
dataloader = create_dataloader('./data/records_demo/train/', 32)

In [33]:
class MultiFileDataset(Dataset):
    def __init__(self, directory, file_pattern='data_batch_{}.pth', total_files=None):
        """
        A dataset that loads tensors from multiple files in a directory.

        Args:
        - directory (str): Directory containing the data files.
        - file_pattern (str): Pattern of the filenames.
        - total_files (int): Total number of batch files.
        """
        self.directory = directory
        self.file_pattern = file_pattern
        self.total_files = total_files or self._get_total_files()

    def _get_total_files(self):
        # Calculate the total number of files based on the file_pattern
        count = 0
        while os.path.isfile(os.path.join(self.directory, self.file_pattern.format(count + 1))):
            count += 1
        return count

    def __len__(self):
        # Assuming each file contains the same number of samples
        # If different, need a more sophisticated method
        return 1024 * self.total_files 

    def __getitem__(self, idx):
        file_idx = idx // 1024 + 1  # Determine which file to load
        in_file_idx = idx % 1024  # Index within the file

        file_path = os.path.join(self.directory, self.file_pattern.format(file_idx))
        batch = torch.load(file_path)

        return batch['data'][in_file_idx], batch['labels'][in_file_idx]

def create_dataloader(directory, batch_size, num_workers=4, shuffle=True):
    dataset = MultiFileDataset(directory)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [36]:
dataloader = create_dataloader('./data/records_demo/train/', 32, num_workers=0)

count = 0
for _, _ in dataloader:
    count += 1
    print(count)

RuntimeError: DataLoader worker (pid(s) 55156, 31116, 54316, 59940) exited unexpectedly