In [1]:
from dataset import SourceImgDataset
from torch.utils.data import DataLoader

# cat
dataset = SourceImgDataset('/home/jupyter/datasphere/project/edm/datasets/afhqv2-64x64', lbl_val=0)
train_loader = DataLoader(dataset, batch_size=256, shuffle=False)

# wild
tgt_dataset = SourceImgDataset('/home/jupyter/datasphere/project/edm/datasets/afhqv2-64x64', lbl_val=1)
tgt_loader = DataLoader(tgt_dataset, batch_size=256, shuffle=False)

In [2]:
%cd edm
from dnnlib import util
import torch_utils
%cd ..
import pickle

# load pretrained diffusion
device = 'cuda:0'
with util.open_url('wild64.pkl') as f:
    net = pickle.load(f)['ema'].to(device)

/home/jupyter/work/resources/edm
/home/jupyter/work/resources


In [None]:
# SDEDIT

%cd edm
from fid import calculate_inception_stats, calculate_fid_from_inception_stats
from dnnlib.util import open_url
%cd ..

from metrics import compute_metrics_and_save_imgs, save_model_samples
import json

sampling_params = {
    'device': 'cuda',
    'sigma_min': 0.02,
    'sigma_max': 10.0,
    'num_steps': 10,
    'rho': 7.0,
    'vis_steps': 1,
    'stochastic': False,
    'cfg': 0
}

exp_results = {}

batch = 512
num_samples = 1024
orig_path = 'orig_imgs'
gen_path = 'gen_imgs'
save_model_samples(orig_path, tgt_loader, num_samples)
mu_real, sigma_real = calculate_inception_stats(image_path=orig_path, num_expected=num_samples, max_batch_size=batch)

sigmas = [5, 10, 25, 40]
steps = [18, 32, 50]

for sigma in sigmas:
    for step in steps:
        exp_name = 'sigma={:.1f};n_steps={}'.format(sigma, step)

        sampling_params['sigma_max'] = sigma
        sampling_params['num_steps'] = step
        res_json = compute_metrics_and_save_imgs(gen_path, train_loader, 'sdedit', net, sampling_params, to_see=num_samples)
        
        mu_gen, sigma_gen = calculate_inception_stats(image_path=gen_path, num_expected=num_samples, max_batch_size=batch)
        fid = calculate_fid_from_inception_stats(mu_gen, sigma_gen, mu_real, sigma_real)
        
        res_json['FID'] = fid
        exp_results[exp_name] = res_json
        
        with open('sdedit_results.json', 'w') as f:
            json.dump(exp_results, f)

/home/jupyter/work/resources/edm
/home/jupyter/work/resources


1024 images saved: 100%|██████████| 1024/1024 [00:03<00:00, 277.48it/s]


Loading Inception-v3 model...
Loading images from "orig_imgs"...
Calculating statistics for 1024 images...


100%|██████████| 2/2 [00:02<00:00,  1.47s/batch]


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /tmp/xdg_cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:09<00:00, 60.1MB/s] 


Loading model from: /home/jupyter/.local/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


 14%|█▎        | 3/22 [09:37<1:00:59, 192.61s/it]

Loading Inception-v3 model...





Loading images from "gen_imgs"...
Calculating statistics for 1024 images...


100%|██████████| 2/2 [00:02<00:00,  1.28s/batch]


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/jupyter/.local/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


 14%|█▎        | 3/22 [17:07<1:48:26, 342.45s/it]

Loading Inception-v3 model...





Loading images from "gen_imgs"...
Calculating statistics for 1024 images...


100%|██████████| 2/2 [00:02<00:00,  1.26s/batch]


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/jupyter/.local/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


  0%|          | 0/22 [00:00<?, ?it/s]

In [None]:
#ILVR

%cd edm
from fid import calculate_inception_stats, calculate_fid_from_inception_stats
from dnnlib.util import open_url
%cd ..

from metrics import compute_metrics_and_save_imgs, save_model_samples
import json

sampling_params = {
    'device': 'cuda',
    'sigma_min': 0.02,
    'sigma_max': 80.0,
    'num_steps': 10,
    'rho': 7.0,
    'vis_steps': 1,
    'stochastic': False,
    'cfg': 0,
    'scale_factor': 2
}

exp_results = {}

batch = 512
num_samples = 1024
#orig_path = 'orig_imgs'
gen_path = 'gen_imgs'
#save_model_samples(orig_path, tgt_loader, num_samples)
#mu_real, sigma_real = calculate_inception_stats(image_path=orig_path, num_expected=num_samples, max_batch_size=batch)

Ns = [4, 8, 16, 32]
steps = [18, 32, 50]

