In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
from stable_diffusion3 import UniLatentPipeline, retrieve_timesteps
from diffusers import StableDiffusion3Pipeline as OGStableDiffusion3Pipeline

from diffusion.data.builder import build_dataset, build_dataloader
from diffusion.utils.data_sampler import AspectRatioBatchSampler
from torch.utils.data import RandomSampler

from tqdm import tqdm
from diffusers import get_cosine_schedule_with_warmup
from accelerate import Accelerator
from transformers import (
    GPT2Tokenizer,
    CLIPVisionModel,
    CLIPImageProcessor,
)

from caption_decoder import TextDecoder



In [3]:
pipe = UniLatentPipeline.from_pretrained('/mnt/bn/us-aigc-temp/henry/data/clip_test/', 
                    device_map=None, low_cpu_mem_usage=False, torch_dtype=torch.float32)
pipe.text_decoder = TextDecoder.from_pretrained('/mnt/bn/us-aigc-temp/henry/data/clip2text_2gpu/epoch_0_step_7499/', 
                        low_cpu_mem_usage=False, device_map=None)

<super: <class 'UniLatentPipeline'>, <UniLatentPipeline object>> ('/mnt/bn/us-aigc-temp/henry/data/clip_test/',) {'device_map': None, 'low_cpu_mem_usage': False, 'torch_dtype': torch.float32}


Loading pipeline components...:   0%|          | 0/11 [00:00<?, ?it/s]

Some weights of TextDecoder were not initialized from the model checkpoint at /mnt/bn/us-aigc-temp/henry/data/clip_test/text_decoder and are newly initialized: ['transformer.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of TextDecoder were not initialized from the model checkpoint at /mnt/bn/us-aigc-temp/henry/data/clip2text_2gpu/epoch_0_step_7499/ and are newly initialized: ['transformer.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
data_config = {
    'type': 'FlexibleInternalDataMS',
    'roots': [
        '/mnt/bn/us-aigc-temp/henry/coco_2014/val/val2014/',
        # '/mnt/bn/aigc-us/zjl/laion-coco-aesthetic/data_max1024/',
        # '/mnt/bn/aigc-us/zjl/openimages/data/',
        # '/mnt/bn/aigc-us/zjl/sharegpt4v_processed_data/data/'
    ],
    'json_lst': [
        '/mnt/bn/us-aigc-temp/henry/test.json',
    ],
    'load_vae_feat': False,
    'load_t5_feat': False
}
dataset = build_dataset(
    data_config, resolution=512, aspect_ratio_type='ASPECT_RATIO_512',
    real_prompt_ratio=0.0, max_length=77, return_image_id=True
)
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset,
                                    batch_size=1, aspect_ratios=dataset.aspect_ratio, drop_last=True,
                                    ratio_nums=dataset.ratio_nums, valid_num=0)
dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=10)

2024-07-24 02:36:56,492 - PixArt - INFO - Constructing dataset FlexibleInternalDataMS...
2024-07-24 02:36:56,494 - PixArt - INFO - T5 max token length: 77
2024-07-24 02:36:56,494 - PixArt - INFO - ratio of real user prompt: 0.0
2024-07-24 02:36:56,525 - PixArt - INFO - /mnt/bn/us-aigc-temp/henry/test.json data volume: 5000
2024-07-24 02:36:56,546 - PixArt - INFO - Dataset FlexibleInternalDataMS constructed. time: 0.05 s, length (use/ori): 5000/5000


In [5]:
from torch.nn.parallel import DistributedDataParallel as DDP
accelerator = Accelerator(
        mixed_precision='fp16',
    )

pipe.transformer = accelerator.prepare(pipe.transformer)
pipe.text_encoder, pipe.text_encoder_2 = accelerator.prepare(pipe.text_encoder, pipe.text_encoder_2)
pipe.clip_image_encoder, pipe.text_decoder = accelerator.prepare(pipe.clip_image_encoder, pipe.text_decoder)

Detected kernel version 5.4.143, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
len(dataloader)

5000

In [8]:
json_list

[{'image_id': tensor([545385]),
  'caption': 'A piece of cake on a plate with a fork.'}]

In [9]:
import json

json_list = []
iterloader = iter(dataloader)
for i in range(len(dataloader)):
    batch = next(iterloader)

    with torch.no_grad():
        image_embd = pipe.encode_image(batch[0][:1], device=accelerator.device, dtype=torch.float16)
        generate_captions = pipe.text_decoder.module.generate_captions if isinstance(pipe.text_decoder, DDP) else pipe.text_decoder.generate_captions
        decoded_tokens = generate_captions(image_embd, 
                            eos_token_id=pipe.decoder_tokenizer.eos_token_id, device=accelerator.device)[0]
        decoded_text = pipe.decoder_tokenizer.batch_decode(decoded_tokens)[0]
    
    caption = decoded_text.strip('!').strip('<|endoftext|>').strip('<|EOS|>').strip()
    json_list.append({'image_id': batch[-1]['image_id'].item(), 'caption': caption})
    print(f"Image: {i:05d} | Predicted: {caption} | True: {batch[1][0]}")

    if (i + 1) % 50 == 0:
        save_path = '/mnt/bn/us-aigc-temp/henry/clip2text_captions.json'
        with open(save_path, 'w') as f:
            test = json.dump(json_list, f)

Image: 00000 | Predicted: Two people riding surfboards on the waves. | True: A man riding a wave on a surfboard.
Image: 00001 | Predicted: A cat sitting on top of a bed next to a bed. | True: A brown and white cat lying on a bed
Image: 00002 | Predicted: A boy in a coat is standing next to an old suitcase. | True: A black and white photo of a child putting on gloves next to a suitcase.
Image: 00003 | Predicted: Two people riding on the back of a motorcycle. | True: A couple riding a motorcycle down a street.
Image: 00004 | Predicted: An old computer sitting on top of a desk. | True: A computer desk topped with a large monitor.
Image: 00005 | Predicted: A man in a black shirt and blue tie. | True: The man with black outfit and royal blue necktie poses for a photo at the event
Image: 00006 | Predicted: A man sitting at a table with a pizza. | True: The man smiles as someone is cutting a pizza.
Image: 00007 | Predicted: A young boy playing frisbee in the park. | True: A young boy standing

In [None]:
# json_list = []
# iterloader = iter(dataloader)
# for i in range(len(dataloader)):
#     batch = next(iterloader)

#     with torch.no_grad():
#         image_embds = pipe.encode_image(batch[0], device=accelerator.device, dtype=torch.float16)
#         for image_embd, image_id in zip(image_embds, batch[-1]['image_id']):
#             print(image_embd.norm())
#             generate_captions = pipe.text_decoder.module.generate_captions if isinstance(pipe.text_decoder, DDP) else pipe.text_decoder.generate_captions
#             decoded_tokens = generate_captions(image_embd[None], 
#                                 eos_token_id=pipe.decoder_tokenizer.eos_token_id, device=accelerator.device)[0]
#             decoded_text = pipe.decoder_tokenizer.batch_decode(decoded_tokens)[0]
    
#             caption = decoded_text.strip('!').strip('<|endoftext|>').strip('<|EOS|>').strip()
#             json_list.append({'image_id': image_id, 'caption': caption})
#             print(f"Image: {i:05d} | Predicted: {caption} | True: {batch[1][0]}")

In [None]:
import json

save_path = '/mnt/bn/us-aigc-temp/henry/clip2text_captions.json'
with open(save_path, 'w') as f:
    test = json.dump(json_list, f)