In [None]:
from transformers import CLIPTokenizer
from datasets import load_dataset
from datasets import Image as HuggingFaceImage
from linformer import Linformer
from vit_pytorch.efficient import ViT
import torch
def get_tokenizer() -> CLIPTokenizer:
    return CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
def prepare_data(tokenizer: CLIPTokenizer):
    def add_prompt(example):
        props, key = example['font_properties']
        character = example['character']
        split = character.split('_')
        if len(split) > 1:
            character = split[0] + 'case ' + split[1]
        else:
            character = split[0]
        prompt = f"a {example['font_serifs']} {character} with {props} {key}" 
        example['prompt'] = prompt
        return example
    def map_tokens(example):
        prompt = example['prompt']
        tokens = tokenizer.encode(prompt, padding='max_length', max_length=42)
        example['tokens'] = tokens
        return example
    dataset = load_dataset('json', data_files={'train':'train-metadata.jsonl', 'test':'test-metadata.jsonl'})
    
    train_new_column = ['foo'] * len(dataset['train'])
    dataset['train'] = dataset['train'].add_column('prompt', train_new_column)
    dataset['train'] = dataset['train'].add_column('tokens', train_new_column)
    dataset['train'] = dataset['train'].map(add_prompt)
    dataset['train'] = dataset['train'].map(map_tokens)
    dataset['train'] = dataset['train'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['train'] = dataset['train'].cast_column('image', HuggingFaceImage())
    dataset['train'] = dataset['train'].with_format('torch')
    
    test_new_column = ['bar'] * len(dataset['test'])
    dataset['test'] = dataset['test'].add_column('prompt', test_new_column)
    dataset['test'] = dataset['test'].add_column('tokens', test_new_column)
    dataset['test'] = dataset['test'].map(add_prompt)
    dataset['test'] = dataset['test'].map(map_tokens)
    dataset['test'] = dataset['test'].remove_columns(['prompt', 'uniqueId', 'ttf_path', 'font_characteristics', 'font_properties', 'character', 'vit_label'])
    dataset['test'] = dataset['test'].cast_column('image', HuggingFaceImage())
    dataset['test'] = dataset['test'].with_format('torch')
    return dataset
def get_vit_model(image_size: int, patch_size: int, dim: int, depth: int, num_heads: int, k: int, device: str):
    sequence_length = (image_size//patch_size)**2 + 1
    # for 512x512px image with 32x32px patches: 16x16 + 1 CLS token
    efficient_transformer = Linformer(
        dim=dim,
        seq_len=sequence_length,  
        depth=depth,
        heads=num_heads,
        k=k
    )
    model = ViT(
        dim=dim,
        image_size=image_size,
        patch_size=patch_size,
        num_classes=62,
        transformer=efficient_transformer,
        channels=1,
    )
    return model 
def get_vit(image_size, patch_size, vit_dim, vit_depth, vit_num_heads, k, device, vit_checkpoint_path):
    vit = get_vit_model(image_size=image_size, 
                        patch_size=patch_size, 
                        dim=vit_dim, 
                        depth=vit_depth, 
                        num_heads=vit_num_heads, 
                        k=k, 
                        device=device)
    if vit_checkpoint_path != None:
        vit_checkpoint = torch.load(vit_checkpoint_path)
        vit.load_state_dict(vit_checkpoint['model_state_dict'])
        print('Loaded ViT model from checkpoint:', vit_checkpoint_path)
    return vit

In [None]:
import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, TrOCRProcessor, VisionEncoderDecoderModel
from torchvision.transforms import Compose, Resize, Normalize, ToTensor

# Define image transformation
image_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# # Load pretrained CLIP model and processor
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

#microsoft/trocr-large-handwritten works 
processor = TrOCRProcessor.from_pretrained("anaghasavit/trocr-processor")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")

In [None]:
from x_clip_train import prepare_batch, get_dataloaders, get_tokenizer, prepare_data
clip_tokenizer = get_tokenizer(True)
dataset = prepare_data(clip_tokenizer)
train_dataset, test_dataset = dataset['train'], dataset['test']
train_loader, valid_loader = get_dataloaders(train_dataset, test_dataset, 32)

In [None]:
from PIL import Image
# Inference function
def generate_caption(image_pil, processor, model, top_k=1):

    if not isinstance(image_pil, list):
        image_pil = [image_pil]

    inputs = processor(images=image_pil, return_tensors="pt", padding=True, max_length=77)
    
    # Create dummy input_ids tensor of zeros with the same batch size as the pixel_values
    batch_size = inputs["pixel_values"].shape[0]
    input_ids_dummy = torch.zeros((batch_size, 1), dtype=torch.long)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids_dummy, **inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=-1)
        
    top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
    captions = [processor.decode(idx) for idx in top_indices.squeeze().tolist()]
    return captions

In [None]:
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import torchvision.transforms.functional as TF

# Iterate over the dataset
for batch in valid_loader:
    batch_imgs, batch_tokens = prepare_batch(batch)
    batch_imgs, batch_tokens = batch_imgs.to('cuda'), batch_tokens.to('cuda')
    for i in range(batch_imgs.shape[0]):
        image_tensor = batch_imgs[i].cpu()
        # rescaled_image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())  # Rescale to [0, 1]
        grid = make_grid(image_tensor, nrow=1, normalize=True, scale_each=True, padding=0)
        image_pil = TF.to_pil_image(grid)
        inputs = processor(images=image_pil, return_tensors="pt")
        outputs = model.generate(**inputs)
        print(outputs)
        caption = processor.decode(outputs.squeeze(), skip_special_tokens=True)
        print("Generated Caption:", caption)
        display(image_pil)
        break
    break