In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
import matplotlib.pyplot as plt
from PIL import Image
from types import SimpleNamespace
import random
import textwrap
import ipywidgets as widgets
from IPython.display import display
import uuid
from datetime import datetime

filename = "sis/val.story-in-sequence.json"
image_folder = "images/val"
output_folder = "story_plots_val_horizontal/"

with open(filename, 'r') as file:
    data = json.load(file)

images_data = data['images']
albums_data = data.get('albums', [])
annotations_data = data['annotations']

images = [json.loads(json.dumps(image), object_hook=lambda d: SimpleNamespace(**d)) for image in images_data]
albums = [json.loads(json.dumps(album), object_hook=lambda d: SimpleNamespace(**d)) for album in albums_data]

annotations = [[json.loads(json.dumps(item), object_hook=lambda d: SimpleNamespace(**d)) for item in annotation_list] for annotation_list in annotations_data]

for img in images[:5]:
    print(f"Title: {img.title}, ID: {img.id}, URL: {img.url_o}")

for album in albums[:5]:
    print(f"Description: {album.description}, ID: {album.id}, Title: {album.title}")

if annotations:
    for annotation in annotations[:5]:
        annotation = annotation[0]
        print(f"Original Text: {annotation.original_text}, Story ID: {annotation.story_id}, Photo ID: {annotation.photo_flickr_id}")

annotations_dict = {}
for annotationArray in annotations:
    for annotation in annotationArray:
        story_id = annotation.story_id
        if story_id not in annotations_dict:
            annotations_dict[story_id] = []
        annotations_dict[story_id].append(annotation)

story_ids = list(annotations_dict.keys())

Title: Fourth of July prerequisite, ID: 694227468, URL: https://farm2.staticflickr.com/1125/694227468_f6c433d7d8_o.jpg
Title: Spectacular fireworks, ID: 694227344, URL: https://farm2.staticflickr.com/1330/694227344_58d54d3732_o.jpg
Title: Bubba Burgers, ID: 694227412, URL: https://farm2.staticflickr.com/1008/694227412_001b568f92_o.jpg
Title: "On guard!", ID: 694227488, URL: https://farm2.staticflickr.com/1302/694227488_bb07200c72_o.jpg
Title: BBQin', ID: 694227508, URL: https://farm2.staticflickr.com/1307/694227508_df12d3b4fb_o.jpg
Description: Bubba burgers, beer and BBQ make for a great Fourth of July, ID: 72157600601428727, Title: Fourth of July 2007
Description: , ID: 72157601202851033, Title: WMATA - WDC - 04 July 2007
Description: July 4th images from Brandon parade and fireworks., ID: 72157594187408667, Title: Brandon FL 7/4/2006
Description: Fireworks and friends!, ID: 72157594234060064, Title: 2006 July 4th
Description: just a glimpse into the happenings on and around the day 

In [2]:
import os
import torch
import matplotlib.pyplot as plt
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from PIL import Image

def get_story_images(story_id, annotations_dict, images, image_folder):
    story_annotations = sorted(
        annotations_dict.get(story_id, []),
        key=lambda ann: ann.worker_arranged_photo_order
    )
    story_images = [
        img for img in images if img.id in [ann.photo_flickr_id for ann in story_annotations]
    ]
    image_filenames = []
    for image in story_images:
        image_path = os.path.join(image_folder, f"{image.id}.jpg")
        if os.path.exists(image_path):
            image_filenames.append(f"{image.id}.jpg")
        else:
            print("Files not exist", image_path)
    return image_filenames

In [3]:
import torch
import os
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

def qwen_test_few_shots(
    images_list_of_lists,
    prompts,
    model_name="Qwen/Qwen2-VL-7B-Instruct",
    image_dir="images/val/",
    resize_to=(224, 224),
    device_index=0
):
    """
    images_list_of_lists: list of lists of image filenames
        E.g. [
          ["img1.jpg","img2.jpg"],  # Turn 1 images
          ["img3.jpg"],            # Turn 2 images
          ...
        ]
    
    prompts: list of dictionaries with "role" and "text"
        E.g. [
          {"role": "system", "text": "..."},
          {"role": "user", "text": "..."},
        ]
    """
    device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    model = Qwen2VLForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",  # or "sequential" if you have specific needs
        use_cache=False
    ).eval()
    for param in model.parameters():
        param.requires_grad = False

    processor = AutoProcessor.from_pretrained(model_name)

    # Build the multi-turn conversation structure
    conversation = []
    all_processed_images = []

    # Make sure images_list_of_lists and prompts have the same length
    if len(images_list_of_lists) != len(prompts):
        raise ValueError(
            f"Mismatched lengths: got {len(images_list_of_lists)} image-turns "
            f"and {len(prompts)} prompt-turns."
        )

    for img_filenames, prompt_dict in zip(images_list_of_lists, prompts):
        role = prompt_dict.get("role", "user")  # default to "user" if not provided
        text = prompt_dict.get("text", "")

        # Load images for this turn
        turn_images = []
        for fn in img_filenames:
            path = os.path.join(image_dir, fn)
            try:
                raw_image = Image.open(path).convert("RGB")
                # Use Image.Resampling.LANCZOS if you see a PIL deprecation warning
                raw_image = raw_image.resize(resize_to, Image.LANCZOS)
                turn_images.append(raw_image)
            except Exception as e:
                print(f"Could not load image {path} due to error: {e}")
                continue

        # Add them to the global list of images
        all_processed_images.extend(turn_images)

        # Create the content block for this turn: "image" placeholders + text prompt
        turn_content = [{"type": "image"} for _ in turn_images]
        if text:  # If there's text in this turn
            turn_content.append({"type": "text", "text": text})

        # Append to conversation
        conversation.append({
            "role": role,
            "content": turn_content
        })

    if not all_processed_images:
        return [], "No valid images were processed."

    # Apply Qwen's chat template to the entire multi-turn conversation
    text_prompt = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True
    )

    # Prepare model inputs
    inputs = processor(
        text=[text_prompt],
        images=all_processed_images,
        return_tensors="pt",
        padding=True,
        max_length=256
    )
    inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=256,
            num_return_sequences=1,
            do_sample=False
        )

    generated_ids = [
        output_id[len(input_id):]
        for input_id, output_id in zip(inputs['input_ids'], output_ids)
    ]
    story_description = processor.batch_decode(
        generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )[0]
    return all_processed_images, story_description


