In [1]:
import requests
import os
import argparse
import json

from omegaconf import OmegaConf
from typing import Optional
from PIL import Image

from transformers import CLIPProcessor, CLIPModel
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel

import numpy as np
import torch
from torch import nn
import torchvision
from torchvision.io import read_video
from torchvision.transforms import Resize, Normalize
from torch.utils.data import DataLoader

from dataset.base_dataset import load_dataset
from configs.config import Config

  warn(


In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='LBA method')
    parser.add_argument("--cfg-path", default='configs/runner.yaml', help="path to configuration file.")
    # verbose
    parser.add_argument('--verbose', action='store_true', help='verbose')
    
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    
    args = parser.parse_args()
    return args


In [3]:
def video_to_tensor(video_path, target_size=(224, 224)):
    # Read the video file
    video, audio, info = read_video(video_path)
    
    # video tensor shape: (T, H, W, C)
    # T: number of frames, H: height, W: width, C: channels (usually 3 for RGB)
    
    # Permute dimensions to (T, C, H, W) as expected by most PyTorch models
    video = video.permute(0, 3, 1, 2)
    
    # Convert to float and scale to [0, 1]
    # video = video.float() / 255.0
    
    # Resize the frames
    resize = Resize(target_size)
    video = resize(video)
    
    return video


In [4]:
def read_image_using_clip(model, preprocess, image_paths):
    if isinstance(image_paths, str):
        image_paths = [image_paths]
    
    image_features = []
    for image_path in image_paths:
        pil_image = Image.open(image_path).convert('RGB')
        import pdb; pdb.set_trace()
        image = preprocess(images=pil_image).unsqueeze(0).to(device)

        # Extract features
        with torch.no_grad():
            image_feature = model.encode_image(image)
        
        image_features.append(image_feature)        
    
    return torch.stack(image_features, dim=0)


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
cache_dir = '/model/Salesforce'

# clip_model_id = "openai/clip-vit-large-patch14"
# clip_model = CLIPModel.from_pretrained(
#     clip_model_id, 
#     cache_dir=os.path.join(cfg.model_cfg.cache_dir, clip_model_id.split('/')[0])
# ).to(device)
# clip_processor = CLIPProcessor.from_pretrained(clip_model_id)
# clip_model.eval()

