In [1]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from datasets import load_dataset
import numpy as np
import random
from transformers import CLIPTextModel, CLIPTokenizer
from torchvision import transforms
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel

import h5py
import torch
import os
import timm
from cleanfid import fid
import wandb

from tqdm import tqdm

from pytorch_lightning.loggers import WandbLogger
from typing import List, Dict
import gc

    PyTorch 2.0.0+cu118 with CUDA 1108 (you have 1.13.0+cu116)
    Python  3.8.16 (you have 3.8.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
2023-05-26 12:03:07.623886: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-26 12:03:07.706518: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-26 12:03:08.093441: W tenso

In [9]:
from transformers import CLIPTokenizer, CLIPTextModel

class LatentDiffusionDataModule(pl.LightningDataModule):
    def __init__(self, 
                 dataset_name, 
                 image_column, 
                 caption_column, 
                 tokenizer, 
                 resolution, 
                 center_crop, 
                 random_flip,
                 train_batch_size,
                 val_batch_size,
                 num_workers,
                 latent_file_path_train=None,
                 latent_file_path_val=None,
                 use_latents=False):
        super().__init__()
        self.dataset_name = dataset_name
        self.image_column = image_column
        self.caption_column = caption_column
        self.tokenizer = tokenizer
        self.resolution = resolution
        self.center_crop = center_crop
        self.random_flip = random_flip
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.latent_file_path_train = latent_file_path_train
        self.latent_file_path_val = latent_file_path_val
        self.use_latents = use_latents



    def prepare_data(self):
        # Your dataset download and preparation logic here
        # Make sure not to return anything from this method
        # as it is called on every GPU during distributed training

        self.dataset = load_dataset(
            self.dataset_name,
        )

        if self.use_latents:

            class H5PyTorchDataset(Dataset):
                def __init__(self, file_path):
                    self.file_path = file_path
                    with h5py.File(self.file_path, 'r') as f:
                        self.latents = f['latents'][:]
                        self.labels = f['labels'][:]

                def __getitem__(self, index):
                    latent = self.latents[index]
                    label = self.labels[index]
                    if np.isscalar(label):
                        label = np.array([label])

                    # Convert the data to PyTorch tensors
                    latent = torch.from_numpy(latent)
                    label = torch.from_numpy(label)

                    return {'latents': latent, 'labels': label}

                def __len__(self):
                    return len(self.latents)

            # Usage
            h5_pytorch_dataset_train = H5PyTorchDataset(self.latent_file_path_train)
            h5_pytorch_dataset_val = H5PyTorchDataset(self.latent_file_path_val)

            self.latent_dataset = dict()
            self.latent_dataset["train"] = h5_pytorch_dataset_train
            self.latent_dataset["val"] = h5_pytorch_dataset_val

    def setup(self, stage=None):
        # Your dataset splitting and processing logic here
        # Store the resulting datasets as instance variables (e.g., self.train_dataset)
        # You can access them in the respective dataloader methods

        self.train_transforms = transforms.Compose(
            [
                transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(self.resolution) if self.center_crop else transforms.RandomCrop(self.resolution),
                transforms.RandomHorizontalFlip() if self.random_flip else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
                # transforms.Normalize([0.19130389392375946, 0.19130389392375946, 0.19130389392375946], [0.1973849982023239, 0.1973849982023239, 0.1973849982023239])
            ]
        )
        self.train_dataset = self.dataset["train"].with_transform(self.preprocess_train)
        self.val_dataset = self.dataset["val"].with_transform(self.preprocess_train)

        if self.use_latents:
            self.train_latent_dataset = self.latent_dataset["train"]
            self.val_latent_dataset = self.latent_dataset["val"]


        from PIL import Image
        import glob
        
        if not os.path.exists("./val_images"):
            os.mkdir("./val_images")

        files = glob.glob("./val_images/*")
        if len(files) < len(self.val_dataset):
            for i in tqdm(range(len(self.val_dataset))):
                img = Image.fromarray(np.array(self.val_dataset[i]["image"]))
                image_name = self.val_dataset[i]["caption"]
                img.save(f"./val_images/{image_name}-{i}.jpg")
                
        try:
            fid.make_custom_stats("val", fdir="./val_images/")
        except:
            print("Stats already exist")

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = timm.create_model('inception_v3', pretrained=True, num_classes=4).to(device)
        model.load_state_dict(torch.load("./finetuned_best.pt"))
        model.eval()
        model = torch.nn.Sequential(*(list(model.children())[:-1]))
        try:
            fid.make_custom_stats("octv3-val", fdir="./val_images/", model=model, model_name="custom")
        except:
            print("Stats already exist")
        model.to("cpu")
        torch.cuda.empty_cache()
        del model
        torch.cuda.empty_cache()


    # def precrocess_train_latents(self, examples):
    #     latents = [latent for latent in examples["latents"]]
    #     return examples

    def preprocess_train(self, examples):
        images = [image.convert("RGB") for image in examples[self.image_column]]
        captions = [caption for caption in examples[self.caption_column]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["caption"] = [caption for caption in captions]
        examples["input_ids"] = self.tokenize_captions(examples)
        return examples

    def tokenize_captions(self, examples, is_train=True):
        captions = []
        for caption in examples[self.caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{self.caption_column}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def train_dataloader(self):
        if self.use_latents:
            return DataLoader(
                self.train_latent_dataset,
                shuffle=True,
                collate_fn=None,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )
        else:
            return DataLoader(
                self.train_dataset,
                shuffle=True,
                collate_fn=self.collate_fn,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )

    def val_dataloader(self):
        if self.use_latents:
            return DataLoader(
                self.val_latent_dataset,
                shuffle=True,
                collate_fn=None,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )
        else:
            return DataLoader(
                self.val_dataset,
                shuffle=True,
                collate_fn=self.collate_fn,
                batch_size=self.val_batch_size,
                num_workers=self.num_workers,
            )

    def collate_fn(self, examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        captions = [example["caption"] for example in examples]
        return {"pixel_values": pixel_values, "input_ids": input_ids, "captions": captions}

def precompute_latents(vae, text_encoder, dataloader, train=False, classes=["CNV", "DME", "DRUSEN", "NORMAL"]):
    # Create a list to store the latents
    latents = []
    encoder_states = []

    # Set the model to eval mode
    vae.eval()

    # Iterate over the batches
    for batch in tqdm(dataloader):
        # Extract the pixel values
        pixel_values = batch["pixel_values"].to("cuda")


        # Encode the pixel values
        with torch.no_grad():
            latent = vae.encode(pixel_values).latent_dist.sample().cpu().numpy()

        class_names = batch["captions"]
        # Encode the class names using the index of the class name in the list of classes
        class_indices = [classes.index(class_name) for class_name in class_names]

        # Append the latents to the list
        latents.append(latent)
        encoder_states.append(class_indices)

        


    # Combine all latents and labels
    latents = np.concatenate(latents, axis=0)
    encoder_states = np.concatenate(encoder_states, axis=0)
    
    # Save the latent representations to disk
    mode = 'train' if train else 'test'
    
    # use h5py highest compression rate



    file = h5py.File(f'trained_vae_kl_dv3_{mode}.h5', 'w')

    # Save the array to the file
    file.create_dataset('latents', data=latents)
    file.create_dataset('labels', data=encoder_states)

    # Close the file
    file.close()

    # Save the latent representations to disk
    # mode = 'train' if train else 'test'
    # np.save(f'{mode}_latents.npy', latents)
    # np.save(f'{mode}_encoder_states.npy', encoder_states)

from diffusers import AutoencoderKL

tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("flix-k/custom_model_parts", subfolder="vae_trained_kl").to("cuda")
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder").to("cuda")
data_module = LatentDiffusionDataModule(dataset_name="flix-k/oct-dataset-val1kv3", 
                                        image_column="image", 
                                        caption_column="caption", 
                                        tokenizer=tokenizer, 
                                        resolution=512, 
                                        center_crop=False, 
                                        random_flip=False,
                                        train_batch_size=1,
                                        val_batch_size=1,
                                        num_workers=0,)

data_module.prepare_data()
data_module.setup()
train_dataloader = data_module.val_dataloader()
# Precompute the latents for the training set
# train_latents = precompute_latents(vae, text_encoder, train_dataloader, train=True)

# import matplotlib.pyplot as plt
# # show the first 10 images in the dataloader
# for i, batch in enumerate(train_dataloader):
#     if i == 10:
#         break
#     plt.imshow(batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy())
#     plt.show()


Found cached dataset parquet (/home/flix/.cache/huggingface/datasets/flix-k___parquet/flix-k--oct-dataset-val1kv3-899ad0f348fd8f48/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 4000/4000 [00:22<00:00, 178.22it/s]


Stats already exist
Stats already exist


In [10]:
train_latents = precompute_latents(vae, text_encoder, train_dataloader, train=False)

100%|██████████| 4000/4000 [09:51<00:00,  6.77it/s]


In [4]:
from transformers import CLIPTokenizer, CLIPTextModel

class LatentDiffusionDataModule(pl.LightningDataModule):
    def __init__(self, 
                 dataset_name, 
                 image_column, 
                 caption_column, 
                 tokenizer, 
                 resolution, 
                 center_crop, 
                 random_flip,
                 train_batch_size,
                 val_batch_size,
                 num_workers,
                 latent_file_path_train=None,
                 latent_file_path_val=None,
                 use_latents=False):
        super().__init__()
        self.dataset_name = dataset_name
        self.image_column = image_column
        self.caption_column = caption_column
        self.tokenizer = tokenizer
        self.resolution = resolution
        self.center_crop = center_crop
        self.random_flip = random_flip
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.latent_file_path_train = latent_file_path_train
        self.latent_file_path_val = latent_file_path_val
        self.use_latents = use_latents



    def prepare_data(self):
        # Your dataset download and preparation logic here
        # Make sure not to return anything from this method
        # as it is called on every GPU during distributed training

        self.dataset = load_dataset(
            self.dataset_name,
        )

        if self.use_latents:

            class H5PyTorchDataset(Dataset):
                def __init__(self, file_path):
                    self.file_path = file_path
                    with h5py.File(self.file_path, 'r') as f:
                        self.latents = f['latents'][:]
                        self.labels = f['labels'][:]

                def __getitem__(self, index):
                    latent = self.latents[index]
                    label = self.labels[index]
                    if np.isscalar(label):
                        label = np.array([label])

                    # Convert the data to PyTorch tensors
                    latent = torch.from_numpy(latent)
                    label = torch.from_numpy(label)

                    return {'latents': latent, 'labels': label}

                def __len__(self):flix-k/oct-dataset-val1kv3-CNV
                    return len(self.latents)

            # Usage
            h5_pytorch_dataset_train = H5PyTorchDataset(self.latent_file_path_train)
            h5_pytorch_dataset_val = H5PyTorchDataset(self.latent_file_path_val)

            self.latent_dataset = dict()
            self.latent_dataset["train"] = h5_pytorch_dataset_train
            self.latent_dataset["val"] = h5_pytorch_dataset_val

    def setup(self, stage=None):
        # Your dataset splitting and processing logic here
        # Store the resulting datasets as instance variables (e.g., self.train_dataset)
        # You can access them in the respective dataloader methods

        self.train_transforms = transforms.Compose(
            [
                transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(self.resolution) if self.center_crop else transforms.RandomCrop(self.resolution),
                transforms.RandomHorizontalFlip() if self.random_flip else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
                # transforms.Normalize([0.19130389392375946, 0.19130389392375946, 0.19130389392375946], [0.1973849982023239, 0.1973849982023239, 0.1973849982023239])
            ]
        )
        self.train_dataset = self.dataset["train"].with_transform(self.preprocess_train)
        self.val_dataset = self.dataset["val"].with_transform(self.preprocess_train)

        if self.use_latents:
            self.train_latent_dataset = self.latent_dataset["train"]
            self.val_latent_dataset = self.latent_dataset["val"]


        from PIL import Image
        import glob
        
        if not os.path.exists("./val_images"):
            os.mkdir("./val_images")

        files = glob.glob("./val_images/*")
        if len(files) < len(self.val_dataset):
            for i in tqdm(range(len(self.val_dataset))):
                img = Image.fromarray(np.array(self.val_dataset[i]["image"]))
                image_name = self.val_dataset[i]["caption"]
                img.save(f"./val_images/{image_name}-{i}.jpg")
                
        try:
            fid.make_custom_stats("val", fdir="./val_images/")
        except:
            print("Stats already exist")

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = timm.create_model('inception_v3', pretrained=True, num_classes=4).to(device)
        model.load_state_dict(torch.load("./finetuned_best.pt"))
        model.eval()
        model = torch.nn.Sequential(*(list(model.children())[:-1]))
        try:
            fid.make_custom_stats("octv3-val", fdir="./val_images/", model=model, model_name="custom")
        except:
            print("Stats already exist")
        model.to("cpu")
        torch.cuda.empty_cache()
        del model
        torch.cuda.empty_cache()


    # def precrocess_train_latents(self, examples):
    #     latents = [latent for latent in examples["latents"]]
    #     return examples

    def preprocess_train(self, examples):
        images = [image.convert("RGB") for image in examples[self.image_column]]
        captions = [caption for caption in examples[self.caption_column]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["caption"] = [caption for caption in captions]
        examples["input_ids"] = self.tokenize_captions(examples)
        return examples

    def tokenize_captions(self, examples, is_train=True):
        captions = []
        for caption in examples[self.caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{self.caption_column}` should contain either strings or lists of strings."
                )
        inputs = self.tokenizer(
            captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    def train_dataloader(self):
        if self.use_latents:
            return DataLoader(
                self.train_latent_dataset,
                shuffle=True,
                collate_fn=None,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )
        else:
            return DataLoader(
                self.train_dataset,
                shuffle=True,
                collate_fn=self.collate_fn,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )

    def val_dataloader(self):
        if self.use_latents:
            return DataLoader(
                self.val_latent_dataset,
                shuffle=True,
                collate_fn=None,
                batch_size=self.train_batch_size,
                num_workers=self.num_workers,
            )
        else:
            return DataLoader(
                self.val_dataset,
                shuffle=True,
                collate_fn=self.collate_fn,
                batch_size=self.val_batch_size,
                num_workers=self.num_workers,
            )

    def collate_fn(self, examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        captions = [example["caption"] for example in examples]
        return {"pixel_values": pixel_values, "input_ids": input_ids, "captions": captions}

def precompute_latents(vae, text_encoder, dataloader, train=False, classes=["CNV", "DME", "DRUSEN", "NORMAL"]):
    # Create a list to store the latents
    latents = []
    encoder_states = []

    # Set the model to eval mode
    vae.eval()

    # Iterate over the batches
    for batch in tqdm(dataloader):
        # Extract the pixel values
        pixel_values = batch["pixel_values"].to("cuda")


        # Encode the pixel values
        with torch.no_grad():
            latent = vae.encode(pixel_values).latent_dist.sample().cpu().numpy()

        class_names = batch["captions"]
        # Encode the class names using the index of the class name in the list of classes
        class_indices = [classes.index(class_name) for class_name in class_names]

        # Append the latents to the list
        latents.append(latent)
        encoder_states.append(class_indices)

        


    # Combine all latents and labels
    latents = np.concatenate(latents, axis=0)
    encoder_states = np.concatenate(encoder_states, axis=0)
    
    # Save the latent representations to disk
    mode = 'train' if train else 'test'
    
    # use h5py highest compression rate



    file = h5py.File(f'trained_vae_dv3_{mode}.h5', 'w')

    # Save the array to the file
    file.create_dataset('latents', data=latents)
    file.create_dataset('labels', data=encoder_states)

    # Close the file
    file.close()

    # Save the latent representations to disk
    # mode = 'train' if train else 'test'
    # np.save(f'{mode}_latents.npy', latents)
    # np.save(f'{mode}_encoder_states.npy', encoder_states)

from diffusers import AutoencoderKL

tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained("flix-k/custom_model_parts", subfolder="vae_trained").to("cuda")
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder").to("cuda")
data_module = LatentDiffusionDataModule(dataset_name="flix-k/oct-dataset-val1kv3", 
                                        image_column="image", 
                                        caption_column="caption", 
                                        tokenizer=tokenizer, 
                                        resolution=512, 
                                        center_crop=False, 
                                        random_flip=False,
                                        train_batch_size=1,
                                        val_batch_size=1,
                                        num_workers=0,)

data_module.prepare_data()
data_module.setup()
train_dataloader = data_module.train_dataloader()
# Precompute the latents for the training set
# train_latents = precompute_latents(vae, text_encoder, train_dataloader, train=True)

# import matplotlib.pyplot as plt
# # show the first 10 images in the dataloader
# for i, batch in enumerate(train_dataloader):
#     if i == 10:
#         break
#     plt.imshow(batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy())
#     plt.show()
train_latents = precompute_latents(vae, text_encoder, train_dataloader, train=True)

Found cached dataset parquet (/home/flix/.cache/huggingface/datasets/flix-k___parquet/flix-k--oct-dataset-val1kv3-899ad0f348fd8f48/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

Stats already exist
Stats already exist


100%|██████████| 96712/96712 [3:47:47<00:00,  7.08it/s]  


In [14]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch

string = "DRUSEN"

tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder").to("cuda")


# inputs = tokenizer(
#     string, max_length=4 , padding="max_length", truncation=True, return_tensors="pt"
# )

inputs = tokenizer(
    string, max_length=tokenizer.model_max_length , padding="do_not_pad", truncation=True, return_tensors="pt"
).input_ids

inputs = inputs.to("cuda")

# Encode the string using the CLIPTextModel
with torch.no_grad():
    encoded = text_encoder(inputs)[0]

print(encoded.shape)

torch.Size([1, 4, 1024])


In [11]:
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder").to("cuda")

In [16]:
string = "NORMAL"
inputs = tokenizer(
    string, max_length=tokenizer.model_max_length , padding="do_not_pad", truncation=True, return_tensors="pt"
).input_ids

output = text_encoder(inputs)
print(output.shape)