In [1]:
import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
# Removed num_to_groups
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))


In [3]:
#!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F
from pathlib import Path
from torch.optim import AdamW
from PIL import Image
import requests
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
import numpy as np
from torch.utils.data import DataLoader
import time
import pandas as pd

import dnnlib
import torchvision
from torchvision.utils import save_image

from utils.ema_pytorch import EMA
from torch.cuda.amp import autocast, GradScaler
from cleanfid import fid

from utils.losses_samples import *
from utils.blocks import *
from utils.elucidating import *
from utils.persistence import *
from utils.misc import *


import cv2
import gc
gc.collect()
torch.cuda.empty_cache()

%matplotlib inline

In [4]:
device='cuda'

In [5]:
import lpips
loss_fn_alex = lpips.LPIPS(net='alex').to(device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: D:\Python\anaconda3\lib\site-packages\lpips\weights\v0.1\alex.pth


In [6]:
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        print("init_dim:\t", init_dim)
        print("dims:\t\t", dims)
        print("in_out:\t\t", in_out)
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
            augm_dim = dim * 4
            self.aug_mlp = nn.Sequential(
                nn.Linear(12, dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            
            if ind < 1:
                self.downs.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                            Downsample(dim_out) if not is_last else nn.Identity(),
                            None
                        ]
                    )
                )
            else:
                self.downs.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                            Downsample(dim_out) if not is_last else nn.Identity(),
                            SEBlock(in_out[ind-1][1], dim_out)
                            #None
                        ]
                    )
                )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)
            
            if ind < 1:
                self.ups.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_out * 2, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_in, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                            Upsample(dim_in) if not is_last else nn.Identity(),
                            None
                        ]
                    )
                )
            else:
                self.ups.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                            block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                            Upsample(dim_in) if not is_last else nn.Identity(),
                            SEBlock(in_out[len(in_out)-ind][1], dim_in)
                            #None
                        ]
                    )
                )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time, augm):
        x = self.init_conv(x)
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        aug = self.aug_mlp(augm)
        h = []
        se_up = []
        se_down = []

        # downsample
        for block1, block2, attn, downsample, se_layer in self.downs:
            x = block1(x, t, aug)
            x = block2(x, t, aug)
            x = attn(x)
            h.append(x)
            se_down.append(x) 
            x = downsample(x)
            if se_layer is not None:
                x = se_layer(se_down.pop(0), x)
            

        # bottleneck
        x = self.mid_block1(x, t, aug)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t, aug)
        se_up.append(x)
        
        #upsample
        for block1, block2, attn, upsample, se_layer in self.ups:
            
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t, aug)
            x = block2(x, t, aug)
            x = attn(x)
            se_up.append(x)
            x = upsample(x)
            if se_layer is not None:
                x = se_layer(se_up.pop(0), x)
        
        return self.final_conv(x)


In [8]:
# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())
    return noisy_image



#SAMPLING_IMGAES
#----------------------------------

@torch.no_grad()
def p_sample(model, x, t, t_index, batch_size):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t, torch.zeros(batch_size, 12).to(device)) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

##New
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.no_grad()
def p_sample_new(self, x, t: int, x_self_cond = None):
    b, *_, device = *x.shape, x.device
    batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
    model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
    noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
    pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
    return pred_img, x_start

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape, batch_size):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in reversed(range(0, timesteps)):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i, batch_size)
        #imgs.append(img.cpu().numpy())
        imgs.append(img)
    return imgs

@torch.no_grad()
def p_sample_loop_new(self, shape, return_all_timesteps = False):
    batch, device = shape[0], self.betas.device

    img = torch.randn(shape, device = device)
    imgs = [img]

    x_start = None

    for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
        self_cond = x_start if self.self_condition else None
        img, x_start = self.p_sample(img, t, self_cond)
        imgs.append(img)

    ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

    ret = self.unnormalize(ret)
    return ret

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size), batch_size=batch_size)

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1", augm=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t, augm)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

In [9]:
from collections import namedtuple

ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])


def predict_start_from_noise(x_t, t, noise):
    return (
        extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
        extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )

def predict_noise_from_start(x_t, t, x0):
    return (
        (extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
        extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape)
    )

def predict_v(self, x_start, t, noise):
    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * noise -
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
    )

