In [None]:
import torch
import torch.nn as nn
import random
from model.stylegan.model import Generator, Discriminator
from model.stylegan.dataset import MultiResolutionDataset
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from model.stylegan import lpips

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def make_noise(batch, latent_dim, n_noise, device):
    if n_noise == 1:
        return torch.randn(batch, latent_dim, device=device)

    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)

    return noises


def mixing_noise(batch, latent_dim, prob, device):
    """有prob的概率产生两个噪音张量元组，1-prob的概率产生单个噪音张量"""
    if prob > 0 and random.random() < prob:
        return make_noise(batch, latent_dim, 2, device)
    else:
        return [make_noise(batch, latent_dim, 1, device)]

In [None]:
generator = Generator(
    1024, 512, 8, channel_multiplier=2
).to(device)
ckpt = torch.load('./checkpoint/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt['g_ema'])
print(ckpt.keys())
# transform = transforms.Compose(
#     [
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize(
#             (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
#     ]
# )
# dataset = MultiResolutionDataset('./data/cartoon/lmdb/', transform, 1024)

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5):
    '''
    将输入的图片张量合并显示。
    image_tensor: (N, *size)
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    # make_grid接受的图像如果是单通道，则会将该通道复制两份形成三通道的黑白图像。
    image_grid = utils.make_grid(image_unflat[:num_images], nrow)
    data = image_grid.permute(1, 2, 0).squeeze()
    plt.figure(figsize=(12, 12))
    plt.imshow(data)
    plt.axis('off')
    plt.show()

In [None]:
fake_imgs = []
num = 16

for i in range(num):
    noise = mixing_noise(1, 512, 0.5, device)
    fake_img, _ = generator(noise)
    fake_img = (fake_img + 1) / 2
    fake_imgs.append(fake_img.detach().cpu())

show_tensor_images(torch.stack(fake_imgs), 16, nrow=4, size=fake_imgs[0].shape[1:])

In [None]:
percept = lpips.PerceptualLoss(
    model="net-lin", net="vgg", use_gpu=device.startswith("cuda"))
