In [1]:
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from transformers import CLIPTextModel, CLIPTokenizer
import torch

from tqdm import tqdm 
from diffusers import AutoencoderKL



    PyTorch 2.0.0+cu118 with CUDA 1108 (you have 1.13.0+cu116)
    Python  3.8.16 (you have 3.8.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
2023-06-06 14:24:56.228962: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-06 14:24:56.373435: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-06 14:24:57.009905: W tenso

In [2]:
class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, samples_per_class=None):
        super(CustomImageFolder, self).__init__(root, transform=transform, target_transform=target_transform)
        
        self.samples_per_class = samples_per_class
        
        if samples_per_class is not None:
            # New class_to_idx dictionary
            new_class_to_idx = {}
            # New samples list
            new_samples = []
            # New targets list
            new_targets = []
            
            # For each class in the original class_to_idx
            for class_name in self.class_to_idx:
                # Get all the samples for this class
                class_samples = [(s, t) for s, t in self.samples if t == self.class_to_idx[class_name]]
                # If there are more samples than samples_per_class, trim the list
                if len(class_samples) > samples_per_class:
                    class_samples = class_samples[:samples_per_class]
                
                # Append the samples to the new samples and targets list
                new_samples.extend(class_samples)
                new_targets.extend([self.class_to_idx[class_name]] * len(class_samples))
                # Set the class_to_idx for the new class
                new_class_to_idx[class_name] = self.class_to_idx[class_name]
            
            # Set the new class_to_idx, samples, and targets
            self.class_to_idx = new_class_to_idx
            self.samples = new_samples
            self.targets = new_targets
    
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            dict: {'images': sample, 'ids': target}
        """
        img, target = super(CustomImageFolder, self).__getitem__(index)
        target = self.classes[target] # Get class name
        
        return {'images': img, 'targets': target}

    
def collate_fn(examples):
    input_ids = [example["targets"] for example in examples]
    pixel_values = [example["images"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    batch = {
        "targets": input_ids,
        "pixel_values": pixel_values,
    }
    return batch
size = ((496, 496))
transform = transforms.Compose([
    transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

dataset = CustomImageFolder(root="/home/flix/Documents/oct-data/CellData/OCT/train/", transform=transform)

# Define the split sizes. In this case, we will split 70% for train and 30% for validation.
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset
torch.manual_seed(0)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Now you can create DataLoaders for your training and validation datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


print("Dataset size:", len(train_dataset))
print("Dataset size:", len(val_dataset))

for i in range(25):
    print(train_dataset[i]['targets'])

print(dataset.class_to_idx)


Dataset size: 92062
Dataset size: 16247
CNV
CNV
CNV
CNV
CNV
CNV
DME
NORMAL
CNV
CNV
CNV
CNV
CNV
CNV
NORMAL
CNV
DRUSEN
NORMAL
NORMAL
CNV
DRUSEN
NORMAL
DRUSEN
DME
NORMAL
{'CNV': 0, 'DME': 1, 'DRUSEN': 2, 'NORMAL': 3}


In [3]:
import pickle
import os

device = "cuda"
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae").to(device, dtype=torch.float16)
vae.eval()

def precompute_latents_pickle(vae, train_dataloader, val_dataloader, classes):


    counter = 0
    # create folder structure
    os.makedirs("latents", exist_ok=True)
    os.makedirs("latents/train", exist_ok=True)
    os.makedirs("latents/val", exist_ok=True)

    # create folder for each class in train and val
    for c in classes:
        os.makedirs(f"latents/train/{c}", exist_ok=True)
        os.makedirs(f"latents/val/{c}", exist_ok=True)


    for batch in tqdm(train_dataloader):
        pixel_values = batch["pixel_values"].to("cuda", dtype=torch.float16)
        with torch.no_grad():
            latent_dist = vae.encode(pixel_values).latent_dist

        
        target = batch["targets"][0]
            # save latent distribution
        with open(f'latents/train/{target}/{target}-({counter}).pkl', 'wb') as output:
            pickle.dump(latent_dist, output, pickle.HIGHEST_PROTOCOL)
        counter += 1

    counter = 0
    for batch in tqdm(val_dataloader):
        pixel_values = batch["pixel_values"].to("cuda", dtype=torch.float16)
        with torch.no_grad():
            latent_dist = vae.encode(pixel_values).latent_dist

        
        target = batch["targets"][0]
            # save latent distribution
        with open(f'latents/val/{target}/{target}-({counter}).pkl', 'wb') as output:
            pickle.dump(latent_dist, output, pickle.HIGHEST_PROTOCOL)
        counter += 1

precompute_latents_pickle(vae, train_loader, val_loader, dataset.classes)

  8%|▊         | 7340/92062 [09:38<1:51:13, 12.70it/s]


In [8]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.vision import VisionDataset
from torchvision import transforms
import pickle
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import lightning.pytorch as pl

def pickle_loader(path: str) -> Any:
    with open(path, 'rb') as f:
        return pickle.load(f)

class PickleFolder(DatasetFolder):
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = pickle_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super(PickleFolder, self).__init__(root, loader, ('.pkl',),
                                            transform=transform,
                                            target_transform=target_transform,
                                            is_valid_file=is_valid_file)
        self.imgs = self.samples
        
    def __getitem__(self, index: int) -> Dict[str, Any]:
        path, target = self.samples[index]
        sample = self.loader(path)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return {"latents": sample, "target": target}

def collate_fn(examples):
    targets = [example["target"] for example in examples]
    pixel_values = [example["latents"].sample() for example in examples]
    pixel_values = torch.stack(pixel_values).squeeze(1)
    
    batch = {
        "latents": pixel_values,
        "classes": targets,
    }
    
    return batch

class PickleDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = './latents', batch_size: int = 4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.pickle_train = PickleFolder(f'{self.data_dir}/train')
            self.pickle_val = PickleFolder(f'{self.data_dir}/val')
            
            print("Trainset: ", len(self.pickle_train))
            print("Valset: ", len(self.pickle_val))

        # # Assign test dataset for use in dataloader(s)
        # if stage == 'test' or stage is None:
        #     self.pickle_test = PickleFolder(f'{self.data_dir}/test')

    def train_dataloader(self):
        return DataLoader(self.pickle_train, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.pickle_val, batch_size=self.batch_size, collate_fn=collate_fn, shuffle=False)

    # def test_dataloader(self):
    #     return DataLoader(self.pickle_test, batch_size=self.batch_size, collate_fn=collate_fn)


In [9]:
pickle_data = PickleDataModule(batch_size=1)

pickle_data.setup()

for batch in pickle_data.train_dataloader():
    print(batch["latents"].shape)
    print(batch["classes"])

Trainset:  13600
Valset:  2400
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[0]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[2]
torch.Size([1, 4, 64, 64])
[3]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[1]
torch.Size([1, 4, 64, 64])
[3]
torch.Si