def predict_start_from_v(self, x_t, t, v):
    return (
        extract(sqrt_alphas_cumprod, t, x_t.shape) * x_t -
        extract(sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    )


def model_predictions(model, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False, objective='pred_noise'):
    model_output = model(x, t, torch.zeros(x.shape[0], 12).to(device))
    maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
    
    if objective == 'pred_noise':
        pred_noise = model_output
        x_start = predict_start_from_noise(x, t, pred_noise)
        x_start = maybe_clip(x_start)

        if clip_x_start and rederive_pred_noise:
            pred_noise = predict_noise_from_start(x, t, x_start)

    elif self.objective == 'pred_x0':
        x_start = model_output
        x_start = maybe_clip(x_start)
        pred_noise = predict_noise_from_start(x, t, x_start)

    elif self.objective == 'pred_v':
        v = model_output
        x_start = predict_start_from_v(x, t, v)
        x_start = maybe_clip(x_start)
        pred_noise = predict_noise_from_start(x, t, x_start)
        
    return ModelPrediction(pred_noise, x_start)

        

def ddim_sample(model, shape, device, total_timesteps, sampling_timesteps, ddim_sampling_eta, return_all_timesteps = False, objective='pred_noise'):
        
        batch = shape[0]

        times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device = device)
        imgs = [img]

        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
            #self_cond = x_start if self.self_condition else None
            pred_noise, x_start, *_ = model_predictions(model, img, time_cond, x_self_cond=None, clip_x_start = True, rederive_pred_noise = True, objective=objective)

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = alphas_cumprod[time]
            alpha_next = alphas_cumprod[time_next]

            sigma = ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

        #ret = unnormalize(ret)
        return ret

In [125]:
def get_run_name(dataset, path, run):
    name = dataset+"_"
    dim_mults = None
    aug_fact = None
    conv_mult = None
    
    if '(1, 2, 2)' in run:
        name += '(1,2,2)'
        dim_mults = (1,2,2)
        conv_mult = 2
    elif '(1, 2)' in run:
        name += '(1,2)'
        dim_mults = (1,2)
        conv_mult = 2
    elif '(1, 2, 4)' in run:
        name += '(1,2,4)'
        dim_mults = (1,2,4)
        conv_mult = 3
    else: 
        print("run modelSize could not be determined: ",run)
        return None
    
    name += '_'
    
    if '0.5' in run:
        name += '0.5'
        aug_fact = 0.5
    elif '0.25' in run:
        name += '0.25'
        aug_fact = 0.25
    else:
        print("run augFactor could not be determined: ",run)
        return None
    
    return name, dim_mults, aug_fact, conv_mult

def initialize_testing():
    global timesteps
    global betas
    global alphas
    global alphas_cumprod
    global alphas_cumprod_prev
    global sqrt_recip_alphas
    global sqrt_recip_alphas_cumprod
    global sqrt_recipm1_alphas_cumprod
    global sqrt_alphas_cumprod
    global sqrt_one_minus_alphas_cumprod
    global posterior_variance 

    betas = linear_beta_schedule(timesteps=1000)
        
    timesteps = 1000
    
    #calc_schedules(betas, ts, device)
    # define alphas 
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_recip_alphas_cumprod = torch.sqrt(1.0/alphas_cumprod)

    # calculations for diffusion q(x_t | x_{t-1}) and others
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    sqrt_recipm1_alphas_cumprod =  torch.sqrt(1. / alphas_cumprod - 1)

    # calculations for posterior q(x_{t-1} | x_t, x_0)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    
    
def calculate_lpips(real, sampled, intrinsic):
    lpips_min_list = []
    closest_original_list, lpips_list = [], []
    overfit = False
        
    for i, sample_image in enumerate(sampled):
        min_original_image = None
        
        lpips_scores = loss_fn_alex(sample_image.to(device), real.to(device))
        lpips_min_list.append(torch.min(lpips_scores).item())
        min_original_image = real[torch.argmin(lpips_scores)].unsqueeze(0).cpu()

        closest_original_list.append(min_original_image)
        
    lpips_min_avg = np.mean(lpips_min_list)
    lpips_min_min = np.mean(np.sort(lpips_min_list)[:10])
    lpips_fraction = lpips_min_avg/intrinsic[0]
    
    if lpips_min_min < 0.05:
        overfit = True
    
    closest_tensor = torch.stack(closest_original_list, dim=0).view(-1,3,96,96)
    
    return lpips_min_avg, lpips_min_min, lpips_fraction, overfit, closest_tensor

