In [1]:
import torch
from data_utils import dataset_x
from data import Augmentation, SSLImageDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torchvision
from torchvision import transforms

from ffcv.loader import OrderOption
import ffcv
from ffcv_ssl import DivideImageBy255

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
basedataset = "imagenette"
use_ffcv = True
batchsize = 8
numworkers = 10

In [3]:



def build_ffcv_nonsslloader(write_path, imgsize, mean, std, batchsize, numworkers, mode="train"):
    image_pipeline1 = [
        ffcv.fields.rgb_image.SimpleRGBImageDecoder(),
        #ffcv.fields.rgb_image.CenterCropRGBImageDecoder(output_size=imgsize, ratio=160/160),
        ffcv.transforms.ops.ToTensor(),
        #ToDevice(torch.device('cuda:0'), non_blocking=True),
        ffcv.transforms.ops.ToTorchImage(convert_back_int16=False),
        DivideImageBy255(torch.float32),
        torchvision.transforms.Normalize(mean, std)
        #ffcv.transforms.normalize.NormalizeImage(mean=np.array(mean)*255.0, std=np.array(std)*255.0, type=np.float32)
    ]

    label_pipeline = [
        ffcv.fields.basics.IntDecoder(),
        ffcv.transforms.ops.ToTensor(),
        ffcv.transforms.common.Squeeze(1),
        #ToDevice(torch.device('cuda:0'), non_blocking=True),  # not int on gpu
    ]

    loader = ffcv.loader.Loader(
        write_path,
        num_workers=numworkers,
        batch_size=batchsize,
        pipelines={
            "image": image_pipeline1,
            "label": label_pipeline,
        },
        order=OrderOption.RANDOM if mode == 'train' else OrderOption.SEQUENTIAL,
        drop_last=False,
        os_cache=True,
        seed=42
    )
    return loader

In [4]:
train_dataset, test_dataset, num_classes , imgsize, mean, std = dataset_x(basedataset)
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
if not use_ffcv:

    test_augmentation = Augmentation(imgsize, mean, std, mode="test", num_views=1)
    train_dataset.transform = test_dataset.transform = test_augmentation
    trainloader = DataLoader(train_dataset, batch_size=batchsize, num_workers=numworkers,shuffle=False, pin_memory=True, drop_last=False)
    testloader = DataLoader(test_dataset, batch_size=batchsize, num_workers=numworkers, shuffle=False, pin_memory=True, drop_last=False)
else:

    trainloader = build_ffcv_nonsslloader(
        write_path=f"output/{basedataset}/trainds.beton",
        mean=mean,
        std=std,
        imgsize=imgsize,
        batchsize=batchsize,
        numworkers=numworkers,
        mode="test"
    )
    testloader = build_ffcv_nonsslloader(
        write_path=f"output/{basedataset}/testds.beton",
        mean=mean,
        std=std,
        imgsize=imgsize,
        batchsize=batchsize,
        numworkers=numworkers,
        mode="test"
    )

160 160
36 31


TypeError: SimpleRGBImageDecoder ony supports constant image,
consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder
instead.

In [None]:
plaindataset = torchvision.datasets.Imagenette(
        root="./data",
        split="train",
        transform=transforms.Compose([
            #transforms.Resize((160, 160)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ]),
        download=False,
        size="160px",
    )

vizdataloader = DataLoader(plaindataset, batch_size=batchsize, shuffle=False)


In [None]:
len(next(iter(trainloader)))

In [None]:
imgs_aug1, y = next(iter(trainloader))


In [None]:
imgs_orig, labels = next(iter(vizdataloader))

In [None]:
labels

In [None]:
imgs_orig.shape

In [None]:

label2name = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]

def plot_imgs(imgs_orig, imgs_aug1, labels):
    maximg = min(4, imgs_orig.shape[0])
    imgs_orig, imgs_aug1, labels = imgs_orig[:maximg], imgs_aug1[:maximg], labels[:maximg]
    fig, ax = plt.subplots(maximg, 2, figsize=(25, 25))
    for i, (img_orig, img_aug1) in enumerate(zip(imgs_orig, imgs_aug1)):
        ax[i, 0].imshow(img_orig.permute(1,2,0) * 0.5 + 0.5)
        ax[i, 1].imshow(img_aug1.permute(1,2,0) * 0.5 + 0.5)
        
        ax[i, 0].set_title("Original " + label2name[int(labels[i])], fontsize=16, pad=5)
        ax[i, 1].set_title("Augmented " + label2name[int(labels[i])] + " 1", fontsize=16, pad=5)
        
    plt.tight_layout()
    fig.show()

In [None]:
plot_imgs(imgs_orig, imgs_aug1, labels)

In [None]:
plot_imgs(imgs_orig, imgs_aug1, labels)