In [1]:
import torch
from transformers import CLIPProcessor, CLIPModel
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from torchvision import transforms
from miniai.imports import *
from miniai.datasets import *
from datasets import load_dataset
import multiprocessing as mp

  from .autonotebook import tqdm as notebook_tqdm


### Get Tiny Imagenet Data

In [2]:
xl,yl = 'image','label'
name = "zh-plus/tiny-imagenet"
dsd = load_dataset(name)

In [3]:
dsd

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

### Preprocess Data
1. Convert Images to Tensors
2. Append CLIP Image Embeddings

In [4]:
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi

def noisify(x0):
    device = x0.device
    n = len(x0)
    t = torch.rand(n,).to(x0).clamp(0,0.999)
    eps = torch.randn(x0.shape, device=device)
    abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
    xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*eps
    return (xt, t.to(device)), eps

def PIL_to_tensor(batch):
    images = [transform(image) for image in batch]
    images = torch.stack(images)
    return images

def collate_clip(batch):
    batch = PIL_to_tensor(batch)
    with torch.no_grad():
        inputs = processor(images=batch, return_tensors="pt")
        image_input = inputs["pixel_values"]
        image_features = model.get_image_features(image_input)
    (xt, t), eps = noisify(batch)
    return (xt, t, image_features), eps

def dl_ddpm(ds): 
    return DataLoader(ds, batch_size=16, collate_fn=collate_clip, num_workers=8)

model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
transform = transforms.Compose([transforms.ToTensor()])

dls = DataLoaders(dl_ddpm(dsd['train']["image"]), dl_ddpm(dsd['valid']["image"]))

dl = dls.train
(xt,t,image_features),eps = b = next(iter(dl))