for N in Ns:
    for step in steps:
        exp_name = 'N={};n_steps={}'.format(N, step)

        sampling_params['scale_factor'] = N
        sampling_params['num_steps'] = step
        res_json = compute_metrics_and_save_imgs(gen_path, train_loader, 'ilvr', net, sampling_params, to_see=num_samples)
        
        mu_gen, sigma_gen = calculate_inception_stats(image_path=gen_path, num_expected=num_samples, max_batch_size=batch)
        fid = calculate_fid_from_inception_stats(mu_gen, sigma_gen, mu_real, sigma_real)
        
        res_json['FID'] = fid
        exp_results[exp_name] = res_json
        
        with open('ilvr_results.json', 'w') as f:
            json.dump(exp_results, f)

In [None]:
# EGSDE

# Uncomment this to train classifier model

# from train_classifier import train_loop
# from dataset import CombinedImgDataset
# from torch.utils.data import DataLoader

# dataset = CombinedImgDataset('/home/jupyter/datasphere/project/edm/datasets/afhqv2-64x64')

# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [9634, 1000])

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# class_model = class_model.to('cuda')
# classifier_model = train_loop(train_loader, val_loader, class_model, 10)

In [3]:
%cd guided-diffusion
from guided_diffusion.unet import EncoderUNetModel
from guided_diffusion.nn import timestep_embedding
%cd ..

class EGClassifier(EncoderUNetModel):
    def forward(self, x, timesteps):
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
        res = self.middle_block(h, emb)
        h_res = res.type(x.dtype)
        return res, self.out(h_res)


/home/jupyter/work/resources/guided-diffusion
/home/jupyter/work/resources


In [4]:
class_model = EGClassifier(image_size=64,
                           in_channels=3,
                           out_channels=1000,
                           model_channels=128,
                           channel_mult=(1,2,3,4),
                           attention_resolutions=[32, 16, 8],
                           use_scale_shift_norm=True,
                           resblock_updown=True,
                           num_res_blocks=4,
                           num_head_channels=64,
                           pool='attention'
                           )

In [5]:
import torch
import torch.nn as nn

pretrained_dict = torch.load('64x64_classifier.pt', map_location="cuda")
model_dict = class_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
class_model.load_state_dict(model_dict, strict=False)

class_model.out[2].c_proj = nn.Conv1d(512, 2,  kernel_size=(1,), stride=(1,))

In [6]:
pretrained_class = torch.load('checkpoints/checkpoint_epoch_9', map_location="cuda")
class_model.load_state_dict(pretrained_class)

# need to lower batch size
train_loader = DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
%cd edm
from fid import calculate_inception_stats, calculate_fid_from_inception_stats
from dnnlib.util import open_url
%cd ..

from metrics import compute_metrics_and_save_imgs, save_model_samples
import json

sampling_params = {
    'device': 'cuda',
    'sigma_min': 0.02,
    'sigma_max': 80.0,
    'num_steps': 10,
    'rho': 7.0,
    'vis_steps': 1,
    'stochastic': False,
    'cfg': 0,
    'scale_factor': 2,
    'class_model': class_model,
    'l_1': 2,
    'l_2': 500
}

exp_results = {}

batch = 512
num_samples = 1024
#orig_path = 'orig_imgs'
gen_path = 'gen_imgs'
#save_model_samples(orig_path, tgt_loader, num_samples)
#mu_real, sigma_real = calculate_inception_stats(image_path=orig_path, num_expected=num_samples, max_batch_size=batch)

sigmas = [10, 25]
Ns = [8, 16, 32]
steps = [18, 32]

for sigma in sigmas:
    for N in Ns:
        for step in steps:
            exp_name = 'sigma={};N={};n_steps={}'.format(sigma, N, step)

            sampling_params['scale_factor'] = N
            sampling_params['num_steps'] = step
            sampling_params['sigma_max'] = sigma
            res_json = compute_metrics_and_save_imgs(gen_path, train_loader, 'egsde', net, sampling_params, to_see=num_samples)

            mu_gen, sigma_gen = calculate_inception_stats(image_path=gen_path, num_expected=num_samples, max_batch_size=batch)
            fid = calculate_fid_from_inception_stats(mu_gen, sigma_gen, mu_real, sigma_real)

            res_json['FID'] = fid
            exp_results[exp_name] = res_json

            with open('egsde_results.json', 'w') as f:
                json.dump(exp_results, f)

/home/jupyter/work/resources/edm
/home/jupyter/work/resources
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /tmp/xdg_cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 107MB/s]  


Loading model from: /home/jupyter/.local/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


  2%|▏         | 4/174 [01:27<1:01:31, 21.71s/it]