def calculate_clip_FID(dataset):
    if dataset == 'Grumpy':
        res= !python fcd.py --path_source "./train_data/Grumpy/img/" --path_test "./score_calculations/"
    elif dataset == 'Obama':
        res= !python fcd.py --path_source "./test_data/Obama/img/" --path_test "./score_calculations/"
    elif dataset == 'Cat':
        res= !python fcd.py --path_source "./train_data/Cat/img/" --path_test "./score_calculations/"
    elif dataset == 'Dog':
        res= !python fcd.py --path_source "./train_data/Dog/img/" --path_test "./score_calculations/"
    elif dataset == 'Panda':
        res= !python fcd.py --path_source "./train_data/Panda/img/" --path_test "./score_calculations/"
    elif dataset == 'Influ':
        res= !python fcd.py --path_source "./train_data/Influ/img/" --path_test "./score_calculations/"
    elif dataset == 'Poke':
        res= !python fcd.py --path_source "./train_data/Poke/img/" --path_test "./score_calculations/"
    elif dataset == 'Art':
        res= !python fcd.py --path_source "./train_data/Art/img/" --path_test "./score_calculations/"
    elif dataset == 'Fauvism':
        res= !python fcd.py --path_source "./train_data/Fauvism/img/" --path_test "./score_calculations/"
    elif dataset == 'Moon':
        res= !python fcd.py --path_source "./train_data/Moon/img/" --path_test "./score_calculations/"
    elif dataset == 'Anime':
        res= !python fcd.py --path_source "./train_data/Anime/img/" --path_test "./score_calculations/"
    elif dataset == 'Shells':
        res= !python fcd.py --path_source "./train_data/Shells/img/" --path_test "./score_calculations/"
    elif dataset == 'Skulls':
        res= !python fcd.py --path_source "./train_data/Skulls/img/" --path_test "./score_calculations/"
    elif dataset == 'FFHQ':
        res= !python fcd.py --path_source "./train_data/FFHQ/img/" --path_test "./score_calculations/"
    elif dataset == 'CelebA':
        print(dataset)
        res= !python fcd.py --path_source "./train_data/CelebA/img/" --path_test "./score_calculations/"
    elif dataset == 'LSUNBed':
        res= !python fcd.py --path_source "./test_data/LSUNBed/img/" --path_test "./score_calculations/"
    elif dataset == 'Dining':
        res= !python fcd.py --path_source "./train_data/Dining/" --path_test "./score_calculations/"
    elif dataset == 'Garbage':
        res= !python fcd.py --path_source "./train_data/Garbage/img/" --path_test "./score_calculations/"
    else:
        return None
    
    return res
                                  


