In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torch.utils.data  import DataLoader, Dataset, Sampler
import torchvision.transforms as transforms

from tqdm import tqdm
from typing import Optional, Tuple

from rival10 import *
RIVAL10_constants.set_rival10_dir("/home/ksas/Public/datasets/RIVAL10/") # Add you own

def load_local_rival10(batch_size:int = 16, 
                        num_workers:int = 4, 
                        preprocess:Optional[transforms.Compose]=None):

    trainset = LocalRIVAL10(train=True, 
                            cherrypick_list=["img", "og_class_label"],
                            masks_dict=True, 
                            transform=preprocess,
                            verbose = "key")
    testset = LocalRIVAL10(train=False, 
                            cherrypick_list=["img", "og_class_label"],
                            masks_dict=True, 
                            transform=preprocess,
                            verbose = "key")

    class_to_idx = {c: i for (i,c) in enumerate(RIVAL10_constants._ALL_CLASSNAMES)}
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    train_loader = DataLoader(trainset, batch_size=batch_size,
                                shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(testset, batch_size=batch_size,
                                shuffle=False, num_workers=num_workers)
    
    return trainset, testset, train_loader, test_loader, class_to_idx, idx_to_class

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

trainset, testset, train_loader, test_loader, class_to_idx, idx_to_class = load_local_rival10(preprocess = transform)

In [None]:
img, og_class_label = next(iter(train_loader))
print(f"img: {img.size()}")
print(f"og_class_label: {og_class_label.size()}")