# B-cos Stable Diffusion

This notebook conatins some example code to use the B-cos diffusion architectures.


In [1]:
%load_ext autoreload
%autoreload 2

You will need to adjust the paths. Note that the logs and checkpoints can easily require 5-10GB of space so choose the path accordingly.

In [None]:
%run scripts/train.py --data_root "/cluster/apps/vogtlab/users/nbernold/laion-ae" --device 'cuda' --base 'configs/stable-diffusion/debugging.yaml' -t --logdir '/cluster/apps/vogtlab/users/nbernold/logs'

In [None]:
%run scripts/txt2img.py --config 'configs/stable-diffusion/v2-bcos-x0o-inference.yaml' --ckpt '/cluster/apps/vogtlab/users/nbernold/logs/x0o/checkpoints/last.ckpt' --device "cuda" --prompt 'A photo of a flamingo'

# Reconstructions and Explanations

In [None]:
import argparse, os, sys
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import notebook, tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext

from ldm.util import instantiate_from_config
import ldm.models.diffusion.ddim 
DDIMSampler = ldm.models.diffusion.ddim.DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

torch.set_grad_enabled(False)
from IPython.utils import io

import logging
log = logging.getLogger("pytorch_lightning")
log.propagate = False
log.setLevel(logging.ERROR)

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt=None, device=torch.device("cuda"), verbose=False):
    if ckpt == "None":
        ckpt = None
    if ckpt is not None:
        print(f"Loading model from {ckpt}")
        if device == torch.device("cuda"):
            pl_sd = torch.load(ckpt)
        elif device == torch.device("cpu"):
            pl_sd = torch.load(ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
        sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    if ckpt is not None:
        m, u = model.load_state_dict(sd, strict=False)
        if len(m) > 0 and verbose:
            print("missing keys:")
            print(m)
        if len(u) > 0 and verbose:
            print("unexpected keys:")
            print(u)

    if device == torch.device("cuda"):
        model.cuda()
    elif device == torch.device("cpu"):
        model.cpu()
        model.cond_stage_model.device = "cpu"
    else:
        raise ValueError(f"Incorrect device name. Received: {device}")
    model.eval()
    return model


In the following block some settings can be changed. Though, note that not necessarily all options are supported.
Make sure to download the checkpoints from Huggingface and adjust the paths.

In [None]:
seed = 42
seed_everything(seed)

model_selection = "x0"

if model_selection == "debug":
    opt_config = "configs/stable-diffusion/debugging-inference.yaml"
    opt_ckpt = '/cluster/apps/vogtlab/users/nbernold/logs/laion-ae2025-03-12T12-04-17_debugging/checkpoints/last.ckpt'

elif model_selection == "vanilla": # Vanilla model used in thesis 
    opt_config = "configs/stable-diffusion/v2-vanilla-inference.yaml"
    opt_ckpt = None # insert checkpoint path for vanilla.ckpt 

elif model_selection == "x0": # B-cos x0 model used in thesis 
    opt_config = "configs/stable-diffusion/v2-bcos-x0o-inference.yaml"
    opt_ckpt = None # insert checkpoint path for bcos_x0.ckpt 

elif model_selection == "eps": # B-cos eps model used in thesis 
    opt_config = "configs/stable-diffusion/v2-bcos-inference.yaml"
    opt_ckpt = None # insert checkpoint path for bcos_eps.ckpt 

opt_device = "cuda"
opt_outdir = "outputs/txt2img-samples"
opt_precision = "autocast"
opt_bf16 = False
opt_C, opt_H, opt_W = 6, 64, 64

config = OmegaConf.load(f"{opt_config}")
device = torch.device("cuda") if opt_device == "cuda" else torch.device("cpu")
print("Device: ", device)
model = load_model_from_config(config, f"{opt_ckpt}", device)

sampler = DDIMSampler(model, device=device)

os.makedirs(opt_outdir, exist_ok=True)
outpath = opt_outdir

batch_size = 1
n_rows = 1

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)

