In [7]:
from transformers import CLIPTokenizer
from datasets import load_dataset
from datasets import Image as HuggingFaceImage
from linformer import Linformer
from vit_pytorch.efficient import ViT
from dalle2_pytorch.tokenizer import tokenizer
import torch
from tqdm import tqdm 

def get_tokenizer() -> CLIPTokenizer:
    return CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def get_tokenizer2():
    return tokenizer
def prepare_data(tokenizer):
    def add_prompt(example):
        props = example['font_properties']
        character = example['character']
        split = character.split('_')
        if len(split) > 1:
            character = split[0] + 'case ' + split[1]
        else:
            character = split[0]
        prompt = f"a {props['font_serifs']} {character} with {props['width']} width {props['rounding']} corners {props['font_weight']} weight and {props['dynamics']} movement with characteristics that can be described by adjectives {example['font_characteristics']}" 
        example['prompt'] = prompt
        return example
    def map_tokens(example):
        prompt = example['prompt']
        tokens = tokenizer.encode(prompt, padding='max_length', max_length=256)
        #tokens = tokenizer.encode(prompt)
        example['tokens'] = tokens
        return example
    dataset = load_dataset('json', data_files={'train':'train-metadata.jsonl', 'test':'test-metadata.jsonl'})
    
    train_new_column = ['foo'] * len(dataset['train'])
    dataset['train'] = dataset['train'].add_column('prompt', train_new_column)
    dataset['train'] = dataset['train'].add_column('tokens', train_new_column)
    dataset['train'] = dataset['train'].map(add_prompt)
    dataset['train'] = dataset['train'].map(map_tokens)
    dataset['train'] = dataset['train'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['train'] = dataset['train'].cast_column('image', HuggingFaceImage())
    dataset['train'] = dataset['train'].with_format('torch')
    
    test_new_column = ['bar'] * len(dataset['test'])
    dataset['test'] = dataset['test'].add_column('prompt', test_new_column)
    dataset['test'] = dataset['test'].add_column('tokens', test_new_column)
    dataset['test'] = dataset['test'].map(add_prompt)
    dataset['test'] = dataset['test'].map(map_tokens)
    dataset['test'] = dataset['test'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['test'] = dataset['test'].cast_column('image', HuggingFaceImage())
    dataset['test'] = dataset['test'].with_format('torch')
    return dataset
def get_vit_model(image_size: int, patch_size: int, dim: int, depth: int, num_heads: int, k: int, device: str):
    sequence_length = (image_size//patch_size)**2 + 1
    # for 512x512px image with 32x32px patches: 16x16 + 1 CLS token
    efficient_transformer = Linformer(
        dim=dim,
        seq_len=sequence_length,  
        depth=depth,
        heads=num_heads,
        k=k
    )
    model = ViT(
        dim=dim,
        image_size=image_size,
        patch_size=patch_size,
        num_classes=62,
        transformer=efficient_transformer,
        channels=1,
    )
    return model 
def get_vit(image_size, patch_size, vit_dim, vit_depth, vit_num_heads, k, device, vit_checkpoint_path):
    vit = get_vit_model(image_size=image_size, 
                        patch_size=patch_size, 
                        dim=vit_dim, 
                        depth=vit_depth, 
                        num_heads=vit_num_heads, 
                        k=k, 
                        device=device)
    if vit_checkpoint_path != None:
        vit_checkpoint = torch.load(vit_checkpoint_path)
        vit.load_state_dict(vit_checkpoint['model_state_dict'])
        print('Loaded ViT model from checkpoint:', vit_checkpoint_path)
    return vit


### Training Decoder/Prior: OpenAI CLIP 

#### Get Data

In [4]:
import numpy as np 
dataset = prepare_data(get_tokenizer())
#images_tensor = dataset['train'][0:2]['image']
#texts_tensor = dataset['train'][0:2]['tokens']


#images_tensor = images_tensor.permute(0, 3, 1, 2)
#images_tensor = images_tensor.float()

#print(images_tensor.shape)
#print(texts_tensor.shape)


Found cached dataset json (C:/Users/rinat/.cache/huggingface/datasets/json/default-1e07ea3eabca7683/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-54018895e1b774e9.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-8115de56d8df858f.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-2ad7280546491b5e.arrow
Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-5c61c0fc59a2958d.arrow


#### Train the Decoder 

In [3]:
import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, OpenAIClipAdapter
from dalle2_pytorch import DecoderTrainer
# openai pretrained clip - defaults to ViT-B/32
n_epochs = 1
print_loss_every = 500
clip = OpenAIClipAdapter()

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim = 512,
    cond_on_text_encodings = True  # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000,
    sample_timesteps = (250, 27),
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr = 1e-5,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

for epoch in range(n_epochs):
    print('Epoch: ', epoch)
    u1_loss = 0 
    u2_loss = 0

    for i, data in enumerate(tqdm(dataset['train'])):
        img, emb = data['image'], data['tokens'] 
        img = img.unsqueeze(0)
        emb = emb.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        img = img.float()

        for unet_number in (1, 2):
            loss = decoder_trainer(
                img,
                text = emb,
                unet_number = unet_number, # which unet to train on
                max_batch_size = 4         # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
            )

            if i % print_loss_every == 0:
                print('Batch ', i)
                print('Loss:', loss)
            decoder_trainer.update(unet_number)

        if i%100 == 0:
            print('Saving decoder model...')
            decoder_trainer.save("best_decoder.pth")
print('Saving last decoder')
decoder_trainer.save("last_decoder.pth")

Epoch:  0


  0%|          | 0/12090 [00:00<?, ?it/s]

Batch  0
Loss: 1.0043487548828125
Batch  0
Loss: 1.0021426677703857
Saving decoder model...


  1%|          | 100/12090 [01:02<1:51:43,  1.79it/s]

Batch  100
Loss: 0.6581845879554749
Batch  100
Loss: 1.0761290788650513
Saving decoder model...


  2%|▏         | 200/12090 [02:01<1:39:16,  2.00it/s]

Batch  200
Loss: 1.4763939380645752
Batch  200
Loss: 0.6246916055679321
Saving decoder model...


  2%|▏         | 300/12090 [03:02<1:51:20,  1.76it/s]

Batch  300
Loss: 2.0950543880462646
Batch  300
Loss: 0.6427910327911377
Saving decoder model...


  3%|▎         | 400/12090 [04:02<1:43:36,  1.88it/s]

Batch  400
Loss: 0.44975772500038147
Batch  400
Loss: 0.9298602342605591
Saving decoder model...


  4%|▍         | 500/12090 [04:59<1:42:20,  1.89it/s]

Batch  500
Loss: 0.574957549571991
Batch  500
Loss: 0.17286932468414307
Saving decoder model...


  5%|▍         | 600/12090 [06:00<1:41:16,  1.89it/s]

Batch  600
Loss: 0.3430556654930115
Batch  600
Loss: 0.6291264891624451
Saving decoder model...


  6%|▌         | 700/12090 [06:58<1:43:21,  1.84it/s]

Batch  700
Loss: 0.07377719134092331
Batch  700
Loss: 0.03940613567829132
Saving decoder model...


  7%|▋         | 800/12090 [08:00<1:46:02,  1.77it/s]

Batch  800
Loss: 0.14477425813674927
Batch  800
Loss: 0.0780532956123352
Saving decoder model...


  7%|▋         | 900/12090 [08:54<1:28:35,  2.11it/s]

Batch  900
Loss: nan
Batch  900
Loss: nan
Saving decoder model...


  8%|▊         | 1000/12090 [09:52<1:49:24,  1.69it/s]

Batch  1000
Loss: nan
Batch  1000
Loss: nan
Saving decoder model...


  9%|▉         | 1100/12090 [10:50<1:29:03,  2.06it/s]

Batch  1100
Loss: nan
Batch  1100
Loss: nan
Saving decoder model...


 10%|▉         | 1200/12090 [11:43<1:27:22,  2.08it/s]

Batch  1200
Loss: nan
Batch  1200
Loss: nan
Saving decoder model...


 11%|█         | 1300/12090 [12:41<1:48:11,  1.66it/s]

Batch  1300
Loss: nan
Batch  1300
Loss: nan
Saving decoder model...


 12%|█▏        | 1400/12090 [13:39<1:40:19,  1.78it/s]

Batch  1400
Loss: nan
Batch  1400
Loss: nan
Saving decoder model...


 12%|█▏        | 1469/12090 [14:22<1:43:59,  1.70it/s]


KeyboardInterrupt: 

#### Train the Prior 

In [9]:
#### Train the Prior 
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, DiffusionPriorTrainer
from dalle2_pytorch import OpenAIClipAdapter

n_epochs = 1
print_loss_every = 500

clip = OpenAIClipAdapter()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
    diffusion_prior,
    lr = 1e-5,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

for i, data in enumerate(tqdm(dataset['train'])):
        img, emb = data['image'], data['tokens'] 
        img = img.unsqueeze(0)
        emb = emb.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        img = img.float()

        loss = diffusion_prior_trainer(emb, img, max_batch_size = 4)


        if i % print_loss_every == 0:
            print('Batch ', i)
            print('Loss:', loss)
        diffusion_prior_trainer.update()

        if i%100 == 0:
            print('Saving prior model...')
            diffusion_prior_trainer.save("best_prior.pth")
print('Saving last prior model')
diffusion_prior_trainer.save("last_prior.pth")

  0%|          | 0/12090 [00:00<?, ?it/s]

Batch  0
Loss: 1.3394761085510254
Saving prior model...
Saving checkpoint at step: 1


  1%|          | 99/12090 [00:10<17:44, 11.26it/s] 

Saving prior model...
Saving checkpoint at step: 101


  2%|▏         | 200/12090 [00:20<17:52, 11.09it/s]

Saving prior model...
Saving checkpoint at step: 201


  2%|▏         | 300/12090 [00:30<17:37, 11.15it/s]

Saving prior model...
Saving checkpoint at step: 301


  3%|▎         | 400/12090 [00:41<18:08, 10.74it/s]

Saving prior model...
Saving checkpoint at step: 401


  4%|▍         | 499/12090 [00:52<18:24, 10.49it/s]

Batch  500
Loss: 0.0503176786005497
Saving prior model...
Saving checkpoint at step: 501


  5%|▍         | 600/12090 [01:03<17:56, 10.68it/s]

Saving prior model...
Saving checkpoint at step: 601


  6%|▌         | 700/12090 [01:13<18:04, 10.50it/s]

Saving prior model...
Saving checkpoint at step: 701


  7%|▋         | 799/12090 [01:24<18:02, 10.43it/s]

Saving prior model...
Saving checkpoint at step: 801


  7%|▋         | 815/12090 [01:27<20:13,  9.29it/s]

#### Sample from the model 

In [None]:
#Load the decoder 
decoder_path = ''
unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim = 512,
    cond_on_text_encodings = True  # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000,
    sample_timesteps = (250, 27),
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

decoder.load_state_dict(torch.load(decoder_path, map_location='cpu'))

#Load the prior 
prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 1000,
    cond_drop_prob = 0.2
).cuda()
prior_path = 'PATH_TO_PRIOR'
diffusion_prior.load_state_dict(torch.load(prior_path, map_location='cpu'))

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['A lowercase a which has traits blocky and properties black square sans serif static extended all caps'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

In [11]:
#Save image 
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage

#print(images.squeeze(0).shape)
images = images.squeeze(0)
img = ToPILImage()(images)
img.show()
#save_image(images, 'example_output.png')

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

In [62]:
#HugginFaceImage - 512 x 512 x 3 
#Token legnth - 42

#Required format: 
#Image: (B, C, W, H) ~ channels: 3 , W,H: 512 
#Text: (B, T)

