#### Import Package

In [None]:
import torch
import os 
import pandas as pd
import numpy as np
import gzip
import lzma

from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForImageTextRetrieval
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from utils.system import *
from class_data.preprocess import Preprocess
from class_data.image_tensor import ImageTensor

In [None]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") 

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

#### Data

In [None]:
file_pattern = 'all_data_tokenize_*.parquet.brotli'
folder_path = get_data() / 'all' / 'chunks'
all_data = Preprocess(folder_path=folder_path, file_pattern=file_pattern)._concat_files()

#### Prepare Data

In [None]:
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# Tiny Vit Requires Image Size of 224
image_size = 224
batch_size = 6

transform_train = transforms.Compose([
    transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    normalize,
])

In [None]:
dataset = ImageTensor(data=all_data, image_column='image_name', caption_column='caption', transform=transform_train)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

#### Get HR and ATTN

In [None]:
# Save embeddings and return file paths
def save_embedding(dir, embedding, idx, prefix):
    file_path = os.path.join(dir, f"{prefix}_{idx}.npy.gz")
    with gzip.open(file_path, 'wb') as f:
        np.save(f, embedding)
    return f"{prefix}_{idx}.npy.gz"

In [None]:
# Load in Blip Image Captioning Model
processor_caption = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_caption = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float32).to(device)
# Load in Blip Image Retrieval Model
processor_retrieval = BlipProcessor.from_pretrained("Salesforce/blip-itm-large-flickr")
blip_retrieval = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-large-flickr").to(device)

In [None]:
# Params
total_batches = len(dataloader)
all_hr_caption = []
all_hr_retrieval = []
all_attn_caption = []
all_attn_retrieval = []
all_indices = []
export_num = 0

for i, (image, caption, idx) in enumerate(dataloader):
    with torch.no_grad():
        # Log Progress
        print("-"*60)
        print(f"Processing batch: {i+1}/{total_batches}...")

        # Load images to device
        image = image.to(device, non_blocking=True)
        idx = idx.to(device, non_blocking=True)
        
        # Create compatible image for parent model
        parent_image = ((image - image.min()) * (1 / (image.max() - image.min()) * 255)).cpu().numpy().astype('uint8')
    
        # Caption
        inputs_caption = processor_caption(images=parent_image, text="a photography of", return_tensors="pt").to(device, torch.float16)
        outputs_caption = blip_caption.forward(**inputs_caption, output_hidden_states=True, output_attentions=True)
        hr_parent_caption = torch.stack(outputs_caption.hidden_states)
        attn_parent_caption = outputs_caption.attentions[-1]
    
        # Retrieval
        inputs_retrieval = processor_retrieval(images=parent_image, text="a photography of", return_tensors="pt").to(device, torch.float16)
        outputs_retrieval = blip_retrieval.forward(**inputs_retrieval, output_hidden_states=True, output_attentions=True)
        hr_parent_retrieval = torch.stack(outputs_retrieval.hidden_states)
        attn_parent_retrieval = outputs_retrieval.attentions[-1]
        print(hr_parent_retrieval.shape)
        print(attn_parent_retrieval.shape)

        # Store
        for j in range(batch_size):
            print(hr_parent_caption[:, j, :, :].shape)
            print(attn_parent_caption[j, :, :, :].shape)
            all_hr_caption.append(hr_parent_caption[:, j, :, :].to(torch.float16).cpu().numpy())
            all_hr_retrieval.append(hr_parent_retrieval[:, j, :, :].to(torch.float16).cpu().numpy())
            all_attn_caption.append(attn_parent_caption[j, :, :, :].to(torch.float16).cpu().numpy())
            all_attn_retrieval.append(attn_parent_retrieval[j, :, :, :].to(torch.float16).cpu().numpy())
            all_indices.append(idx[j].item())

        # Export in batches
        if i%50==0:
            print("Exporting...")
            embeddings_dir = get_data() / 'blip' / 'emb'
            data = {'idx': [], 'hr_caption_path': [], 'hr_retrieval_path': [], 'attn_caption_path': [], 'attn_retrieval_path': []}
            
            # Save embeddings and get file path name
            for i, idx in enumerate(all_indices):
                data['idx'].append(idx)
                data['hr_caption_path'].append(save_embedding(embeddings_dir, all_hr_caption[i], idx, 'hr_caption'))
                data['hr_retrieval_path'].append(save_embedding(embeddings_dir, all_hr_retrieval[i], idx, 'hr_retrieval'))
                data['attn_caption_path'].append(save_embedding(embeddings_dir, all_attn_caption[i], idx, 'attn_caption'))
                data['attn_retrieval_path'].append(save_embedding(embeddings_dir, all_attn_retrieval[i], idx, 'attn_retrieval'))

            # Sort Index
            parent_store = pd.DataFrame(data).set_index('idx').sort_index()
            # Export Data
            parent_store.to_parquet(get_data() / 'blip' / f'blip_store_{export_num}.parquet.brotli', compression='brotli')

            # Reset Data
            export_num+=1
            all_hr_caption = []
            all_hr_retrieval = []
            all_attn_caption = []
            all_attn_retrieval = []
            all_indices = []
            break    

#### Load HR and ATTN

In [None]:
def load_embedding(dir, filename):
    file_path = os.path.join(dir, filename)
    with gzip.open(file_path, 'rb') as f:
        embedding = np.load(f)
    return embedding

# Usage
embedding_dir = get_data() / 'blip' / 'emb'
filename = 'attn_caption_0.npy.gz'

loaded_embedding = load_embedding(embedding_dir, filename)