In [1]:
import contextlib
import requests
import torch
import transformers
from PIL import Image

contextlib.ExitStack().enter_context(torch.inference_mode())

In [2]:
model_name = "llava-hf/llava-1.5-7b-hf"

model = transformers.LlavaForConditionalGeneration.from_pretrained(
    model_name,
    device_map=7,
    quantization_config=transformers.BitsAndBytesConfig(load_in_8bit=True),
)
processor = transformers.AutoProcessor.from_pretrained(model_name)

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


In [3]:
from datasets import load_dataset

imagenet_train = load_dataset('Maysee/tiny-imagenet', split='train')
imagenet_val_combined = load_dataset('Maysee/tiny-imagenet', split='valid')

In [4]:
imagenet_val_test = imagenet_val_combined.train_test_split(test_size=0.5, stratify_by_column='label')
imagenet_val = imagenet_val_test['train']
imagenet_test = imagenet_val_test['test']

In [5]:
def get_hidden_states(image, text):
    conv = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": text},
        ]},
    ]
    prompt = processor.apply_chat_template(conv, add_generation_prompt=True)

    inp = processor(image, prompt, return_tensors="pt").to(model.device)

    output = model(
        **inp,
        output_hidden_states=True,
        num_logits_to_keep=1,
        use_cache=False,
    )
    return torch.vstack(output.hidden_states)

In [6]:
import torch
from torch.utils.data import Dataset

class ImageNetDataset(Dataset):
    def __init__(self, huggingface_dataset, transform=None):
        """
        Args:
            huggingface_dataset: Our ImageNet dataset from huggingface
            transform: Potential transformation for the images
        """
        self.dataset = huggingface_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['label']

        # Apply the transform if specified
        if self.transform:
            image = self.transform(image)

        return image, label, idx

In [9]:
from torch.utils.data import DataLoader
from tqdm import tqdm 
import h5py

from torch.utils.data import DataLoader
from torchvision import transforms

# Example transformation function: this is for use with Vision Transformers
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x), # apparently some images are not RGB
])

train_dataset = ImageNetDataset(imagenet_train, transform=transform)
train_dataloader = DataLoader(train_dataset, shuffle=True)

hdf5_file = h5py.File("hidden_states.h5", "w")

for image, label, idx in tqdm(train_dataloader):
    hidden_states = get_hidden_states(image, "describe the image")
    hidden_states = hidden_states[-1]
    print(hidden_states.shape)
    torch.save(hidden_states, f"/raid/lawrence/hidden_states/{label}_{idx}.pt")
    break

  0%|                                                                                                                                     | 0/100000 [00:00<?, ?it/s]It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
  0%|                                                                                                                                     | 0/100000 [00:33<?, ?it/s]

torch.Size([592, 4096])





    conv = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "describe the image"},
    ]},
]
prompt = processor.apply_chat_template(conv, add_generation_prompt=True)

im_url = "https://llava-vl.github.io/static/images/view.jpg"

with requests.get(im_url, stream=True) as im:
    with Image.open(im.raw) as im:
        inp = processor(im, prompt, return_tensors="pt")

In [76]:
streamer = transformers.TextStreamer(
    processor.tokenizer,
    True,
    skip_special_tokens=True,
)
output1 = model.generate(
    **inp.to(model.device),
    max_new_tokens=1,
    streamer=streamer,
    output_hidden_states=True,
    return_dict_in_generate=True,
)

hidden1 = torch.vstack(output1.hidden_states[0])

The


In [77]:
streamer = transformers.TextStreamer(
    processor.tokenizer,
    True,
    skip_special_tokens=True,
)
output2 = model.generate(
    **inp.to(model.device),
    max_new_tokens=10,
    streamer=streamer,
    output_hidden_states=True,
    return_dict_in_generate=True,
)

hidden2 = torch.vstack(output2.hidden_states[0])

The image features a pier extending out into a large


In [79]:
hidden1[-1] == hidden2[-1]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [9]:
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, text=prompt, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=1)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some kwargs in processor config are unused and will not have any effect: num_additional_image_tokens. 


"USER:  \nWhat's the content of the image? ASSISTANT: The"