In [1]:
from torchmetrics.image.fid import FrechetInceptionDistance
import pandas as pd
import cv2
import numpy as np
from PIL import Image

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from copy import deepcopy
from omegaconf.dictconfig import DictConfig


import torch
import torch.nn as nn
from matplotlib import pyplot as plt

from kandinsky2 import get_kandinsky2

from kandinsky2.model.utils import get_obj_from_str
from kandinsky2.model.resample import UniformSampler
from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer

from kandinsky2.model.text_encoders import TextEncoder
from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ
from train_utils.trainer_2_1_uclip import train_unclip
from kandinsky2.model.resample import UniformSampler
from omegaconf import OmegaConf
import clip

from kandinsky2 import CONFIG_2_1, Kandinsky2_1 

In [3]:
torch.cuda.is_available()

True

In [4]:
def count_fids(orig, generated):
    orig = orig.copy()
    generated = generated.copy()
    
    fid = FrechetInceptionDistance()

    real_ = torch.as_tensor(orig.transpose(2,0,1))
    real = torch.stack([real_, real_], dim=0)

    fake_ = torch.as_tensor(generated.transpose(2,0,1))
    fake = torch.stack([fake_, fake_], dim=0)

    fid.reset()
    fid.update(real, real=True)
    fid.update(fake, real=False)

    return fid.compute()

In [15]:
def count_fids_for_dataset(model, df):
    for idx in range(len(df)):
        caption, orig_path = df.loc[idx]

        orig = cv2.imread(orig_path)
        height, width, _ = orig.shape
        generated = model.generate_text2img(
                caption,    
                num_steps=100,
                batch_size=1, 
                guidance_scale=4,
                h=height, w=width,
                sampler='p_sampler', 
                prior_cf_scale=4,
                prior_steps="5"
        )[0]
        generated = np.asarray(generated)

        fid = count_fids(orig, generated)
        print(f'FID: {fid}')

# Init dataset for counting fids

In [6]:
df = pd.read_csv('file_mononoke.csv')

In [7]:
df.head()

Unnamed: 0,caption,image_name
0,A man in a blue shirt with a bow in his hands ...,datasets/princess_mononoke/000661.jpg
1,A path that stretches along the reservoir and ...,datasets/princess_mononoke/000176.jpg
2,A man carries a bundle of bamboo on his should...,datasets/princess_mononoke/000161.jpg
3,Five women in kimonos and headdresses with bow...,datasets/princess_mononoke/000330.jpg
4,a girl with red drawings on her face in white ...,datasets/princess_mononoke/000763.jpg


# Load default model

In [9]:
model = get_kandinsky2('cuda', task_type='text2img', cache_dir = "tmp/kand2/kandinsky2",
                       model_version='2.1', use_flash_attention=False)



making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


# Load Miyazaki model

In [10]:
CONFIG = DictConfig(deepcopy(CONFIG_2_1))

In [11]:
CONFIG['params_path'] = 'tmp/kand2/kandinsky2/2_1/decoder_fp16.ckpt'
CONFIG['tokenizer_name'] = 'tmp/kand2/kandinsky2/2_1/text_encoder'
CONFIG['image_enc_params']['ckpt_path'] = 'tmp/kand2/kandinsky2/2_1/movq_final.ckpt'
CONFIG['text_enc_params']['model_path'] = 'tmp/kand2/kandinsky2/2_1/text_encoder'
CONFIG["prior"]["clip_mean_std_path"] = 'tmp/kand2/kandinsky2/2_1/ViT-L-14_stats.th'

In [12]:
miyazaki_model = Kandinsky2_1(CONFIG, 
                              'output/unclip/model_final_Miyazaki.ckpt',
                              'output/prior/model_final_Miyazaki.ckpt', 
                              device = 'cuda'
)

making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


# Perform FIDs computation

In [16]:
count_fids_for_dataset(model, df)

100%|██████████| 100/100 [00:15<00:00,  6.59it/s]


FID: 398.29681396484375


100%|██████████| 100/100 [00:15<00:00,  6.61it/s]


FID: 293.22406005859375


100%|██████████| 100/100 [00:15<00:00,  6.60it/s]


FID: 340.4888916015625


100%|██████████| 100/100 [00:15<00:00,  6.60it/s]


FID: 339.2647705078125


100%|██████████| 100/100 [00:15<00:00,  6.59it/s]


FID: 411.7514343261719


 90%|█████████ | 90/100 [00:13<00:01,  6.51it/s]


KeyboardInterrupt: 

In [17]:
count_fids_for_dataset(miyazaki_model, df)

100%|██████████| 100/100 [00:15<00:00,  6.59it/s]


FID: 328.6327819824219


100%|██████████| 100/100 [00:15<00:00,  6.57it/s]


FID: 267.53997802734375


100%|██████████| 100/100 [00:15<00:00,  6.58it/s]


FID: 351.5025634765625


100%|██████████| 100/100 [00:15<00:00,  6.55it/s]


FID: 262.13018798828125


100%|██████████| 100/100 [00:15<00:00,  6.57it/s]


FID: 278.8929443359375


100%|██████████| 100/100 [00:15<00:00,  6.60it/s]


FID: 371.9193420410156


 22%|██▏       | 22/100 [00:03<00:12,  6.27it/s]


KeyboardInterrupt: 