#### <b>Install the StyleGAN-XL Model</b>

* For using the <b>StyleGAN-XL</b> model instance, the following libraries are required.
* After installation, the <b>[Runtime Restart]</b> is recommended.

In [1]:
# Uninstall the current (new) JAX
!pip uninstall jax jaxlib -y

# GPU front-end installation
!pip install "jax[cuda11_cudnn805]==0.3.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Downgrade Pytorch for using the StyleGAN-XL
!pip uninstall torch torchvision -y
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install timm==0.4.12 ftfy==6.1.1 ninja==1.10.2 opensimplex
!pip install dill

Found existing installation: jax 0.3.10
Uninstalling jax-0.3.10:
  Successfully uninstalled jax-0.3.10
Found existing installation: jaxlib 0.3.10+cuda11.cudnn805
Uninstalling jaxlib-0.3.10+cuda11.cudnn805:
  Successfully uninstalled jaxlib-0.3.10+cuda11.cudnn805
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_cudnn805]==0.3.10
  Using cached jax-0.3.10-py3-none-any.whl
Collecting jaxlib==0.3.10+cuda11.cudnn805 (from jax[cuda11_cudnn805]==0.3.10)
  Using cached https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.10%2Bcuda11.cudnn805-cp310-none-manylinux2014_x86_64.whl (175.7 MB)
Installing collected packages: jaxlib, jax
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.7 requires jax>=0.4.6, but you have jax 0.3.10 which is incompatible.
flax 0.7.1 requires jax>=0.4.2, but you have

In [2]:
!wget https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/EcPAsytGJQVEskKjrrfk-vkB-F2c7_6PigPqdkLR_bAsDQ?download=1 -O cifar10.pkl
!git clone https://github.com/autonomousvision/stylegan-xl

--2023-08-11 16:21:14--  https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/EcPAsytGJQVEskKjrrfk-vkB-F2c7_6PigPqdkLR_bAsDQ?download=1
Resolving postechackr-my.sharepoint.com (postechackr-my.sharepoint.com)... 13.107.136.8, 2620:1ec:8f8::8
Connecting to postechackr-my.sharepoint.com (postechackr-my.sharepoint.com)|13.107.136.8|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /personal/dongbinna_postech_ac_kr/Documents/Research/models/StyleGAN_v2/cifar10.pkl?ga=1 [following]
--2023-08-11 16:21:15--  https://postechackr-my.sharepoint.com/personal/dongbinna_postech_ac_kr/Documents/Research/models/StyleGAN_v2/cifar10.pkl?ga=1
Reusing existing connection to postechackr-my.sharepoint.com:443.
HTTP request sent, awaiting response... 200 OK
Length: 766269432 (731M) [application/octet-stream]
Saving to: ‘cifar10.pkl’


2023-08-11 16:21:47 (23.2 MB/s) - ‘cifar10.pkl’ saved [766269432/766269432]

Cloning into 'stylegan-xl'...
remote: Enumerat

#### <b>Load the StyleGAN-XL Model</b>

In [7]:
import os
import sys

sys.path.append("./stylegan-xl")

In [4]:
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
"""Generate images using pretrained network pickle."""

import re
from typing import List, Optional, Tuple, Union

import click
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy
from torch_utils import gen_utils

#----------------------------------------------------------------------------

def parse_range(s: Union[str, List]) -> List[int]:
    '''Parse a comma separated list of numbers or ranges and return a list of ints.
    Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
    '''
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
    '''Parse a floating point 2-vector of syntax 'a,b'.
    Example:
        '0,1' returns (0,1)
    '''
    if isinstance(s, tuple): return s
    parts = s.split(',')
    if len(parts) == 2:
        return (float(parts[0]), float(parts[1]))
    raise ValueError(f'cannot parse 2-vector {s}')

#----------------------------------------------------------------------------

