Let's use another dataset than mnist to see if everything works the same

In [2]:
import torch
import numpy as np
import random
from torch import nn,tensor
import matplotlib.pyplot as plt
from datasets import load_dataset
from torchmetrics.classification import MulticlassAccuracy 
import torchvision.transforms.functional as TF
from torch.optim.lr_scheduler import OneCycleLR
from torch.nn import init
import fastcore.all as fc
from lib import *
from pathlib import Path
from torch.optim.lr_scheduler import OneCycleLR
from diffusers import UNet2DModel
from accelerate import Accelerator

set_seed(42)
device = "mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
x,y = 'img','label'

@inplace
def transformi(b): 
    b[x] = [F.pad(TF.to_tensor(o), (2,2,2,2)) for o in b[x]]

dsd = load_dataset("cifar10")
tds = dsd.with_transform(transformi)

betamin,betamax,n_steps = 0.0001,0.02,1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

Using the latest cached version of the module from /home/marconobile/.cache/huggingface/modules/datasets_modules/datasets/cifar10/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4 (last modified on Mon Dec 11 19:23:00 2023) since it couldn't be found locally at cifar10., or remotely on the Hugging Face Hub.


In [3]:
bs = 512
dls = DataLoaders.from_datasetDict(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))

In [4]:
def show_img(im, ax=None, figsize=None, title=None, **kwargs):
    if ax is None:
        _, ax = plt.subplots(figure=figsize)
    if im.shape[0] == 3:
        ax.imshow(im.permute(1, 2, 0), **kwargs)
    else:
        ax.imshow(im, **kwargs)
    if title:
        ax.set_title(title)
    ax.axis('off')
    return ax

def showImgGroup(data, grid=(3,3), **kwargs):
    fig, axs = plt.subplots(grid[0], grid[1])
    imgs = data[: (grid[0]* grid[1])]
    for ax, img in zip(axs.flat, imgs):
        show_img(img.squeeze(), ax)

class AccelerateCB(TrainCB):
    order = DeviceCB.order+10
    def __init__(self, n_inp=1, mixed_precision="fp16"):
        super().__init__(n_inp=n_inp)
        self.acc = Accelerator(mixed_precision=mixed_precision)

    def before_fit(self, learn):
        learn.model,learn.opt,learn.dls.train,learn.dls.valid = self.acc.prepare(
            learn.model, learn.opt, learn.dls.train, learn.dls.valid)

    def backward(self, learn): 
        sealf.acc.backward(learn.loss)


def noisify(x0, ᾱ):
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
    return xt, t.to(device), ε

def collate_ddpm(batch):
    '''
    input: [(xi,yi), (xj,yj), ...]
    output: (tensor(X_batch_with noise), tensor(noise))
    '''
    return noisify(default_collate(batch)[x], alphabar)

def dl_ddpm(ds): 
    return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

class DDPMCB2(Callback):
    def after_predict(self, learn): 
        learn.preds = learn.preds.sample

def init_ddpm(model):
    for o in model.down_blocks:
        for p in o.resnets:
            p.conv2.weight.data.zero_()
            for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)

    for o in model.up_blocks:
        for p in o.resnets: p.conv2.weight.data.zero_()

    model.conv_out.weight.data.zero_()

class UNet(UNet2DModel):
    def forward(self, x): 
        return super().forward(*x).sample

In [5]:
# l = noisify(xb[:10], alphabar)
# noisy_data = [l[0][i] for i in range(10)]
# titles = [f'{i.data}' for i in  l[1]]
# print(titles)
# showImgGroup(noisy_data, grid=(3,3), titles=titles)

In [10]:
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

lr = 1e-2
epochs = 8
tmax = epochs * len(dls.train)
scheduler = partial(torch.optim.lr_scheduler.OneCycleLR, max_lr = lr, total_steps=tmax)
opt_func = partial(optim.Adam, eps=1e-5)

model = UNet(in_channels=3, out_channels=3, block_out_channels=(32, 64, 128, 256), norm_num_groups=8)
init_ddpm(model)

cbs = [DeviceCB(), ProgressCB(plot=True), MetricCB(), BatchSchedCB(scheduler)] 

#AccelerateCB(n_inp=2) # HERE U DEFINE THE NUMBER OF INPUTS OF THE MODEL
#DDPMCB2()
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)

In [11]:
learn.fit(epochs)

AttributeError: loss

In [None]:
mdl_path = Path('models')
mdl_path.mkdir(exist_ok=True)
#learn.model = torch.load(mdl_path/'fashion_mnist_ddpmMineCPU.pkl')

In [None]:
@torch.no_grad()
def sample(model, sz):
    ps = next(model.parameters())
    x_t = torch.randn(sz).to(ps)
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1-alphabar[t]
        b̄_t1 = 1-ᾱ_t1
        noise = model((x_t, t_batch))
        x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        preds.append(x_t.float().cpu())
    return preds

In [None]:
samples = sample(model, (1, 3, 32, 32))
len(samples)

In [None]:
# showImgGroup([samples[-1][i] for i in range(n_samples)], (3,3))

In [None]:
plt.imshow(samples[-1].squeeze().permute(1, 2, 0))

In [None]:
samples[-1].shape