In [5]:
def run():
    if prompt is None:
        return
    
    with io.capture_output() as captured:
        seed_everything(seed);

    backup_seed = np.random.randint(0, 2147483647)

    base_count = len(os.listdir(sample_path))
    grid_count = len(os.listdir(outpath)) - 1
    sample_count = 0

    data = [batch_size * [prompt]]
    start_code = torch.randn([opt_n_samples, opt_C, opt_H, opt_W], device=device)
    if model.encode_noise:
        start_code[:,3:,:,:] = -start_code[:,:3,:,:]
    start_code = model.mean + start_code*model.stdev

    precision_scope = autocast if opt_precision=="autocast" or opt_bf16 else nullcontext
    with torch.no_grad(), \
        precision_scope(opt_device), \
        model.ema_scope():
            all_samples = list()
            for n in trange(1, desc="Sampling"):
                for prompts in tqdm(data, desc="data"):
                    uc = None
                    with io.capture_output() as captured:
                        seed_everything(seed);
                    if opt_scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)
                    shape = [opt_C, opt_H, opt_W]
                    
                    

                    if use_ddim:
                        samples, intermediates = sampler.sample(S=opt_steps,
                                                        conditioning=c,
                                                        batch_size=opt_n_samples,
                                                        shape=shape,
                                                        verbose=False,
                                                        unconditional_guidance_scale=opt_scale,
                                                        unconditional_conditioning=uc,
                                                        eta=opt_ddim_eta,
                                                        x_T=start_code.clone(),
                                                        log_every_t=1,
                                                        backup_seed=(t_rem, backup_seed))
                    else:
                        samples, intermediates = model.sample(c, batch_size=1, return_intermediates=True, x_T=start_code,
                        verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None, log_every_t=50)
                        intermediates = {"pred_x0" : intermediates}

                    x_samples = model.decode_first_stage(samples)
                    sample_out = x_samples.clone()
                    
                    x_samples = torch.clamp(x_samples, min=0.0, max=1.0)

                    for x_sample in x_samples:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        x_sample = x_sample[:,:,:3]
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                        base_count += 1
                        sample_count += 1

                    all_samples.append(x_samples)

                    if False:
                        for q in intermediates["pred_x0"]: 
                            q = torch.clamp(q, min=0.0, max=1.0)
                            all_samples.append(q)
                    
            if explain:
                for prompts in tqdm(data, desc="data"):
                    x_start = start_code.clone()
                    context = model.get_learned_conditioning(prompts)
                    
                    with io.capture_output() as captured:
                        seed_everything(seed);

                    (_img, _cond), _fun = sampler.sample(S=opt_steps,
                        conditioning=context,
                        batch_size=opt_n_samples,
                        shape=shape,
                        verbose=False,
                        unconditional_guidance_scale=opt_scale,
                        unconditional_conditioning=uc,
                        eta=opt_ddim_eta,
                        x_T=x_start, 
                        no_grad=False,
                        disable_tqdm=True,
                        t_remaining=(t_rem, 0) if rest else t_rem,
                        backup_seed=(t_rem, backup_seed),
                        return_eps=False)

                    with io.capture_output() as captured:
                        seed_everything(seed);
                    
                    x_prev, pred_x0 = _fun(_img, _cond)
                    x_samples = model.decode_first_stage(_img)
                    y_samples = model.decode_first_stage(pred_x0)
                    z_samples = model.decode_first_stage(x_prev)
                    z_samples = torch.clamp(z_samples, min=0.0, max=1.0)
                    x_samples = torch.clamp(x_samples, min=0.0, max=1.0)
                    y_samples = torch.clamp(y_samples, min=0.0, max=1.0)

                    all_samples.append(x_samples) # input
                    all_samples.append(z_samples) # prev
                    all_samples.append(y_samples) # pred x0

                    DLW = torch.zeros((opt_C, opt_H // patchsize, opt_W // patchsize, 77, 1024), device=device)
                    indices = [(c,i,j) for c in range(opt_C) for i in range(opt_H // patchsize) for j in range(opt_W // patchsize)]
                    with torch.enable_grad():
                        for c,i,j in notebook.tqdm(indices):
                            with io.capture_output() as captured:
                                seed_everything(seed);
                            cond = _cond.clone().requires_grad_()
                            with model.explanation_mode():
                            
                                x_prev, pred_x0 = _fun(_img, cond)
                                ins = x_prev
                                ins = model.decode_first_stage(ins)
                                mid = torch.mean(ins[:,c,patchsize*i:patchsize*(i+1), patchsize*j:patchsize*(j+1)], dim=(1,2))
                                out = mid[0]

                                out.backward(inputs=[cond])

                                dlw = cond.grad[0]
                                DLW[c,i,j] = dlw

                    x_samples = torch.einsum('chwij,ij->chw', DLW[:3], context[0])

                    

                    x_samples = x_samples.unsqueeze(0)

                    x_samples -= x_samples.min()
                    x_samples /= x_samples.abs().max()

                    compx = x_samples.clone()

                    x_samples = x_samples.repeat([1,2,1,1])

                    x_samples = torch.nn.functional.interpolate(x_samples, scale_factor=patchsize)
                    all_samples.append(x_samples)

            # additionally, save as grid
            grid = torch.stack(all_samples, 0)
            grid = rearrange(grid, 'n b c h w -> (n b) c h w')
            grid = make_grid(grid, nrow=5)

            # to image
            grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
            grid = grid[:,:,:3]
            grid = Image.fromarray(grid.astype(np.uint8))
            grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
            grid_count += 1

    if explain:
        return DLW, context, sample_out
    else:
        return None, None, sample_out

In the following block, you can adjust the settings for your generations.


In [None]:

prompt = "a penguin"

expl_id = 2  
use_ddim = True # strongly recommended
opt_scale = 10.0 # unconditional guidance scale
opt_steps = 3 # number of sampling steps
opt_n_samples = 1 
opt_ddim_eta = 0.0 
explain = True # whether the model summary should be computed
patchsize = 8 # patchsize to aggregate the output
seed = 43 # seed
t_rem = 2 # a number from 0 to opt_steps-1, 0 means explaining the last step, opt_steps-1 means explaining the the first step
rest = True # if this is true, the entire denoising process after step t_rem is explained

DLW, context, pred_x0 = run()

## Post-processing of the model
The following blocks contain various examples to create reconstructions or explanations using the model summaries.

In [None]:
# To identify word to token correspondences
import open_clip
open_clip.tokenize(prompt)
print(list(enumerate(prompt.split(" "),1)))

In [16]:
import matplotlib.pyplot as plt

In [17]:
def save_img(img, name):
        plt.imshow(img[:3].clip(0,1).permute(1,2,0).detach().cpu().numpy())
        plt.grid(False)
        plt.axis("off")
        plt.tight_layout()
        plt.savefig('outputs/expls/'+name+'.png', bbox_inches='tight', pad_inches=0)

In [28]:
def expl2img(expl, image):
    print(expl.shape, image.shape)
    dlw = expl.clone()
    contrib = (dlw*image.squeeze()).sum(0, keepdim=True)
    dlw = torch.nn.functional.relu(dlw)
    dmg = dlw.clone()
    dmg = dmg[:3]/(dmg[:3]+dmg[3:]+1e-12)
    alpha = torch.sqrt(torch.sum(contrib**2, dim = 0)).unsqueeze(0)
    alpha = torch.where(contrib < 0, 1e-12, alpha)
    smooth = 0
    if smooth:
            alpha = torch.nn.functional.avg_pool2d(alpha, smooth, stride=1, padding=(smooth-1) // 2)
    alpha = (alpha / torch.quantile(alpha, 0.99)).clip(0, 1)
    #dmg = torch.cat([dmg, alpha], dim = 0).permute(1,2,0)
    dmg = (dmg*alpha[0][None]).permute(1,2,0)
    return dmg

In [None]:
context = model.get_learned_conditioning(prompt)[0]
pred = pred_x0[0]
id = expl_id

name = "expl"

DR = torch.einsum('chwji->chwj', DLW.abs()).sum(dim=(0)).reshape(64*64//patchsize//patchsize,77)
asrt = torch.argsort(DR, dim = 1)
asrt.shape
DR = torch.empty(64*64//patchsize//patchsize, 77, dtype = torch.long, device=device).scatter_ (1, asrt, torch.arange (77).to(device).repeat (64*64//patchsize//patchsize, 1)).reshape(64//patchsize,64//patchsize,77)
DRx = 1-DR[:,:,id:id+1].repeat(1,1,3)/76
print('Strength', DRx.abs().max())
DRx /= DRx.abs().max()
#plt.imshow(DRx.clip(0,1).cpu().numpy())
DRx = DRx.permute(2,0,1)

# Full Reconstruction
Rec = torch.einsum('chwji,ji->chw', DLW, context)
save_img(Rec[:3]/Rec[:3].max(), name+"_"+str(seed)+"_nrec")
Rec = Rec[:3]/(Rec[:3]+Rec[3:])
save_img(Rec, name+"_"+str(seed)+"_rec")

# Partial Reconstruction
Rec = torch.einsum('chwi,i->chw', DLW[...,id,:], context[id])
save_img(Rec[:3]/Rec[:3].max(), name+"_"+str(seed)+"_partial_rec")
Rec = Rec[:3]/(Rec[:3]+Rec[3:])
save_img(Rec, name+"_"+str(seed)+"_partial_nrec")

# Explanation
Rec = torch.einsum('chwi->chw', DLW[...,id,:].abs())
Rec = Rec[:3]
Rec /= Rec.max()
save_img(Rec, name+"_"+str(seed)+"_expl")

# Masked Explanation
Rec *= DRx
save_img(Rec, name+"_"+str(seed)+"_masked_expl")

print("Done")

In [7]:
context = model.get_learned_conditioning(prompt)[0]
pred = pred_x0[0]

In [None]:
# Sample
Rec = torch.einsum('chwji,ji->chw', DLW, context).permute(1,2,0)
Rec = Rec[...,:3]
Rec /= Rec.max()

plt.imshow(pred[:3].permute(1,2,0).cpu().numpy())

In [None]:
# Unnormalized Reconstruction
Rec = torch.einsum('chwji,ji->chw', DLW, context).permute(1,2,0)
Rec = Rec[...,:3]
Rec /= Rec.max()

plt.imshow(Rec.cpu().numpy())

In [None]:
# Normalized Reconstruction
Rec = torch.einsum('chwji,ji->chw', DLW, context).permute(1,2,0)
Rec = Rec[...,:3]/(Rec[...,:3]+Rec[...,3:])

plt.imshow(Rec.cpu().numpy())

In [None]:
# Unnormalized Partial Reconstruction
id = id # potentially change token
Rec = torch.einsum('chwi,i->chw', DLW[...,id,:], context[id]).permute(1,2,0)
Rec = Rec[...,:3]
Rec /= Rec.max()

plt.imshow(Rec.cpu().numpy())

In [None]:
# Unnormalized Partial Reconstruction
id = id # potentially change token
Rec = torch.einsum('chwi,i->chw', DLW[...,id,:], context[id]).permute(1,2,0)
Rec = Rec[...,:3]/(Rec[...,:3]+Rec[...,3:])

plt.imshow(Rec.cpu().numpy())

In [None]:
# Explanation
id = id # potentially change token
Expl = torch.einsum('chwi->chw', DLW[...,id,:].abs()).permute(1,2,0)
Expl = Expl[...,:3]
Expl /= Expl.max()
plt.imshow(Expl.cpu().numpy())