In [None]:
%cd ../

In [None]:
# import dnnlib
import pickle as pkl
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random

from PIL import Image
from pathlib import Path
from torchvision.transforms import Resize

from core.utils.example_utils import Inferencer, vstack_with_lines, hstack_with_lines, to_im
from core.utils.image_utils import construct_paper_image_grid
from core.utils.reading_weights import read_weights
from core.uda_models import OffsetsTunningGenerator
from core.sparse_models import SparsedModel

from examples.draw_util import weights

In [None]:
device = 'cuda:0'

g = OffsetsTunningGenerator(
    checkpoint_path='pretrained/StyleGAN2/stylegan2-ffhq-config-f.pt'
).patch_layers('s_delta').to(device)

In [None]:
percentiles = [0.7, 0.8, 0.9, 0.9, 0.95]


domain = 'sketch'
bs = 4
truncation = 0.8

model = SparsedModel(device, read_weights(weights[domain]))
z = [torch.randn(bs, 512).to(device)]
resize = Resize(256)


images = []
for perc in percentiles:
    offsets = model.pruned_offsets(perc)
    im, _ = g(z, offsets=offsets, truncation=truncation)
    images.append(to_im(resize(im.detach()), padding=0))
    
    
orig_ims, _ = g(z, truncation=truncation)
images.append(to_im(resize(orig_ims.detach()), padding=0))

In [None]:
ext = 2

plt.figure(figsize=(bs * ext, (len(percentiles) + 1) * ext))
plt.imshow(vstack_with_lines(images, 10))

plt.xticks(
    np.arange(128, bs * 256, 256), 
    labels=[f"id {k}" for k in range(bs)]
)


plt.yticks(
    np.arange(128, len(percentiles) * (256 + 10) + 256, 256 + 10),
    labels=[f"{p * 100}% pruned" for p in percentiles] + ['Original']
)


# plt.axis('off')
plt.show()