def test_run_final():
    result_dict = dict()
    result_dict['Features'] = ['Dataset','Dim_mults','Augmentation','Iteration','FID','KID','Clip_FID','LPIPS_MinAvg','LPIPS_min','LPIPS_norm','Overfit']
    
    
    model_path = './model_weights/'
    sample_path = './score_calculations/'
    
    intrinsic_df = pd.read_csv('./intrinsic_scores.csv',delimiter=";")
    intrinsic_df.rename(columns={'Unnamed: 0':'Name'}, inplace=True)
    intrinsic_df.loc[:,'Name'][intrinsic_df.Name=='Shell'] = 'Shells'
    intrinsic_df.loc[:,'Name'][intrinsic_df.Name=='Skull'] = 'Skulls'
    
    intrinsic_df = intrinsic_df[['Name','LPIPS']]
    
    display(intrinsic_df)
    
    
    
    #Initialize global variables
    initialize_testing()
    
    #Iterate through datasets
    dataset_list = os.listdir('./model_weights/')
    dataset_test_list = os.listdir()

    #Iterate through dataset
    for dataset in dataset_list:
        if dataset == '.ipynb_checkpoints':
            continue
            
        dataset_path = r"./train_data/"+dataset
        sample_path = './score_calculations/'
        save_path = Path("./sampled_images/{}/".format(dataset))
        save_path.mkdir(exist_ok = True)
        
            
        #load datapoints of this dataset into memory
        dataloader = get_data(dataset_path, batch_size=44, image_size=96)
        full_tens = list()        
        for batch in dataloader:
            batch = batch[0].to(device)
            full_tens.append(batch)
        full_tens = torch.cat(full_tens,dim=0)
        print(full_tens.shape)
        
        
               
        for run in os.listdir(model_path+dataset+'/'):
            print("run: ", run)
            done_flag = False
            if run == '.ipynb_checkpoints':
                continue
            
            run_name, dim_mults, aug_factor, conv_mult = get_run_name(dataset=dataset, path='./model_weights/'+dataset+'/', run=run)
            complete_path = model_path + dataset +'/'+run+'/'
            if dim_mults == (1, 2, 2) and aug_factor == 0.25:
                print("cont")
                continue
            
            
            starting_point = 600000
            plus_one = True     
           
            checkpoint = 'ema-'+str(starting_point)+'.tar'
        
            while checkpoint not in os.listdir(complete_path):
                
                if plus_one:
                    starting_point -= 9999
                    plus_one = False
                else:
                    starting_point -= 10000
                checkpoint = 'ema-'+str(starting_point)+'.tar'
                if starting_point < 50000:
                    break
            if starting_point < 50000:
                    break
                    
            save_path = Path("./sampled_images/{}/{}_{}/".format(dataset, dim_mults,aug_factor))
            save_path.mkdir(exist_ok = True)
            print("start_iteration")
            print(starting_point)
            for m in range(5):
                #Create the directory for the final images
                checkpoint = 'ema-'+str(starting_point)+'.tar'
                save_path = Path("./sampled_images/{}/{}_{}/{}/".format(dataset,dim_mults,aug_factor,starting_point))
                save_path.mkdir(exist_ok = True)
                
                
                print("Checkpoint:\t",checkpoint)
                result_list = [dataset, dim_mults, aug_factor]
                #Delete all files in the sample directory
                for f in os.listdir(sample_path):
                    os.remove(os.path.join(sample_path, f))
                    
                    
                model = Unet2(
                    dim=64,
                    channels=3,
                    dim_mults=dim_mults,
                    convnext_mult = conv_mult
                    )
                    
                ema = EMA(
                        model
                    ).to(device)
                
                    
                if checkpoint not in os.listdir(complete_path+'/'):
                    if plus_one:
                        starting_point -= 9999
                        plus_one=False 
                    else:
                        starting_point -= 10000
                    continue

                ema.load_state_dict(torch.load(complete_path+'/'+checkpoint, map_location=torch.device('cuda')))

                ema.eval()
                    
                #Sample new images and normalize
                all_images_list_ema = []
                with torch.no_grad():
                    for i in range(2):
                        all_images_list_ema.append(ddim_sample(model, (16,3,96,96), device='cuda', total_timesteps=1000, sampling_timesteps=50, ddim_sampling_eta=0))

                sample_image_ema = map_interval(torch.cat(all_images_list_ema, axis=0))
                    
                for j, image in enumerate(sample_image_ema):
                    save_image((image+1)*0.5, './score_calculations/sample_{}.png'.format(j))
                    
                #Calculate LPIPS scores and closes sample
                lpips_min_avg, lpips_min_min, lpips_fraction, overfit, closest_tensor = calculate_lpips(full_tens, sample_image_ema,intrinsic_df[intrinsic_df.Name==dataset]['LPIPS'].values)
                    
                #Calculate FID scores
                print(dataset_path+"/img/")
                fid_score = fid.compute_fid(dataset_path+"/img/", sample_path,num_workers=0)
                

                    
                #Calculate KID scores
                kid_score = fid.compute_kid(dataset_path, sample_path, mode='clean',num_workers=0)
                
                #Calculate Clip-FID scores
                clip_fid = calculate_clip_FID(dataset)
                print(clip_fid)
                print(clip_fid[-1].replace(',','.'))
                try:
                    clip_fid = float(clip_fid[-1])
                except:
                    clip_fid = -1.  
                    

                save_image((sample_image_ema[0:32]+1)*0.5, "./sampled_images/{}/{}_{}/{}/sample_images.png".format(dataset,dim_mults,aug_factor,starting_point))
                save_image((closest_tensor[0:32]+1)*0.5, "./sampled_images/{}/{}_{}/{}/real_images.png".format(dataset,dim_mults,aug_factor,starting_point))
                for k in range(10):
                    save_image((sample_image_ema[k]+1)*0.5, "./sampled_images/{}/{}_{}/{}/sample_image_{}.png".format(dataset,dim_mults,aug_factor,starting_point,k))
                    save_image((closest_tensor[k]+1)*0.5, "./sampled_images/{}/{}_{}/{}/real_image_{}.png".format(dataset,dim_mults,aug_factor,starting_point,k))
                    
                    
                #Write results into a result_list
                result_list += [starting_point, fid_score, kid_score, clip_fid, lpips_min_avg, lpips_min_min, lpips_fraction, overfit]
                print(result_list)
                if plus_one:
                    starting_point -= 19999
                    plus_one=False 
                elif not plus_one and overfit:
                    starting_point -= 40000
                else:
                    starting_point -= 20000
                    
                result_dict[dataset+str(dim_mults)+str(m)+'_'+str(aug_factor)] = result_list
        
            
                result_df= pd.DataFrame.from_dict(result_dict,orient='index').transpose().set_index('Features').transpose()
                display(result_df)
                result_df.to_csv('./score_tables/all_scores.csv')
                        
                    
                
                
                
            
            
            