model_id = 'Salesforce/blip2-flan-t5-xl'
processor = Blip2Processor.from_pretrained(model_id)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_id, 
    cache_dir=cache_dir
).to(device)
model.eval()
blip2visionmodel = model.vision_model
# Blip2VisionModel(
#     model_id, 
#     cache_dir=os.path.join(cfg.model_cfg.cache_dir, model_id.split('/')[0])
# ).to(device)
blip2visionmodel.eval()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Blip2VisionModel(
  (embeddings): Blip2VisionEmbeddings(
    (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
  )
  (encoder): Blip2Encoder(
    (layers): ModuleList(
      (0-38): 39 x Blip2EncoderLayer(
        (self_attn): Blip2Attention(
          (dropout): Dropout(p=0.0, inplace=False)
          (qkv): Linear(in_features=1408, out_features=4224, bias=True)
          (projection): Linear(in_features=1408, out_features=1408, bias=True)
        )
        (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Blip2MLP(
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
        )
        (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      )
    )
  )
  (post_layernorm): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
)

In [6]:
datasets_cfg = {'dataset_name': 'DramaQA', 'num_data': -1, 'split': 'val', 'data_type': 'videos', 'n_frms': 4, 'vqa_acc': False, 'only_scene': False, 'ann_paths': {'train': ['/data/AnotherMissOh/AnotherMissOhQA_train_set.json', '/data/flipped_vqa/dramaqa/clipvitl14.pth'], 'val': ['/data/AnotherMissOh/AnotherMissOhQA_val_set.json', '/data/flipped_vqa/dramaqa/clipvitl14.pth']}, 'vis_root': '/data/AnotherMissOh/images/', 'vis_processor': {'train': {'name': 'alpro_video_train', 'n_frms': 5, 'image_size': 224, 'min_scale': 0.9, 'max_scale': 1.0}, 'eval': {'name': 'alpro_video_eval', 'n_frms': 5, 'image_size': 224, 'min_scale': 0.9, 'max_scale': 1.0}}, 'text_processor': {'train': {'name': 'blip_question'}, 'eval': {'name': 'blip_caption'}}}
from omegaconf import OmegaConf
datasets_cfg = OmegaConf.create(datasets_cfg)

In [7]:
dataset = load_dataset(datasets_cfg, split='val')
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=dataset.collater)

vid_set = []
for i, batch in enumerate(dataloader):
    if i == 0:
        print('batch.keys():', batch.keys())
        for k, v in batch.items():
            if hasattr(v, "shape"):
                print(f'{k:<20s}: {v.shape}')
            elif isinstance(v, list) and hasattr(v[0], "shape"):
                print(f'{k:<20s}: {len(v)} {v[0].shape}')
            elif isinstance(v, list) and isinstance(v[0], list) and hasattr(v[0][0], "shape"):
                print(f'{k:<20s}: {len(v)} {len(v[0])} {v[0][0].shape}')
            elif k != "candidate_list":
                print(f'{k:<20s}: {v}')
    else:
        vid_set.extend(batch['vid'])
        
vid_set = list(set(vid_set))


  3888/  3889 : AnotherMissOh15_039_0000
DramaQAEvalDataset
vis_processor :  None
text_processor :  None
vis_root :  /data/AnotherMissOh/images/
ann_paths :  ['/data/AnotherMissOh/AnotherMissOhQA_val_set.json', '/data/flipped_vqa/dramaqa/clipvitl14.pth']
type(self.annotation), len(self.annotation): <class 'list'> 3889
type(self.vis_features), len(self.vis_features): <class 'dict'> 23125
batch.keys(): dict_keys(['image', 'text_input', 'question_id', 'gt_ans', 'candidate_list', 'answer_sentence', 'type', 'vid'])
image               : 2 4 (768, 1024, 3)
text_input          : ['How is the relationship between Haeyoung1 and Dokyung when the two hug and kiss each other?', 'Why did Haeyoung1 lean against the wall when Haeyoung1 was walking in the alley with Dokyung?']
question_id         : [3205, 3200]
gt_ans              : [0, 3]
answer_sentence     : ['Haeyoung1 and Dokyung are in love and the two went through many things before starting to date.', 'Because Haeyoung1 tried to recall the tim

In [8]:
def read_image_as_one_frame_video_to_tensor(image_paths, target_size=(224, 224)):
    if isinstance(image_paths, str):
        image_paths = [image_paths]
    
    totensor = torchvision.transforms.ToTensor()
    
    result = []
    
    for image_path in image_paths:
        pil_image = Image.open(image_path).convert('RGB')
        tensor = totensor(pil_image)
        result.append(tensor)
        
    return torch.stack(result, dim=0)


In [11]:
def extract(processor, blip2visionmodel, image_paths):
    if isinstance(image_paths, str):
        image_paths = [image_paths]
    
    # List[str] -> List[PIL.Image]
    pil_image_list = [Image.open(image_path).convert('RGB') for image_path in image_paths]
    
    # -> <'BatchFeature'>['pixel_values']: List of torch.Tensor. 4 (3, 224, 224)
    processed = processor(images=np.stack(pil_image_list, axis=0))
    print('processed:', type(processed), type(processed['pixel_values']), len(processed['pixel_values']))
    print("processed['pixel_values'][0]:", type(processed['pixel_values'][0]), processed['pixel_values'][0].shape)
    
    totensor = torchvision.transforms.ToTensor()
    tensor = torch.tensor(np.stack(processed['pixel_values'], axis=0)).to(device)
    
    # return tensor
    
    return blip2visionmodel(tensor, return_dict=True).last_hidden_state
    
    processed = torch.stack(processed['pixel_values'], dim=0)
    return blip2visionmodel(processed['pixel_values'], return_dict=True).last_hidden_state


vid = 'AnotherMissOh13_008_0313'
image_paths = dataset.get_image_path(vid)
feature = extract(processor, blip2visionmodel, image_paths)
print(type(feature), feature.shape)

processed: <class 'transformers.image_processing_base.BatchFeature'> <class 'list'> 4
processed['pixel_values'][0]: <class 'numpy.ndarray'> (3, 224, 224)
<class 'torch.Tensor'> torch.Size([4, 257, 1408])


In [10]:
feature.shape

torch.Size([4, 257, 1408])

In [None]:
features = {}

for i, vid in enumerate(vid_set):
    if vid != 'AnotherMissOh13_008_0313':
        continue
    
    print(f'{i+1}/{len(vid_set)} : {vid}')
    if datasets_cfg.dataset_name == 'DramaQA':
        image_paths = dataset.get_image_path(vid)
        feature = extract(processor, blip2visionmodel, image_paths)
        # feature = image_extract(processor, blip2visionmodel, image_paths)
        
        # features[vid] = read_image_as_one_frame_video_to_tensor(image_paths)
        # features[vid] = read_image_using_clip(clip_model, clip_processor, image_paths)
        # for img_path in image_paths:
        #     # frms.append(Image.open(img_path))
        #     frms.append(np.array(Image.open(img_path)))
        # if len(frms) < cfg.datasets_cfg.n_frms:
        #     # frms = [Image.new('RGB', frms[0].size)] * (self.n_frms - len(frms)) + frms
        #     frms = [np.zeros_like(frms[0])] * (self.n_frms - len(frms)) + frms
        features[vid] = feature
    else:
        vpath = os.path.join(cfg.datasets_cfg.vis_root, f'{vid}.mp4')
        if os.path.exists(vpath):
            features[vid] = torch.load(vpath)
        else:
            print(f'{vpath} not exists')


# video = features['AnotherMissOh13_008_0313']
# inputs = processor(video, questions, return_tensors="pt", padding=True).to("cuda")  # , torch.float16)

# out = model.generate(**inputs)
# print(out)
# print(processor.batch_decode(out, skip_special_tokens=True))


1928/2471 : AnotherMissOh13_008_0313
processed: <class 'transformers.image_processing_base.BatchFeature'> <class 'list'> 4
processed['pixel_values'][0]: <class 'numpy.ndarray'> (3, 224, 224)


In [None]:
print(type(feature), feature.shape)

<class 'torch.Tensor'> torch.Size([4, 257, 1408])
