In [6]:
from transformers import CLIPTokenizer
from datasets import load_dataset
from datasets import Image as HuggingFaceImage
from linformer import Linformer
from vit_pytorch.efficient import ViT
import torch
def get_tokenizer() -> CLIPTokenizer:
    return CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def prepare_data(tokenizer: CLIPTokenizer):
    def add_prompt(example):
        props = example['font_properties']
        character = example['character']
        split = character.split('_')
        if len(split) > 1:
            character = split[0] + 'case ' + split[1]
        else:
            character = split[0]
        prompt = f"a {props['font_serifs']} {character} with {props['width']} width {props['rounding']} corners {props['font_weight']} weight and {props['dynamics']} movement with characteristics that can be described by adjectives {example['font_characteristics']}" 
        example['prompt'] = prompt
        return example
    def map_tokens(example):
        prompt = example['prompt']
        tokens = tokenizer.encode(prompt, padding='max_length', max_length=42)
        example['tokens'] = tokens
        return example
    dataset = load_dataset('json', data_files={'train':'train-metadata.jsonl', 'test':'test-metadata.jsonl'})
    
    train_new_column = ['foo'] * len(dataset['train'])
    dataset['train'] = dataset['train'].add_column('prompt', train_new_column)
    dataset['train'] = dataset['train'].add_column('tokens', train_new_column)
    dataset['train'] = dataset['train'].map(add_prompt)
    dataset['train'] = dataset['train'].map(map_tokens)
    #dataset['train'] = dataset['train'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['train'] = dataset['train'].cast_column('image', HuggingFaceImage())
    dataset['train'] = dataset['train'].with_format('torch')
    
    test_new_column = ['bar'] * len(dataset['test'])
    dataset['test'] = dataset['test'].add_column('prompt', test_new_column)
    dataset['test'] = dataset['test'].add_column('tokens', test_new_column)
    dataset['test'] = dataset['test'].map(add_prompt)
    dataset['test'] = dataset['test'].map(map_tokens)
    #dataset['test'] = dataset['test'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['test'] = dataset['test'].cast_column('image', HuggingFaceImage())
    dataset['test'] = dataset['test'].with_format('torch')
    return dataset
