In [1]:
import os
import sys
import numpy as np
import shutil
from typing import Optional, Tuple, List, Dict
from tqdm import tqdm

import torch
import torch.nn.functional as F

from lightning.pytorch import LightningModule

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel

from modules.evaluator import Evaluator

from cleanfid import fid
import timm
import shutil
import wandb
import torch.nn as nn

    PyTorch 2.0.0+cu118 with CUDA 1108 (you have 1.13.0+cu116)
    Python  3.8.16 (you have 3.8.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
2023-05-26 10:20:57.870079: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-26 10:20:58.018085: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-26 10:20:58.516770: W tenso

In [3]:
unet = UNet2DConditionModel.from_pretrained("/home/flix/epoch1/")

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_images = 25
class_prompts = ["CNV", "DME", "DRUSEN", "NORMAL"]

pipeline = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base",
    unet=unet,
    vae=AutoencoderKL.from_pretrained("flix-k/custom_model_parts", subfolder="vae_trained_kl"),
    torch_dtype=torch.float16,
    safety_checker=None,
    )
pipeline.set_progress_bar_config(disable=True)
pipeline.to(device)
num_gpus = 1
images_per_gpu = num_images // num_gpus
save_path = "./synth_data_GS4/"
for class_prompt in class_prompts:
    
    print(f"Generating images for class {class_prompt}")
    if save_path is not None:
        save_path_class = save_path + f"/{class_prompt}"
        isExist = os.path.exists(save_path_class) 
        if not isExist:
            os.makedirs(save_path_class)
    for it in tqdm(range(images_per_gpu)):
        with torch.autocast("cuda"):
            images = pipeline(
                prompt = class_prompt,
                height = 512,
                width = 512,
                num_inference_steps = 25,
                guidance_scale = 4,
                negative_prompt = None,
                num_images_per_prompt = 1,
                ).images
            for idx, image in enumerate(images):
                id_num = idx + (it * 1)
                id = str(id_num).zfill(len(str(num_images)))
                image.save(f"{save_path_class}/{class_prompt}-({id}).jpg")

Generating images for class CNV


100%|██████████| 25/25 [01:49<00:00,  4.40s/it]


Generating images for class DME


100%|██████████| 25/25 [01:50<00:00,  4.42s/it]


Generating images for class DRUSEN


100%|██████████| 25/25 [01:50<00:00,  4.40s/it]


Generating images for class NORMAL


100%|██████████| 25/25 [01:50<00:00,  4.43s/it]


In [12]:
folder = "./synth_data_GS4/"

os.makedirs(f"./{folder}/ALL", exist_ok=True)

for c in ["NORMAL", "CNV", "DRUSEN", "DME"]:
    for f in os.listdir(f"./{folder}/{c}"):
        if ".jpg" in f:
            shutil.copy(f"./{folder}/{c}/{f}", f"./synth_data/ALL/{f}")

In [13]:
import timm
from cleanfid import fid
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('inception_v3', pretrained=True, num_classes=4).to(device)
model.load_state_dict(torch.load("/home/flix/Documents/DeepFlix/generative/diffusion/evaluation/models/finetuned_best.pt"))
model.eval()
model = torch.nn.Sequential(*(list(model.children())[:-1]))

score = fid.compute_fid(
    "./synth_data/ALL", 
    "/home/flix/Documents/hf_datasets/OCT-datasetv3/val", 
    mode="clean",
    custom_feat_extractor=model,
)
print(score)

compute FID between two folders
Found 100 images in the folder ./synth_data/ALL


FID ALL : 100%|██████████| 4/4 [00:01<00:00,  3.93it/s]


Found 4000 images in the folder /home/flix/Documents/hf_datasets/OCT-datasetv3/val


FID val : 100%|██████████| 125/125 [00:11<00:00, 10.96it/s]


19.998051014705993


: 

FID: 5.590877777494846