In [None]:
import os
import requests
from io import BytesIO

os.chdir("/w/246/ikozlov/csc2231-project/transformer_latent_diffusion")

from tld.train import main
from tld.configs import DataDownloadConfig, DataConfig, ModelConfig, TrainConfig, DenoiserConfig, VaeConfig, ClipConfig
from accelerate import notebook_launcher
import wandb
from PIL import Image
import torch
from torchvision.transforms import Resize, ToTensor
from datasets import load_dataset
from tld.data import get_text_and_latent_embeddings_hdf5, encode_text, encode_image
import clip
from diffusers import AutoencoderKL, AutoencoderTiny, StableDiffusionPipeline
import numpy as np
import huggingface_hub as hgf
from transformers import CLIPProcessor, CLIPModel, CLIPTextModel, AutoTokenizer, CLIPTokenizer
from tqdm import tqdm

# Log into W&B and HF 
os.environ["WANDB_API_KEY"]='57ae32f3ba3cd4a369ec2340fe535fc4ca75e1a7'
wandb.login()
hgf.login('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')
# !wandb login

# Initialize device 
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
device

In [None]:
# Image Resolution
resolution = DataDownloadConfig.image_size

# Latents paths 
latent_save_path = DataDownloadConfig.latent_save_path
image_latent_path = os.path.join(latent_save_path, 'image_latents.npy')
text_emb_path = os.path.join(latent_save_path, 'text_encodings.npy')
val_emb_path = os.path.join(latent_save_path, 'val_encs.npy')

In [None]:
def load_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raises stored HTTPError, if one occurred.
        image = Image.open(BytesIO(response.content))
        return image
    except requests.exceptions.RequestException as e:
        print(f"Error fetching {url}: {e}")
        return None
    except:
        print(f"Something else went wrong fetching {url}")
        return None

def transform(example):
    if "image" in example: 
        example['image'] = example['image'].resize((resolution, resolution), Image.BICUBIC)
        return example 
    
    if "image_url" in example:
        example['image'] = load_image_from_url(example['image_url'])
        return example
        
    if "link" in example:
        example['image'] = load_image_from_url(example['link'])
        
    if example['image'] is not None:
        example['image'] = example['image'].resize((resolution, resolution), Image.BICUBIC)
    else:
        example['image'] = None
    return example

def custom_collate_fn(batch):
    images, texts = [], []
    for item in batch:
        image, text = item['image'], item['text']

        # Apply the transformation
        image = ToTensor()(image)
        image_rgb = torch.zeros(3,image.shape[1], image.shape[2])
        if image.shape[0] > 3: 
            image_rgb = image[:3, :, :] 
        elif image.shape[0] < 3: 
            image_rgb[0,:,:] = image[0, :, :] 
            image_rgb[1,:,:] = image[0, :, :] 
            image_rgb[2,:,:] = image[0, :, :] 
        
        # new_item = dict()
        # new_item['image'] = ToTensor()(item['image'])
        # new_item['text'] = ToTensor()(item['text'])

        images.append(image_rgb)
        texts.append(text)


    return torch.stack(images), texts

def get_text_and_latent_embeddings(dataloader, vae, model, device):
    img_encoding_ds = []
    text_encoding_ds = []
    
    for img, label in tqdm(dataloader):
            # label_dv = torch.Tensor(label).to(device) 
            # tokens = tokenizer(label_dv, device=device) 
            text_tokens = clip.tokenize(label, truncate=True).to(device)
            model = model.to(device)
            text_encoding = model(text_tokens).cpu().numpy().astype(np.float16)
             
            #text_encoding = encode_text(label, model, device).cpu().numpy().astype(np.float16)
            x = img.to(device).to(torch.float16)
            x = x * 2 - 1  # to make it between -1 and 1.
            img_encoding = vae.encode(x, return_dict=False)[0].sample().cpu().numpy().astype(np.float16)
            
            # img_encoding = encode_image(img, vae).cpu().numpy().astype(np.float16)

            text_encoding_ds.append(text_encoding)
            img_encoding_ds.append(img_encoding)
            
    return img_encoding_ds, text_encoding_ds

def get_img_embeddings(dataloader, vae, device):
    img_encoding_ds = []
    
    print("Generating image embeddings...") 
    for img, _ in tqdm(dataloader):
            #text_encoding = encode_text(label, model, device).cpu().numpy().astype(np.float16)
            x = img.to(device).to(VaeConfig.vae_dtype)
            x = x * 2 - 1  # to make it between -1 and 1.
            vae.to(device)
            img_encoding = vae.encode(x, return_dict=False)[0].sample().cpu().detach().numpy().astype(np.float16)
            img_encoding_ds.append(img_encoding)
            
    return img_encoding_ds