In [128]:
df = test_run_final()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  intrinsic_df.loc[:,'Name'][intrinsic_df.Name=='Shell'] = 'Shells'
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  intrinsic_df.loc[:,'Name'][intrinsic_df.Name=='Skull'] = 'Skulls'


Unnamed: 0,Name,LPIPS
0,Obama,0.21
1,Grumpy,0.2
2,Cat,0.21
3,Dog,0.2
4,Panda,0.22
5,Influ,0.27
6,Moon,0.25
7,Anime,0.23
8,Poke,0.24
9,Art,0.26


torch.Size([100, 3, 96, 96])
run:  CelebA96_linear_16_64_(1, 2)_1000_0.25_l20.0003_1
start_iteration
600000
Checkpoint:	 ema-600000.tar
init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]


sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

./train_data/CelebA/img/
./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x0000029376D39A30>
None
{}
compute FID between two folders
Found 200 images in the folder ./train_data/CelebA/img/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.17s/it]


Found 64 images in the folder ./score_calculations/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.72it/s]


./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x0000029376D77630>
None
{}
compute KID between two folders
Found 200 images in the folder ./train_data/CelebA


KID CelebA : 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.23s/it]


Found 64 images in the folder ./score_calculations/


KID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.09it/s]


CelebA
['Preprocess Images from : ./train_data/CelebA/img/', '', '  0%|          | 0/100 [00:00<?, ?it/s]', ' 55%|#####5    | 55/100 [00:00<00:00, 547.11it/s]', '100%|##########| 100/100 [00:00<00:00, 579.46it/s]', 'Preprocess Images from : ./score_calculations/', '', '  0%|          | 0/32 [00:00<?, ?it/s]', '100%|##########| 32/32 [00:00<00:00, 695.65it/s]', 'Infernce from CLIP', '', '', '0it [00:00, ?it/s]', '1it [00:03,  3.03s/it]', '1it [00:03,  3.03s/it]', 'Calc FCD Score:', '', '47.41742767055652']
47.41742767055652
['CelebA', (1, 2), 0.25, 600000, 150.0366949594007, 0.042815012137095175, 47.41742767055652, 0.16459925496019423, 0.10233676806092262, 0.8229962748009712, False]


Features,Dataset,Dim_mults,Augmentation,Iteration,FID,KID,Clip_FID,LPIPS_MinAvg,LPIPS_min,LPIPS_norm,Overfit
"CelebA(1, 2)0_0.25",CelebA,"(1, 2)",0.25,600000,150.036695,0.042815,47.417428,0.164599,0.102337,0.822996,False


Checkpoint:	 ema-580001.tar
init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]


sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

./train_data/CelebA/img/
./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x0000029374600C70>
None
{}
compute FID between two folders
Found 200 images in the folder ./train_data/CelebA/img/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.19s/it]