def get_vit_model(image_size: int, patch_size: int, dim: int, depth: int, num_heads: int, k: int, device: str):
    sequence_length = (image_size//patch_size)**2 + 1
    # for 512x512px image with 32x32px patches: 16x16 + 1 CLS token
    efficient_transformer = Linformer(
        dim=dim,
        seq_len=sequence_length,  
        depth=depth,
        heads=num_heads,
        k=k
    )
    model = ViT(
        dim=dim,
        image_size=image_size,
        patch_size=patch_size,
        num_classes=62,
        transformer=efficient_transformer,
        channels=1,
    )
    return model 
def get_vit(image_size, patch_size, vit_dim, vit_depth, vit_num_heads, k, device, vit_checkpoint_path):
    vit = get_vit_model(image_size=image_size, 
                        patch_size=patch_size, 
                        dim=vit_dim, 
                        depth=vit_depth, 
                        num_heads=vit_num_heads, 
                        k=k, 
                        device=device)
    if vit_checkpoint_path != None:
        vit_checkpoint = torch.load(vit_checkpoint_path)
        vit.load_state_dict(vit_checkpoint['model_state_dict'])
        print('Loaded ViT model from checkpoint:', vit_checkpoint_path)
    return vit


In [3]:
import os
from transformers import CLIPTextModel
from x_clip_train import get_tokenizer
import torch
from x_clip import CLIP
from vit_pytorch.extractor import Extractor

image_size = 512
patch_size = 32
vit_dim = 128
vit_depth = 12
vit_num_heads = 8
k = 64
vit_checkpoint = './vit-checkpoints/model-512-epoch3.pt'
base_vit = get_vit(image_size, 
                    patch_size, 
                    vit_dim, 
                    vit_depth, 
                    vit_num_heads, 
                    k, 
                    device=None, 
                    vit_checkpoint_path=None)
image_encoder = Extractor(
    base_vit,
    return_embeddings_only = True
)
clip_tokenizer = get_tokenizer(True)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-base-patch32')
text_encoder.resize_token_embeddings(len(clip_tokenizer))
path = os.path.join(os.getcwd(), 'clip-checkpoints', 'clip-epoch-9.pt')
checkpoint = torch.load(path)
clip = CLIP(
    image_encoder = image_encoder,
    text_encoder = text_encoder,
    dim_image=128,
    dim_text=512,
    dim_latent=384,
    text_encode_without_mask=False,
    use_all_token_embeds=False,
    text_has_cls_token=True,
    visual_has_cls_token=True,
    num_text_tokens=text_encoder.vocab_size,
    text_pad_id=clip_tokenizer.pad_token_id,
    text_eos_id=clip_tokenizer.eos_token_id,
    use_mlm=True,
    mlm_mask_token_id=clip_tokenizer.mask_token_id,
    mlm_pad_token_id=clip_tokenizer.pad_token_id,
    mlm_mask_ignore_token_ids=[clip_tokenizer.bos_token_id],
    channels=1
).to('cuda')
clip.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
print('Loaded model from checkpoint:', path)

Added special tokens:  {'mask_token': '<|mask_token|>'}


Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.mlp.fc1.bias', 'vision_model.embeddings.class_embedding', 'vision_model.encoder.layers.7.self_attn.q_proj.bias', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.5.self_attn.v_proj.weight', 'vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.8.mlp.fc2.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.mlp.fc1.bias', 'vision_model.encoder.layers.1.mlp.fc1.bias', 'vision_model.encoder.layers.8.

Loaded model from checkpoint: c:\Users\rinat\Documents\font-diffusion\clip-checkpoints\clip-epoch-9.pt


### Moving forward: Train Decoder (Custom Linformer + CLIP)

In [4]:
import torch
from dalle2_pytorch import Unet, Decoder, CLIP


# unet for the decoder
unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 1,
    dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    channels=1
).cuda()

# mock images (get a lot of this)

images = torch.randn(1, 1, 512, 512).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

torch.Size([1, 65, 128])
KV LEN 65 Seq len 257


AssertionError: the sequence length of the key / values must be 257 - 65 given

### Training Decoder: OpenAI CLIP 

In [4]:
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, OpenAIClipAdapter

clip = OpenAIClipAdapter()


# mock images (get a lot of this)
text = torch.randint(0, 49408, (3, 256)).cuda()
print(text.shape)
images = torch.randn(1, 3, 512, 512).cuda()
print(images.shape)

# unet for the decoder
unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    channels=3
).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

torch.Size([3, 256])
torch.Size([1, 3, 512, 512])


### Training Decoder: DEFAULT ON WEBSITE 

In [5]:
import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    channels=1
).cuda()

# unet for the decoder

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 1,
    dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    channels=1
).cuda()

# mock images (get a lot of this)

images = torch.randn(1, 1, 256, 256).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 8.00 GiB total capacity; 7.20 GiB already allocated; 0 bytes free; 7.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [3]:
dataset

NameError: name 'dataset' is not defined

In [8]:
dataset = prepare_data(get_tokenizer())

Found cached dataset json (C:/Users/rinat/.cache/huggingface/datasets/json/default-1e07ea3eabca7683/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\rinat\.cache\huggingface\datasets\json\default-1e07ea3eabca7683\0.0.0\fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e\cache-54018895e1b774e9.arrow


Map:   0%|          | 0/12090 [00:00<?, ? examples/s]

Map:   0%|          | 0/208 [00:00<?, ? examples/s]

Map:   0%|          | 0/208 [00:00<?, ? examples/s]

In [15]:
dataset['train']

Dataset({
    features: ['image', 'tokens'],
    num_rows: 12090
})