In [None]:
%pip install  git+https://github.com/CompVis/taming-transformers.git
!git clone https://github.com/CompVis/taming-transformers.git && cd taming-transformers && python -m pip install -e .
%pip install "omegaconf>=2.0.0" "pytorch-lightning>=1.0.8" einops transformers

import os
os._exit(00)

In [None]:
#VQGAN ImageNet (f=16), 1024

# !curl -L "https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1" >"last.ckpt"
# !curl -L "https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1" >"model.yaml"


#boris/vqgan_f16_16384
!curl -L "https://huggingface.co/boris/vqgan_f16_16384/raw/main/config.yaml" > "config_vqgan_minidalle.yaml"
!curl -L "https://huggingface.co/boris/vqgan_f16_16384/resolve/main/model.ckpt" > "model_vqgan_minidalle.ckpt"

In [None]:
import sys
sys.path.append(".")

# also disable grad to save memory
import torch
torch.set_grad_enabled(False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from omegaconf import OmegaConf
import taming
from taming.models.vqgan import VQModel

config_path = "config_vqgan_minidalle.yaml"
# config_path = "model.yaml"

config = OmegaConf.load(config_path)
model=VQModel(**config.model.params).to(DEVICE)

ckpt_path = "model_vqgan_minidalle.ckpt"
# ckpt_path = "last.ckpt"

sd = torch.load(ckpt_path, map_location=DEVICE)["state_dict"]
model.load_state_dict(sd, strict=False)

In [None]:
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from torchvision.transforms import InterpolationMode
import numpy as np
from PIL import Image
import torchvision.transforms as T

import requests

r = requests.get("https://images.pexels.com/photos/10060920/pexels-photo-10060920.jpeg?cs=srgb&dl=pexels-nataliya-vaitkevich-10060920.jpg&fm=jpg",stream=True)
im = Image.open(r.raw)


def preprocess_vqgan(x):
  x = 2.*x - 1.
  return x

def custom_to_pil(x):
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x


def resize_image(image):
      s = min(image.size)
      r = 256 / s
      s = (round(r * image.size[1]), round(r * image.size[0]))
      image = image.convert('RGB')
      image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
      image = TF.center_crop(image, output_size = 2 * [256])
      image = torch.unsqueeze(T.ToTensor()(image), 0)

      return preprocess_vqgan(image.to(DEVICE))

quant, emb_loss, info=model.encode(resize_image(im))


In [None]:
x_rec=(model.decode(quant))
display(custom_to_pil(x_rec[0]))

In [None]:
from tqdm import tqdm
import pandas as pd
import os


def create_encoding_file_parquet(base_dir_images):
    list_file=[]
    list_encoded=[]
    for file_name in tqdm(os.listdir(base_dir_images), desc='dirs') :
        _, _, [_, _, indices]=model.encode(resize_image(Image.open(base_dir_images+"/"+file_name)))
        list_file.append(file_name)
        list_encoded.append(indices.tolist())
    df= pd.DataFrame.from_dict(
                    {"file_name": list_file, "encoding": list_encoded}
                )
    df.to_parquet(base_dir_images+".parquet")


In [None]:
# !wget archive_train.zip
# !wget archive_val.zip
!unzip archive_train.zip
!unzip archive_val.zip

In [None]:
create_encoding_file_parquet("archive_val")
create_encoding_file_parquet("archive_train")