def make_transform(translate: Tuple[float,float], angle: float):
    m = np.eye(3)
    s = np.sin(angle/360.0*np.pi*2)
    c = np.cos(angle/360.0*np.pi*2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m

#----------------------------------------------------------------------------

In [5]:
import pickle

device = torch.device('cuda')

# Load the pre-trained StyleGAN model.
print('Loading networks from ...')
with dnnlib.util.open_url("./cifar10.pkl") as f:
    G = legacy.load_network_pkl(f)['G_ema']
    G = G.eval().requires_grad_(False).to(device)

# Construct an inverse rotation/translation matrix and pass it to the generator.
rotate = 0
translate = [0,0]

if hasattr(G.synthesis, 'input'):
    m = make_transform(translate, rotate)
    m = np.linalg.inv(m)
    G.synthesis.input.transform.copy_(torch.from_numpy(m))

Loading networks from ...


#### <b>Generate the Latent Bank (Average Latents)</b>

In [6]:
# Calculate the average latent vectors for each class.
batch_sz = 1
truncation_psi = 0
seed = 1234
centroids_path = None

class_centers = []
for class_idx in range(10):
    w = gen_utils.get_w_from_seed(G, batch_sz, device, truncation_psi, seed=seed, centroids_path=centroids_path, class_idx=class_idx)
    class_centers.append(w)

Setting up PyTorch plugin "bias_act_plugin"... Done.


#### <b>Generate the Synthesized OOD Images</b>

In [8]:
import time
import random

outdir = "./DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR"
outdir_grid = "./DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR_grid"

os.makedirs(outdir, exist_ok=True)
os.makedirs(outdir_grid, exist_ok=True)

made_cnt = 0
truncation_psi = 1.0
start_time = time.time()
max_iters = 6250 # 6,250 X 8 = 50,000

for iter in range(max_iters):
    # Generate latent vectors whose size is the batch size.
    batch_sz = 8
    # The number of latent vectors to mix (Multiple Latent Mix-up).
    k = 5
    results = []
    for i in range(batch_sz):
        result = None # A latent vector for a image.
        sampled = random.sample(range(10), k) # Randomly pick k samples from 10 classes.
        for class_idx in sampled:
            seed = random.randint(0, int(2 ** 32 - 1))
            w = gen_utils.get_w_from_seed(G, 1, device, truncation_psi, seed=seed, centroids_path=centroids_path, class_idx=class_idx)
            w = w.to(device)
            # w -= class_centers[class_idx] # Disentangling Marginal Representations (DMR)
            if result == None:
                result = (w / k)
            else:
                result += (w / k)
        results.append(result)
    ws = torch.cat(results, dim=0).to(device)

    # Strengthen the represenrtations with a probability of 50%.
    """
    for i in range(len(ws)):
        random_data = random.randint(1, 2)
        if random_data == 1:
            ws[i] *= 3
    """

    # Generate the images using the latent vectors.
    imgs = gen_utils.w_to_img(G, ws, to_np=True)

    # Save the grid image.
    seed = random.randint(0, int(2 ** 32 - 1))
    PIL.Image.fromarray(gen_utils.create_image_grid(imgs), 'RGB').save(f'{outdir_grid}/seed{seed}.png')

    # Save all images individually.
    for i, img in enumerate(imgs):
        PIL.Image.fromarray(gen_utils.create_image_grid(np.expand_dims(img, axis=0)), 'RGB').save(f'{outdir}/{made_cnt}.png')
        made_cnt += 1

    if (iter + 1) % 10 == 0:
        print(f"[{iter}/{max_iters}] {time.time() - start_time:.2f} seconds elapsed.")

Setting up PyTorch plugin "filtered_lrelu_plugin"... Done.
[9/6250] 181.81 seconds elapsed.
[19/6250] 183.29 seconds elapsed.
[29/6250] 184.98 seconds elapsed.
[39/6250] 186.75 seconds elapsed.
[49/6250] 188.44 seconds elapsed.
[59/6250] 189.90 seconds elapsed.
[69/6250] 191.51 seconds elapsed.
[79/6250] 192.97 seconds elapsed.
[89/6250] 194.44 seconds elapsed.
[99/6250] 195.90 seconds elapsed.
[109/6250] 197.36 seconds elapsed.
[119/6250] 199.03 seconds elapsed.
[129/6250] 200.77 seconds elapsed.
[139/6250] 202.50 seconds elapsed.
[149/6250] 203.99 seconds elapsed.
[159/6250] 205.50 seconds elapsed.
[169/6250] 207.00 seconds elapsed.
[179/6250] 208.47 seconds elapsed.
[189/6250] 209.95 seconds elapsed.
[199/6250] 211.40 seconds elapsed.
[209/6250] 213.01 seconds elapsed.
[219/6250] 214.77 seconds elapsed.
[229/6250] 216.68 seconds elapsed.
[239/6250] 218.17 seconds elapsed.
[249/6250] 219.62 seconds elapsed.
[259/6250] 221.09 seconds elapsed.
[269/6250] 222.55 seconds elapsed.
[279/62

In [9]:
!find ./DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR -type f | wc -l
!find ./DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR_grid -type f | wc -l

50000
6250


In [11]:
!zip -r ./CIFAR10_OOD_training_images_using_MLM_without_DMR.zip ./DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/3416.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/30865.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/39284.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/20179.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/48873.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/8581.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/7796.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/49453.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/48424.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/12099.png (stored 0%)
  adding: DMR/CIFAR10_OOD_training_images_using_MLM_without_DMR/11726.png (stored 0%)
  addin