In [6]:
import copy
import torch
import torch.nn.utils.prune as prune

In [1]:
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

model_id = "stabilityai/stable-diffusion-2"

# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.enable_attention_slicing() # For low GPU RAM

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
    
image.save("astronaut_rides_horse.png")

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]



  0%|          | 0/50 [00:00<?, ?it/s]

In [5]:
print(f'UNET: {sum(p.numel() for p in pipe.unet.parameters() if p .requires_grad):_}')
print(f'VAE: {sum(p.numel() for p in pipe.vae.parameters() if p .requires_grad):_}')
print(f'CLIP: {sum(p.numel() for p in pipe.text_encoder.parameters() if p .requires_grad):_}')

UNET: 865_910_724
VAE: 83_653_863
CLIP: 340_387_840


In [None]:
for amount in amounts:
    print(amount)
    lofi_pipe = copy.deepcopy(pipe)
    lofi_pipe = lofi_pipe.to("cuda")
    unet_params = (
        (lofi_pipe.unet.down_blocks[0].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[0].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[1].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[1].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[2].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[2].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[3].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.down_blocks[3].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[0].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[0].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[1].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[1].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[2].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[2].resnets[1].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[3].resnets[0].conv1, 'weight'),
        (lofi_pipe.unet.up_blocks[3].resnets[1].conv1, 'weight'),
    )
    vae_params = (
        (lofi_pipe.vae.decoder.up_blocks[0].resnets[0].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[0].resnets[1].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[0].resnets[2].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[1].resnets[0].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[1].resnets[1].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[1].resnets[2].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[2].resnets[0].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[2].resnets[1].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[2].resnets[2].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[3].resnets[0].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[3].resnets[1].conv1, 'weight'),
        (lofi_pipe.vae.decoder.up_blocks[3].resnets[2].conv1, 'weight'),
    )

    params = unet_params + vae_params
    for layer, name in params:
        prune.ln_structured(layer, name=name, amount=amount, n=norm, dim=dim)
    num_weights = sum(torch.count_nonzero(layer.weight) for layer, name in params)
    compression = round((num_weights / total).item(), 2)
    print('weights: ', num_weights) 
    print('compressed: ', compression)
  
    for prompt in prompts:
        generator = torch.Generator("cuda").manual_seed(seed)
        images = lofi_pipe([prompt] * 3, num_inference_steps=50, generator=generator, height=512, width=512).images
        grid = image_grid(images, rows=3, cols=1)
        grid.save(f"figures/all/compressed-{amount}-{prompt}-{compression}.png")
        plt.figure(figsize=(30, 12))
        plt.title(f'prune={amount}\tcompressed={compression}')
        plt.imshow(grid)
        plt.show()


In [None]:
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return gri