In [2]:
import PIL
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [21]:
import os
from pathlib import Path
import random

In [7]:
ROOT = Path('./data/imagenette2-160/')

In [58]:
def default_tfms(size):
    tfms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])
    return tfms

In [65]:
class ImageNette(Dataset):
    def __init__(self, ROOT, train=True, shuffle=True, tfms=None):
        self.tfms = default_tfms(size=128) if tfms is None else tfms
        self.ROOT = ROOT
        self.path = ROOT/'train' if train==True else ROOT/'val'
        
        self.n2c = {v:i for i,v in enumerate(os.listdir(self.path))}
        self.c2n = {v:k for k,v in self.n2c.items()}
        
        data = []
        for c in self.n2c.keys():
            p2fol = os.path.join(self.path, c)
            for f in os.listdir(p2fol):
                p2im = os.path.join(p2fol, f)
                data.append(p2im)
                
        self.data = data
        if shuffle: random.shuffle(self.data)
        
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        p2im = self.data[idx]
        im = PIL.Image.open(p2im)
        if self.tfms: im = self.tfms(im)
        y = self.get_cls(p2im)
        y = torch.Tensor([float(y)]).squeeze(0).long()
        return im, y
        
    def get_cls(self, p2im): 
        cname = p2im.split('\\')[3]
        return self.n2c[cname]

In [70]:
class DataBunch:
    def __init__(self, root, bs=32, tfms=None, num_workers=0):
        self.train_ds = ImageNette(root, train=True, tfms=tfms)
        self.valid_ds = ImageNette(root, train=False, tfms=tfms)
        
        self.train_dl = DataLoader(
            self.train_ds, batch_size=bs, num_workers=num_workers, shuffle=True)
        
        self.valid_dl = DataLoader(
            self.valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)

In [71]:
data = DataBunch(ROOT)

In [1]:
from bnet.databunch import DataBunch
from pathlib import Path

In [2]:
root = Path('./data/imagenette2-160/')
data = DataBunch(root)

In [3]:
xb, yb = next(iter(data.train_dl))

In [4]:
xb.shape

torch.Size([32, 3, 128, 128])

In [5]:
yb.shape

torch.Size([32])