In [4]:
temp_story_ids = story_ids[:60]

unique_ids = []
seen_sequences = set()

for story_id in temp_story_ids:
    story_images = get_story_images(story_id, annotations_dict, images, image_folder)

    images_key = tuple(story_images)

    if images_key not in seen_sequences:
        seen_sequences.add(images_key)
        unique_ids.append(story_id)

temp_story_ids = unique_ids

In [None]:
import os
import csv
import uuid
import matplotlib.pyplot as plt
from datetime import datetime

general_instruction = """You are an advanced assistant designed to create a 5-sentence story based on a 5-image sequence input. 
Your task is to generate the appropriate text for each image input. 
You are given a 5-sequence image and also an aspect to enhance. 
This aspect will consist of the aspect name and definition, explaining how to express that particular aspect. 
Your objective is to generate a 5-sentence story that expresses this aspect while still accurately visualizing the related image. 
Vary between first-person and third-person viewpoints. 
You can generate a named entity for the entities detected in the image.

Aspect list:
1) Immersion: A proper immersion is a story that has a consistent World Building. The world must have its own rules and logic. Ensure the world feels real.
2) Structure: A proper structure uses a clear beginning, middle, and end.
Generate a story based on the input image
"""
#1) Immersion: A proper immersion is a story that has a consistent World Building. The world must have its own rules and logic. Ensure the world feels real.
#1) Structure: A proper structure uses a clear beginning, middle, and end.
# generation_prompt = (
#     "Generate a story based on the input image"
# )

temp_story_ids = story_ids[:60]
unique_ids = []
seen_sequences = set()

for story_id in temp_story_ids:
    story_images = get_story_images(story_id, annotations_dict, images, image_folder)
    images_key = tuple(story_images)
    if images_key not in seen_sequences:
        seen_sequences.add(images_key)
        unique_ids.append(story_id)

temp_story_ids = unique_ids

folder_name = "1_feb_generate_immersion_structure_one_shot"
os.makedirs(folder_name, exist_ok=True)

# Prepare CSV file to export data (3 columns: story_id, prompts, generated_story)
csv_filename = os.path.join(folder_name, f"{folder_name}.csv")
with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["story_id", "prompts", "generated_story"])  # CSV header

    for final_story_id in temp_story_ids:
        sample_pairs = [
            # Add or remove any few-shot examples here if needed.
        ]

        images_list_of_lists = []
        prompts = []
        prompts.append({"role": "system", "text": general_instruction})
        # images_list_of_lists.append([])

        for sample in sample_pairs:
            sample_images = get_story_images(sample["story_id"], annotations_dict, images, image_folder)
            prompts.append({"role": "user", "text": sample["question"]})
            images_list_of_lists.append(sample_images)
            prompts.append({"role": "assistant", "text": sample["answer"]})
            images_list_of_lists.append([])

        final_story_images = get_story_images(final_story_id, annotations_dict, images, image_folder)
        images_list_of_lists.append(final_story_images)
        # prompts.append({"role": "system", "text": generation_prompt})

        # Call your function that returns processed_images and the response
        processed_images, response = qwen_test_few_shots(
            images_list_of_lists=images_list_of_lists,
            prompts=prompts,
            image_dir="images/val"
        )

        # Generate figure
        plt.figure(figsize=(20, 10))
        num_images = len(processed_images)
        for i, img in enumerate(processed_images, 1):
            plt.subplot(2, num_images, i)
            plt.imshow(img)
            plt.axis('off')
            plt.title(f'Image {i}')

        # Display the response
        plt.subplot(2, 1, 2)
        plt.text(
            0.5, 0.7, "Response: " + response,
            horizontalalignment='center', verticalalignment='center',
            wrap=True, fontsize=10, bbox=dict(facecolor='white', alpha=0.5)
        )

        # Format all prompts the same way you do in the figure
        formatted_prompts = "\n".join(
            f"{idx+1}) Role: {entry['role'].capitalize()}, Text: {entry['text']}"
            for idx, entry in enumerate(prompts)
        )

        plt.text(
            0.5, 0.3,
            "Prompt: " + formatted_prompts,
            horizontalalignment='center', verticalalignment='center',
            wrap=True, fontsize=10, bbox=dict(facecolor='white', alpha=0.5)
        )
        plt.axis('off')
        plt.title(f'Final Story for Story ID: {final_story_id}')
        plt.tight_layout()

        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        combined_filename =  f'story_combined_{timestamp}_{uuid.uuid4().hex}.png'
        plt.savefig(os.path.join(folder_name, combined_filename))
        plt.show()
        plt.close()

        # Write row to CSV: Story ID, Formatted Prompts, Generated Story
        writer.writerow([final_story_id, formatted_prompts, response])

print(f"Exported data to CSV: {csv_filename}")


`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
