## This nb borrowed lines from fastai, OpenAI & ChatGPT

Aim is to verify whether Jeremy's transformi() is identical to what we have

In [1]:
import torch
from transformers import CLIPProcessor, CLIPModel

In [2]:
import torchvision.transforms.functional as TF,torch.nn.functional as F
from miniai.imports import *
from miniai.datasets import *
from datasets import load_dataset,load_dataset_builder

In [3]:
from huggingface_hub import hf_hub_download
import json 

In [4]:
class GPUCUDAMissing(BaseException):
    pass
    
try:
    if torch.cuda.is_available():
        print('# of GPUs available = ', torch.cuda.device_count())
    else:
        raise GPUCUDAMissing
except GPUCUDAMissing:
    print("ERROR: GPU is missing")

# of GPUs available =  1


### Jeremy's version

In [5]:
import torch
torch.manual_seed(0)

import random
random.seed(0)

In [6]:
xl,yl = 'image','label'
name = "zh-plus/tiny-imagenet"
dsd = load_dataset(name)

In [7]:
@inplace
def transformi(b): b[xl] = [(torch.ones([3,1,1])*(TF.to_tensor(o)-0.0)) for o in b[xl]]

bs = 1024
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=7)

In [8]:
dt = dls.train
xb,yb = next(iter(dt))

In [9]:
#| export
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi

def noisify(x0):
    device = x0.device
    n = len(x0)
    t = torch.rand(n,).to(x0).clamp(0,0.999)
    ε = torch.randn(x0.shape, device=device)
    abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
    xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl])

Info: In order to fix some CUDA multiprocessing issue which I don't understand. Please set num_workers=0 as shown in the cell below. Source: https://github.com/pytorch/pytorch/issues/40403#issuecomment-731782611

In [10]:
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=0)

In [11]:
def collate_ddpm(b):
    b = default_collate(b)
    (xt,t),eps = noisify(b[xl])
    return (xt,t),eps

In [12]:
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]

In [13]:
tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['valid']))

dl = dls.train

In [14]:
%%time
(xt,t),eps = b = next(iter(dl))

CPU times: user 650 ms, sys: 314 ms, total: 965 ms
Wall time: 650 ms


In [15]:
xt_ori = xt

### Our version

In [16]:
import torch
torch.manual_seed(0)

import random
random.seed(0)

In [17]:
#| export
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi

def noisify(x0):
    device = x0.device
    n = len(x0)
    t = torch.rand(n,).to(x0).clamp(0,0.999)
    ε = torch.randn(x0.shape, device=device)
    abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
    xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl])

Info: In order to fix some CUDA multiprocessing issue which I don't understand. Please set num_workers=0 as shown in the cell below. Source: https://github.com/pytorch/pytorch/issues/40403#issuecomment-731782611

In [18]:
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=0)

TODO: Double check if the transformi shown below was implemented correctly within collate_ddpm()

def transformi(b): b[xl] = [F.pad( (torch.ones([3,1,1])*TF.to_tensor(o))  , (2,2,2,2))-0.5 for o in b[xl]]

In [19]:
def collate_ddpm(b):
    b = default_collate(b)
    
    # ok, let's do the padding and shifting the range from (0,1) to (-0.5,0.5) for SD
    b_padded_n_shifted = F.pad(b[xl], (2,2,2,2))-0.5
    (xt,t),eps = noisify(b_padded_n_shifted)
    
    return (xt,t),eps

In [20]:
@inplace
def transformi(b): b[xl] = [ (torch.ones([3,1,1])*TF.to_tensor(o)) for o in b[xl]]

In [21]:
tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['valid']))

dl = dls.train

In [22]:
%%time
(xt,t),eps = b = next(iter(dl))

CPU times: user 2.02 s, sys: 350 ms, total: 2.37 s
Wall time: 624 ms


In [23]:
torch.mean(xt), torch.std(xt)

(tensor(-0.0984), tensor(0.7335))

In [24]:
torch.mean(xt_ori), torch.std(xt_ori)

(tensor(-0.0993), tensor(0.7348))