In [1]:
import requests
import json
from io import BytesIO
from PIL import Image
import random
import uuid
from hashlib import sha256
from tqdm import tqdm
import time
import os

from datasets import load_dataset

random.seed( 42 )

In [20]:
IMAGE_PER_CLASS = 500

OUTPUT_IMAGE_PATH = "./test_images/"
IMAGE_DATASET_PATH = "/scratch/spp9399/mia/pixmo/molmo/"

## HF Dataset

In [4]:
dataset = load_dataset("/scratch/spp9399/mia/pixmo/molmo/data/molmo/torch_datasets/pixmo_datasets/ask-model-anything", split='train')

Resolving data files:   0%|          | 0/33 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/33 [00:00<?, ?it/s]

In [5]:
dataset

Dataset({
    features: ['image_url', 'image', 'question', 'answer'],
    num_rows: 56248
})

In [6]:
dataset[0]

{'image_url': 'https://cdn.wallpapersafari.com/40/8/AYhcIr.jpg',
 'image': './data/molmo/torch_datasets/pixmo_images/6d83e897c03fa5c199fd03b537fd4894118915a73b3624229b74b4f8dd94074f',
 'question': ['Can you challenge this bird to a wrestling match in the style of macho man Randy savage?'],
 'answer': ["Ooh yeah, dig it! Listen up, you ebony-feathered menace! The Macho Man Randy Savage is here to challenge you to the ultimate showdown, you Common Raven perched on that branch!\n\nYou think you're the cream of the crop with those glossy black feathers and beady eyes? Well, let me tell you somethin', bird brain! The Macho Man's gonna soar higher than you ever could, dropping elbow after elbow from the heavens!\n\nIn the squared circle, I'm gonna pluck your pride and ground your flight! You may have a sharp beak, but I've got biceps that'll make your wings look like toothpicks! \n\nThe Macho Man's gonna teach you the meaning of sky-high pain! You'll be squawking my name when I unleash my si

In [7]:
# Shuffle the data
dataset = dataset.shuffle(seed=42)

# Create 500 examples for sequence, collage and pic-in-pic
# Save the sha256 for each selected images!

In [9]:
dataset[0]

{'image_url': 'http://www.hambledonsurrey.co.uk/wp-content/uploads/2020/05/IMG_0903-1536x1152.jpeg',
 'image': './data/molmo/torch_datasets/pixmo_images/cffee398f971f6d0c19652826b27bcb21d0b784fda150bf79c65a6b77c9a1ff0',
 'question': ['Where is the camera in this scene?',
  'When was this photograph taken?',
  'Where is the security camera?',
  'What is the name of the company that supplied the newspapers?'],
 'answer': ["The camera in this scene is not directly visible. However, we can infer its approximate position based on the elements in the image. The newspapers are laid out on a table, and we can see a window with a CCTV sticker on it. Given this view, it's likely that the camera is positioned slightly above and in front of the table, capturing the newspapers and the window behind them. This suggests the camera is probably handheld or on a tripod, positioned to take in the full scene of the newspaper display and its surroundings.",
  'While an exact date isn\'t visible, we can inf

## Sequence
In sequence, split the data equally between (2, 3, 4, 5) number of images

In [21]:
def get_images( x ):
    """
        Given indices, choose randomly x until |set( sha )| != |images| 
    """
    images_index = random.sample(range(len(dataset)), x)
    images = [dataset[i] for i in images_index]
    
    return images

In [26]:
def sequence_class():
    sequences = [2, 3, 4, 5]

    images_per_sequence = IMAGE_PER_CLASS // len(sequences)

    res = []
    
    for seq in sequences:
        # Randomly samples images based on seq, maintain used_index. Check for sha to make sure images are not same
        for _ in tqdm(range( images_per_sequence ), desc="Sequence " + str(seq)):
            images = get_images( seq )
            
            # All images are dissimilar (similar means that images are same but qa pairs are different)
            data = {}
            
            # Randomly choose an image index whose qa pair will be used!
            image_index = random.choice( range(len(images)) )
            image_length = len( images )
    
            data['id'] = uuid.uuid4().hex
            data['image'] = [
                "test_images/" + data['id'] + ".png"
                for i in images
            ]
    
            for i in images:
                img = Image.open(
                    IMAGE_DATASET_PATH + i["image"]
                )
                if img.mode == "CMYK":
                    img = img.convert("RGB")  # PNGs can't be saved in CMYK

                img.save( OUTPUT_IMAGE_PATH + data['id'] + ".png", "PNG" )
    
            image_index = image_index
    
            if image_length == 2:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\n' + f'In Image{image_index + 1}, ' 
            elif image_length == 3:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\n' + f'In Image{image_index + 1}, '
            elif image_length == 4:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\nImage4:<image>\n'+ f'In Image{image_index + 1}, '
            elif image_length == 5:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\nImage4:<image>\nImage5:<image>\n' + f'In Image{image_index + 1}, '
            
            ques_index = random.choice( range(len(images[image_index]['question'])) ) # If there are multiple questions, then select one
            data['prompt'] += images[image_index]["question"][ques_index]
            data['answer'] = images[image_index]["answer"][ques_index]

            data['image_url'] = " ".join( [ i['image_url'] for i in images ] )
            
            res.append( data )
    return res

In [27]:
sequence_data = sequence_class()

Sequence 2: 100%|██████████| 125/125 [01:33<00:00,  1.34it/s]
Sequence 3: 100%|██████████| 125/125 [01:42<00:00,  1.22it/s]
Sequence 4: 100%|██████████| 125/125 [02:32<00:00,  1.22s/it]
Sequence 5: 100%|██████████| 125/125 [02:41<00:00,  1.29s/it]


In [29]:
with open("./sequence_test_images.json", "w") as f:
    json.dump(sequence_data, f, indent=4)  # indent=4 for pretty formatting

## Collage 

In [None]:
def get_images( x ):
    """
        Given indices, choose randomly x until |set( sha )| != |images| 
    """
    images_index = random.sample(range(len(dataset)), x)
    images = [dataset[i] for i in images_index]
    
    return images

def sequence_class():
    sequences = [2, 3, 4, 5]

    images_per_sequence = IMAGE_PER_CLASS // len(sequences)

    res = []
    
    for seq in sequences:
        # Randomly samples images based on seq, maintain used_index. Check for sha to make sure images are not same
        for _ in tqdm(range( images_per_sequence ), desc="Sequence " + str(seq)):
            images = get_images( seq )
            
            # All images are dissimilar (similar means that images are same but qa pairs are different)
            data = {}
            
            # Randomly choose an image index whose qa pair will be used!
            image_index = random.choice( range(len(images)) )
            image_length = len( images )
    
            data['id'] = uuid.uuid4().hex
            data['image'] = [
                "test_images/" + data['id'] + ".png"
                for i in images
            ]
    
            for i in images:
                img = Image.open(
                    IMAGE_DATASET_PATH + i["image"]
                )
                if img.mode == "CMYK":
                    img = img.convert("RGB")  # PNGs can't be saved in CMYK

                img.save( OUTPUT_IMAGE_PATH + data['id'] + ".png", "PNG" )
    
            image_index = image_index
    
            if image_length == 2:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\n' + f'In Image{image_index + 1}, ' 
            elif image_length == 3:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\n' + f'In Image{image_index + 1}, '
            elif image_length == 4:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\nImage4:<image>\n'+ f'In Image{image_index + 1}, '
            elif image_length == 5:
                data['prompt'] = 'Image1:<image>\nImage2:<image>\nImage3:<image>\nImage4:<image>\nImage5:<image>\n' + f'In Image{image_index + 1}, '
            
            ques_index = random.choice( range(len(images[image_index]['question'])) ) # If there are multiple questions, then select one
            data['prompt'] += images[image_index]["question"][ques_index]
            data['answer'] = images[image_index]["answer"][ques_index]

            data['image_url'] = " ".join( [ i['image_url'] for i in images ] )
            
            res.append( data )
    return res

## Pic-in-Pic