Found 64 images in the folder ./score_calculations/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.85it/s]


./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x00000293744C3F70>
None
{}
compute KID between two folders
Found 200 images in the folder ./train_data/CelebA


KID CelebA : 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.21s/it]


Found 64 images in the folder ./score_calculations/


KID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.02it/s]


CelebA
['Preprocess Images from : ./train_data/CelebA/img/', '', '  0%|          | 0/100 [00:00<?, ?it/s]', ' 70%|#######   | 70/100 [00:00<00:00, 689.57it/s]', '100%|##########| 100/100 [00:00<00:00, 709.11it/s]', 'Preprocess Images from : ./score_calculations/', '', '  0%|          | 0/32 [00:00<?, ?it/s]', '100%|##########| 32/32 [00:00<00:00, 727.20it/s]', 'Infernce from CLIP', '', '', '0it [00:00, ?it/s]', '1it [00:04,  4.14s/it]', '1it [00:04,  4.14s/it]', 'Calc FCD Score:', '', '49.94611490685234']
49.94611490685234
['CelebA', (1, 2), 0.25, 580001, 153.57381260982584, 0.05216074379663618, 49.94611490685234, 0.1812200634740293, 0.11702277511358261, 0.9061003173701465, False]


Features,Dataset,Dim_mults,Augmentation,Iteration,FID,KID,Clip_FID,LPIPS_MinAvg,LPIPS_min,LPIPS_norm,Overfit
"CelebA(1, 2)0_0.25",CelebA,"(1, 2)",0.25,600000,150.036695,0.042815,47.417428,0.164599,0.102337,0.822996,False
"CelebA(1, 2)1_0.25",CelebA,"(1, 2)",0.25,580001,153.573813,0.052161,49.946115,0.18122,0.117023,0.9061,False


Checkpoint:	 ema-560001.tar
init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]


sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/50 [00:00<?, ?it/s]

./train_data/CelebA/img/
./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x0000029374CECBF0>
None
{}
compute FID between two folders
Found 200 images in the folder ./train_data/CelebA/img/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.18s/it]


Found 64 images in the folder ./score_calculations/


FID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.67it/s]


./
./inception-2015-12-05.pt
<torch.jit.CompilationUnit object at 0x0000029376D4C7B0>
None
{}
compute KID between two folders
Found 200 images in the folder ./train_data/CelebA


KID CelebA : 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:08<00:00,  1.22s/it]


Found 64 images in the folder ./score_calculations/


KID  : 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.89it/s]


CelebA
['Preprocess Images from : ./train_data/CelebA/img/', '', '  0%|          | 0/100 [00:00<?, ?it/s]', ' 70%|#######   | 70/100 [00:00<00:00, 696.43it/s]', '100%|##########| 100/100 [00:00<00:00, 709.11it/s]', 'Preprocess Images from : ./score_calculations/', '', '  0%|          | 0/32 [00:00<?, ?it/s]', '100%|##########| 32/32 [00:00<00:00, 711.08it/s]', 'Infernce from CLIP', '', '', '0it [00:00, ?it/s]', '1it [00:04,  4.17s/it]', '1it [00:04,  4.17s/it]', 'Calc FCD Score:', '', '48.89877148784375']
48.89877148784375
['CelebA', (1, 2), 0.25, 560001, 142.77305663633018, 0.038254205612909244, 48.89877148784375, 0.17621236538980156, 0.08999713696539402, 0.8810618269490078, False]


Features,Dataset,Dim_mults,Augmentation,Iteration,FID,KID,Clip_FID,LPIPS_MinAvg,LPIPS_min,LPIPS_norm,Overfit
"CelebA(1, 2)0_0.25",CelebA,"(1, 2)",0.25,600000,150.036695,0.042815,47.417428,0.164599,0.102337,0.822996,False
"CelebA(1, 2)1_0.25",CelebA,"(1, 2)",0.25,580001,153.573813,0.052161,49.946115,0.18122,0.117023,0.9061,False
"CelebA(1, 2)2_0.25",CelebA,"(1, 2)",0.25,560001,142.773057,0.038254,48.898771,0.176212,0.089997,0.881062,False


Checkpoint:	 ema-540001.tar
init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]
Checkpoint:	 ema-530001.tar
init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]
