In [1]:
from transformers import CLIPTokenizer
from datasets import load_dataset
from datasets import Image as HuggingFaceImage
from linformer import Linformer
from vit_pytorch.efficient import ViT
from dalle2_pytorch.tokenizer import tokenizer
import torch
def get_tokenizer() -> CLIPTokenizer:
    return CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def get_tokenizer2():
    return tokenizer
def prepare_data(tokenizer):
    def add_prompt(example):
        props = example['font_properties']
        character = example['character']
        split = character.split('_')
        if len(split) > 1:
            character = split[0] + 'case ' + split[1]
        else:
            character = split[0]
        prompt = f"a {props['font_serifs']} {character} with {props['width']} width {props['rounding']} corners {props['font_weight']} weight and {props['dynamics']} movement with characteristics that can be described by adjectives {example['font_characteristics']}" 
        example['prompt'] = prompt
        return example
    def map_tokens(example):
        prompt = example['prompt']
        tokens = tokenizer.encode(prompt, padding='max_length', max_length=256)
        #tokens = tokenizer.encode(prompt)
        example['tokens'] = tokens
        return example
    dataset = load_dataset('json', data_files={'train':'train-metadata.jsonl', 'test':'test-metadata.jsonl'})
    
    train_new_column = ['foo'] * len(dataset['train'])
    dataset['train'] = dataset['train'].add_column('prompt', train_new_column)
    dataset['train'] = dataset['train'].add_column('tokens', train_new_column)
    dataset['train'] = dataset['train'].map(add_prompt)
    dataset['train'] = dataset['train'].map(map_tokens)
    dataset['train'] = dataset['train'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['train'] = dataset['train'].cast_column('image', HuggingFaceImage())
    dataset['train'] = dataset['train'].with_format('torch')
    
    test_new_column = ['bar'] * len(dataset['test'])
    dataset['test'] = dataset['test'].add_column('prompt', test_new_column)
    dataset['test'] = dataset['test'].add_column('tokens', test_new_column)
    dataset['test'] = dataset['test'].map(add_prompt)
    dataset['test'] = dataset['test'].map(map_tokens)
    dataset['test'] = dataset['test'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['test'] = dataset['test'].cast_column('image', HuggingFaceImage())
    dataset['test'] = dataset['test'].with_format('torch')
    return dataset
def get_vit_model(image_size: int, patch_size: int, dim: int, depth: int, num_heads: int, k: int, device: str):
    sequence_length = (image_size//patch_size)**2 + 1
    # for 512x512px image with 32x32px patches: 16x16 + 1 CLS token
    efficient_transformer = Linformer(
        dim=dim,
        seq_len=sequence_length,  
        depth=depth,
        heads=num_heads,
        k=k
    )
    model = ViT(
        dim=dim,
        image_size=image_size,
        patch_size=patch_size,
        num_classes=62,
        transformer=efficient_transformer,
        channels=1,
    )
    return model 
def get_vit(image_size, patch_size, vit_dim, vit_depth, vit_num_heads, k, device, vit_checkpoint_path):
    vit = get_vit_model(image_size=image_size, 
                        patch_size=patch_size, 
                        dim=vit_dim, 
                        depth=vit_depth, 
                        num_heads=vit_num_heads, 
                        k=k, 
                        device=device)
    if vit_checkpoint_path != None:
        vit_checkpoint = torch.load(vit_checkpoint_path)
        vit.load_state_dict(vit_checkpoint['model_state_dict'])
        print('Loaded ViT model from checkpoint:', vit_checkpoint_path)
    return vit


In [2]:
# import os
# from transformers import CLIPTextModel
# from x_clip_train import get_tokenizer
# import torch
# from x_clip import CLIP
# from vit_pytorch.extractor import Extractor

# image_size = 512
# patch_size = 32
# vit_dim = 128
# vit_depth = 12
# vit_num_heads = 8
# k = 64
# vit_checkpoint = './vit-checkpoints/model-512-epoch3.pt'
# base_vit = get_vit(image_size, 
#                     patch_size, 
#                     vit_dim, 
#                     vit_depth, 
#                     vit_num_heads, 
#                     k, 
#                     device=None, 
#                     vit_checkpoint_path=None)
# image_encoder = Extractor(
#     base_vit,
#     return_embeddings_only = True
# )
# clip_tokenizer = get_tokenizer(True)
# text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-base-patch32')
# text_encoder.resize_token_embeddings(len(clip_tokenizer))
# path = os.path.join(os.getcwd(), 'clip-checkpoints', 'clip-epoch-9.pt')
# checkpoint = torch.load(path)
# clip = CLIP(
#     image_encoder = image_encoder,
#     text_encoder = text_encoder,
#     dim_image=128,
#     dim_text=512,
#     dim_latent=384,
#     text_encode_without_mask=False,
#     use_all_token_embeds=False,
#     text_has_cls_token=True,
#     visual_has_cls_token=True,
#     num_text_tokens=text_encoder.vocab_size,
#     text_pad_id=clip_tokenizer.pad_token_id,
#     text_eos_id=clip_tokenizer.eos_token_id,
#     use_mlm=True,
#     mlm_mask_token_id=clip_tokenizer.mask_token_id,
#     mlm_pad_token_id=clip_tokenizer.pad_token_id,
#     mlm_mask_ignore_token_ids=[clip_tokenizer.bos_token_id],
#     channels=1
# ).to('cuda')
# clip.load_state_dict(checkpoint['model_state_dict'])
# # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# start_epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# print('Loaded model from checkpoint:', path)

### Moving forward: Train Decoder (Custom Linformer + CLIP)

In [4]:
# import torch
# from dalle2_pytorch import Unet, Decoder, CLIP


# # unet for the decoder
# unet = Unet(
#     dim = 128,
#     image_embed_dim = 512,
#     cond_dim = 128,
#     channels = 1,
#     dim_mults=(1, 2, 4, 8)
# ).cuda()

# # decoder, which contains the unet and clip

# decoder = Decoder(
#     unet = unet,
#     clip = clip,
#     timesteps = 100,
#     image_cond_drop_prob = 0.1,
#     text_cond_drop_prob = 0.5,
#     channels=1
# ).cuda()

# # mock images (get a lot of this)

# images = torch.randn(1, 1, 512, 512).cuda()

# # feed images into decoder

# loss = decoder(images)
# loss.backward()

torch.Size([1, 65, 128])
KV LEN 65 Seq len 257


AssertionError: the sequence length of the key / values must be 257 - 65 given

### Training Decoder/Prior: OpenAI CLIP 

In [2]:
import numpy as np 
dataset = prepare_data(get_tokenizer())
images_tensor = dataset['train'][0:1]['image']
texts_tensor = dataset['train'][0:1]['tokens']


images_tensor = images_tensor.permute(0, 3, 1, 2)
images_tensor = images_tensor.float()

print(images_tensor.shape)
print(texts_tensor.shape)


Found cached dataset json (C:/Users/rinat/.cache/huggingface/datasets/json/default-1e07ea3eabca7683/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-54018895e1b774e9.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-8115de56d8df858f.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-2ad7280546491b5e.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-5c61c0fc59a2958d.arrow


torch.Size([1, 3, 512, 512])
torch.Size([1, 256])


In [3]:
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter

# openai pretrained clip - defaults to ViT-B/32

clip = OpenAIClipAdapter()

# mock data

text = texts_tensor.cuda()
images_tensor = images_tensor.float()
images = images_tensor.cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 10000,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim = 512,
    cond_on_text_encodings = True  # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 10000,
    sample_timesteps = (250, 27),
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

for unet_number in (1, 2):
    loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['A lowercase a which has traits blocky and properties black square sans serif static extended all caps'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

sampling loop time step:   0%|          | 0/10000 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/27 [00:00<?, ?it/s]

In [4]:
#Save image 
from torchvision.utils import save_image
save_image(images, 'example_output.png')

In [62]:
#HugginFaceImage - 512 x 512 x 3 
#Token legnth - 42

#Required format: 
#Image: (B, C, W, H) ~ channels: 3 , W,H: 512 
#Text: (B, T)

images_tensor = dataset['train'][0]['image']
texts_tensor = dataset['train'][0]['tokens']