In [None]:
import argparse
import math
from torch import nn
import torch
from torchvision import utils
from PIL import Image
import matplotlib.pyplot as plt
from random import randint
import pickle
!wget https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/master/model.py
from model import Generator,PixelNorm,EqualLinear

class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, n_mlp=8):
        super().__init__()
        self.generator = Generator(code_dim)
        layers = [PixelNorm()]
        for i in range(n_mlp):
            layers.append(EqualLinear(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))
        self.style = nn.Sequential(*layers)

    def forward(self,input,noise=None,step=0,alpha=-1,mean_style=None,style_weight=0,mixing_range=(-1, -1)):
        styles = []
        if type(input) not in (list, tuple):
            input = [input]
        for i in input:
            styles.append(self.style(i))
        batch = input[0].shape[0]
        if noise is None:
            noise = []
            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
        if mean_style is not None:
            styles_norm = []
            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))
            styles = styles_norm
        return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)

    def mean_style(self, input):
        style = self.style(input).mean(0, keepdim=True)
        return style

--2020-05-01 04:22:30--  https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/master/model.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17017 (17K) [text/plain]
Saving to: ‘model.py’


2020-05-01 04:22:30 (35.8 MB/s) - ‘model.py’ saved [17017/17017]



In [None]:
@torch.no_grad()
def get_mean_style(generator, device):
    mean_style = None
    for i in range(10):
        style = generator.mean_style(torch.randn(1024, 512).to(device))
        if mean_style is None: mean_style = style
        else: mean_style += style
    mean_style /= 10
    return mean_style
    
@torch.no_grad()
def style_mixing(generator, step, mean_style, n_source, n_target, device, scode):
    source_code = scode
    target_code = torch.randn(n_target, 512).to(device)    
    shape = 4 * 2 ** step
    alpha = 1
    target_image = generator(target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7)
    for i in range(n_target):
        image = generator([target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],step=step,alpha=alpha,mean_style=mean_style,style_weight=0.7,mixing_range=(0, 1))   
    return image

In [None]:
size = 256
path = '/content/drive/My Drive/CV/stylegan-256px-new.model'   ## Location of pre-trained model    
device = 'cuda'
generator = StyledGenerator(512).to(device)
generator.load_state_dict(torch.load(path)['g_running'])
generator.eval()
mean_style = get_mean_style(generator, device)
step = int(math.log(size, 2)) - 2

In [None]:
# Generate and save 50 source images along with their rand pt files
for i in range(50):
  src_code = torch.randn(1, 512).to(device)
  torch.save(src_code, f'/content/drive/My Drive/CV/srcim_pytorch_stylegan_256/src256_{i}.pt')
  src_im = generator(src_code, step=step, alpha=1, mean_style=mean_style, style_weight=0.7)
  utils.save_image(src_im, f'/content/drive/My Drive/CV/srcim_pytorch_stylegan_256/src256_{i}.png', nrow=1 + 1, normalize=True, range=(-1, 1))

# Generate 2500 style-mixed images with saved rand pt of src images
rnd_list = []
for j in range(2500):
  randm = randint(0,49)
  rnd_list.append(randm)
  scode = torch.load('/content/drive/My Drive/CV/srcim_pytorch_stylegan_256/src256_'+str(randm)+'.pt')
  img = style_mixing(generator, step, mean_style, 1, 1, device,scode)
  utils.save_image(img, f'/content/drive/My Drive/CV/dataset_ffhq_sgan_256/final_{j}.png', nrow=1 + 1, normalize=True, range=(-1, 1))

# Save src img numbers for the 2500 images into text file
with open('/content/drive/My Drive/CV/rndforall.txt', 'wb') as file:
  pickle.dump(rnd_list,file)

In [None]:
# Assign gender,age manually to 50 src images, map it to final and dump to text file

# 512x512 source gender,age values
# src_gndr = [1,0,1,0,1,0,0,0,0,0,1,1,0,1,0,1,1,1,1,1,1,1,0,0,0,1,0,1,0,0,1,1,1,0,1,1,1,0,0,0,0,0,0,0,1,0,1,0,0,0]  ## 0-female 1-male
# src_age = [1,1,1,1,1,1,1,1,0,1,0,0,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,0]  ## 0-child 1-adult

#256x256 source gender,age values
src_gndr = [0,0,0,1,0,1,1,0,1,0,0,0,0,1,0,0,0,1,0,1,1,0,0,0,1,1,1,0,0,0,1,1,1,0,1,1,1,0,0,1,1,1,1,1,0,0,0,1,1,0]
src_age = [1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,0,0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,1,1,1,1,1,1,0,0,1,0]

gndr_fnl = [src_gndr[itm] for itm in rnd_list]
age_fnl = [src_age[itm] for itm in rnd_list]
with open('/content/drive/My Drive/CV/gender_ffhq256.txt', 'wb') as file:
  pickle.dump(gndr_fnl,file)
with open('/content/drive/My Drive/CV/age_ffhq256.txt', 'wb') as file:
  pickle.dump(age_fnl,file)