def get_text_embeddings(model, tokenizer, device, text_list = [], dataloader = None):
    text_encoding_ds = []
    
    print("Generating text embeddings...") 
    if text_list != []: 
        for label in tqdm(text_list):
                # label_dv = torch.Tensor(label).to(device) 
                # tokens = tokenizer(label_dv, device=device) 
                # text_tokens = torch.Tensor(tokenizer(label, truncation=True)).to(device)
                
                # Tokenize the labels/text
                tokens = tokenizer(label, truncation=True, padding="max_length", return_tensors="pt")
                
                # Move tokens to the appropriate device
                input_ids = tokens['input_ids'].to(device)
                attention_mask = tokens['attention_mask'].to(device)
                
                # Get model output using the prepared tokens
                model = model.to(device)
                text_encoding = model(input_ids, attention_mask=attention_mask).last_hidden_state.detach().cpu().numpy().astype(np.float16)
                
                # model = model.to(device)
                # text_encoding = model(text_tokens).last_hidden_state.to("cpu").detach().numpy().astype(np.float16)

                text_encoding_ds.append(text_encoding)
    elif dataloader != None: 
        for _, label in tqdm(dataloader):
                # label_dv = torch.Tensor(label).to(device) 
                # tokens = tokenizer(label_dv, device=device) 
                # text_tokens = torch.Tensor(tokenizer(label, truncation=True)).to(device)

                # Tokenize the labels/text
                tokens = tokenizer(label, truncation=True, padding="max_length", return_tensors="pt")
                
                # Move tokens to the appropriate device
                input_ids = tokens['input_ids'].to(device)
                attention_mask = tokens['attention_mask'].to(device)
                
                # Get model output using the prepared tokens
                model = model.to(device)
                text_encoding = model(input_ids, attention_mask=attention_mask).last_hidden_state.detach().cpu().numpy().astype(np.float16)
                
                # model = model.to(device)
                # text_encoding = model(text_tokens).last_hidden_state.to("cpu").detach().numpy().astype(np.float16)

                text_encoding_ds.append(text_encoding)
            
    return text_encoding_ds

#######################################
#            Load Data Set            #
#######################################

# Load dataset and apply transforms to training set 
# dataset_name = 'lambdalabs/pokemon-blip-captions' 
dataset_name = "fantasyfish/laion-art" # "laion/gpt4v-dataset" # "saxon/T2IScoreScore" # 'valhalla/pokemon-dataset'
dataset = load_dataset(dataset_name)
dataset["train"] = dataset["train"].map(transform)
dataset["train"] = dataset["train"].filter(lambda example: example['image'] is not None) # Filter out entries where 'image' is None
#dataset.set_format(type='torch', columns=['image', 'text'])

# Create DataLoader object 
train_dataloader = torch.utils.data.DataLoader(dataset["train"], batch_size=DataDownloadConfig.batch_size, shuffle=False, collate_fn=custom_collate_fn)

#######################################
#    Get Text and Image Encodings     #
#######################################

text_encoder_hf_path = ClipConfig.clip_model_name # "openai/clip-vit-base-patch32" # "openai/clip-vit-large-patch14"
img_encoder_hf_path = VaeConfig.vae_name # "madebyollin/taesd" #"madebyollin/sdxl-vae-fp16-fix"

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
teacher_text_encoder_hf_path = pretrained_model_name_or_path # "openai/clip-vit-base-patch32" # "openai/clip-vit-large-patch14"
teacher_img_encoder_hf_path =  pretrained_model_name_or_path # "madebyollin/taesd" #"madebyollin/sdxl-vae-fp16-fix"

# Initialize latents path
if not os.path.exists(latent_save_path):
    os.mkdir(latent_save_path)

# Load models 
# model, preprocess = clip.load("ViT-L/14")
model = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=ClipConfig.clip_dtype)
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
# tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# preprocess = CLIPProcessor.from_pretrained(pretrained_model_name_or_path, subfolder="")
#vae = AutoencoderKL.from_pretrained(img_encoder_hf_path, torch_dtype=torch.float16)
# vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, torch_dtype=VaeConfig.vae_dtype)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", torch_dtype=VaeConfig.vae_dtype)
# vae = vae.to('cuda')
# model.to('cuda')

# Generate embeddings 
# img_encodings, text_encodings = get_text_and_latent_embeddings(train_dataloader, vae, model, device)
img_encodings = get_img_embeddings(train_dataloader, vae, device) 
text_encodings = get_text_embeddings(model, tokenizer ,device, dataloader = train_dataloader)

# Save latents to path 
np.save(image_latent_path, np.concatenate(img_encodings, axis=0))
np.save(text_emb_path, np.concatenate(text_encodings, axis=0))

# Load teacher models 
# teacher_text_embedding = 

#######################################
#  Save Validation Prompt Encodings   #
#######################################

creature_descriptions = [
    "A drawing of a small, blue aquatic creature with a fin on its head and a light blue tail.",
    "A picture of a fiery orange and red mythical dragon-like figure, with smoke billowing from its nostrils.",
    "A cartoon image of a character that looks like a yellow sunflower with a smiling face in the center.",
    "An illustration of a rock-like creature, gray and rugged, with crystals emerging from its back.",
    "A sketch of a ghostly figure, transparent and white, with glowing red eyes and ethereal trails.",
    "A drawing of a cute, furry, brown bear cub-like character, with large, round ears and a small nose.",
    "An image of an electric-type creature, bright yellow with black stripes, radiating energy.",
    "A picture of an ice-like character, resembling a small, crystalline snowflake with a shimmering, icy body."
]

np.save(val_emb_path, get_text_embeddings(model, tokenizer, device, text_list=creature_descriptions))
print("Done with conversion to latents.")

In [None]:
run_id='' #@param {type:"string"}
n_epoch=40 #@param {type:"integer"}


data_config = DataConfig(latent_path=image_latent_path,
                        text_emb_path=text_emb_path,
                        val_path=val_emb_path)

denoiser_config = DenoiserConfig(image_size=int(resolution/8))

model_cfg = ModelConfig(
    data_config=data_config,
    denoiser_config=denoiser_config,
    train_config=TrainConfig(),
)

notebook_launcher(main, (model_cfg,), num_processes=1)