In [None]:
#default_exp crafter

# Crafter

Takes a list of image filenames and transforms them to batches of the correct dimensions for CLIP. Need to figure out a way around torchvision's loader  idiosyncrasies here: currently it just loads images from subfolders, needs to operate okay if pointed at a single folder of images, or recursively, or an arbitrary list of files.

Then, too, it would be nice to eventually putthis work on the client computer using torchscript or something. So that it only sends 224x224x3 images over the wire. And we only have to compute those once per image, since we're storing a database of finished vectors which should be even smaller


In [None]:
#export
import torch
from torchvision.datasets import VisionDataset
from PIL import Image

In [None]:
#export
def make_dataset(new_files):
    '''Returns a list of samples of a form (path_to_sample, class) and in 
    this case the class is just the filename'''
    samples = []
    slugs = []
    for i, f in enumerate(new_files):
        path, slug = f
        samples.append((str(path), i))
        slugs.append((slug, i))
    return(samples, slugs)

In [None]:
#export
def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [None]:
#export
class DatasetImagePaths(VisionDataset):
    def __init__(self, new_files, transforms = None):
        super(DatasetImagePaths, self).__init__(new_files, transforms=transforms)
        samples, slugs = make_dataset(new_files)
        self.samples = samples
        self.slugs = slugs
        self.loader = pil_loader
        self.root = 'file dataset'
    def __len__(self):
        return(len(self.samples))
        
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transforms is not None:
            sample = self.transforms(sample)
        return sample, target

In [None]:
crafted = DatasetImagePaths(new_files)

In [None]:
if len(crafted) > 0:
    crafted[0][0].show()

Okay, that seems to work decently. Test with transforms, which I will just find in CLIP source code and copy over, to prevent having to import CLIP in this executor.

In [None]:
#export
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

In [None]:
#export
def clip_transform(n_px):
    return Compose([
        Resize(n_px, interpolation=Image.BICUBIC),
        CenterCrop(n_px),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

In [None]:
crafted_transformed = DatasetImagePaths(new_files, clip_transform(224))

Put that all together, and wrap in a DataLoader for batching. In future, need to figure out how to pick batch size and number of workers programmatically bsed on device capabilities.

In [None]:
#export
def crafter(new_files, device, batch_size=128, num_workers=4): 
    with torch.no_grad():
        imagefiles=DatasetImagePaths(new_files, clip_transform(224))
        img_loader=torch.utils.data.DataLoader(imagefiles, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return(img_loader)

In [None]:
crafted_files = crafter(new_files, device)

In [None]:
crafted_files

<torch.utils.data.dataloader.DataLoader at 0x7fc5ac6822d0>