<a href="https://colab.research.google.com/github/neohack22/test/blob/master/christmas_story_generation_challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!--- @wandbcode{weaviate-paris-nov-2024} -->

## Welcome to the Weights & Biases Christmas story Weaving challenge

### 🧠 Resources

- Challenge instructions: [wandb.me/station-f](https://wandb.me/station-f)
- Challenge starter code(this colab): [wandb.me/christmas-weaving](http://wandb.me/christmas-weaving)
- Submissions and evaluations project: [wandb.me/paris_dashboard](https://wandb.me/paris_dashboard)
- Weights & Biases Weave [docs](https://wandb.me/docs_paris)

Please run all, including the last code black as it is what actually submits your work to the dashboard.

In [None]:
!pip install -qU openai fal-client weave

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/389.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m389.1/389.5 kB[0m [31m13.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m389.5/389.5 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.7/315.7 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m586.9/586.9 kB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.8/325.8 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.2/203.2 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.0/74.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import getpass


os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API Key: ")
os.environ["FAL_KEY"] = getpass.getpass("Enter FalAI API Key: ")

Enter OpenAI API Key: ··········
Enter FalAI API Key: ··········


In [None]:
import base64
import requests
import tempfile
import time
from io import BytesIO
from typing import Any, Callable, Dict, Optional, Union

import fal_client
import weave
from openai import OpenAI
from PIL import Image
from pydantic import BaseModel
from rich.progress import track

## Get your API key

Create or log into your Weights & Biases (W&B) account at [https://wandb.ai](https://wandb.ai/?utm_source=event&utm_medium=demo&utm_campaign=weaviate_paris_nov_2024) and copy your API key from [here](https://wandb.ai/authorize/?utm_source=event&utm_medium=demo&utm_campaign=weaviate_paris_nov_2024).

In [None]:
weave.init("ml-colabs/christmas-weaving-challenge")

Please login to Weights & Biases (https://wandb.ai/) to continue:


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Logged in as Weights & Biases user: agatamlyn.
View Weave data at https://wandb.ai/ml-colabs/christmas-weaving-challenge/weave


<weave.trace.weave_client.WeaveClient at 0x7cbf12c45ff0>

In [None]:
# @title
def base64_encode_image(image: Image.Image) -> str:
    byte_arr = BytesIO()
    image.save(byte_arr, format="PNG")
    encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
    encoded_string = f"data:image/png;base64,{encoded_string}"
    return str(encoded_string)

def custom_weave_wrapper(name: str) -> Callable[[Callable], Callable]:
    def wrapper(fn: Callable) -> Callable:
        op = weave.op()(fn)
        op.name = name  # type: ignore
        return op

    return wrapper

**Weave** is a lightweight toolkit for tracking and evaluating LLM applications, built by Weights & Biases.

Our goal is to bring rigor, best-practices, and composability to the inherently experimental process of developing AI applications, without introducing cognitive overhead.

Get started by decorating Python functions with @weave.op().


In [None]:
class Story(BaseModel):
    paragraphs: list[str]


class StoryGenerationModel(weave.Model):
    model_name: str
    system_prompt: Optional[str] = None

    @weave.op()
    def frame_messages(
        self,
        prompts: Union[str, list[str]],
        history: Optional[list[dict[str, str]]] = None
    ):
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        if history:
            messages += history
        prompts = [prompts] if isinstance(prompts, str) else prompts
        messages += [{"role": "user", "content": prompt} for prompt in prompts]
        return messages

    @weave.op()
    def predict(
        self,
        prompts: Union[str, list[str]],
        history: Optional[list[dict[str, str]]] = None
    ):
        messages = self.frame_messages(prompts, history)
        completion = OpenAI().beta.chat.completions.parse(
            model=self.model_name, messages=messages, response_format=Story
        )
        return completion.choices[0].message.parsed

In [None]:
story_generation_model = StoryGenerationModel(model_name="gpt-4o-mini")
story = story_generation_model.predict(
    prompts="Generate a 3 paragraph long story about Christmas."
)

🍩 https://wandb.ai/ml-colabs/christmas-weaving-challenge/r/call/019369d9-b2ac-7c40-b4d2-150934e46d48


In [None]:
class ImageGenerationModel(weave.Model):
    model_name: str
    inference_kwargs: dict[str, any] = {}

    def download_image_to_pil(self, url):
        response = requests.get(url)
        response.raise_for_status()
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
            temp_file.write(response.content)
            temp_filename = temp_file.name
        try:
            image = Image.open(temp_filename)
        finally:
            os.unlink(temp_filename)
        return image

    @weave.op()
    def predict(self, prompt: str) -> Image.Image:
        result = custom_weave_wrapper(
            name="fal_client.subscribe"
        )(fal_client.subscribe)(
            self.model_name,
            arguments={"prompt": prompt, **self.inference_kwargs},
        )
        return self.download_image_to_pil(result["images"][0]["url"])

  warn(


In [None]:
image_generation_model = ImageGenerationModel(model_name="fal-ai/flux/dev")
images = image_generation_model.predict(prompt=story.paragraphs[0])

🍩 https://wandb.ai/ml-colabs/christmas-weaving-challenge/r/call/019369d9-ca44-70a2-86d8-5a791b00d0be


In [None]:
class IllustratedStoryGenerator(weave.Model):
    story_generation_model: StoryGenerationModel
    image_generation_model: ImageGenerationModel

    @weave.op()
    def generate_story(
        self,
        prompts: Union[str, list[str]],
        history: Optional[list[dict[str, str]]] = None
    ) -> Story:
        return self.story_generation_model.predict(
            prompts=prompts, history=history
        )

    @weave.op()
    def illustrate_story(self, story: Story) -> list[Image.Image]:
        return [
            image_generation_model.predict(prompt=paragraph)
            for paragraph in track(
                story.paragraphs, description="Illustrating"
            )
        ]

    @weave.op()
    def predict(
        self,
        prompts: Union[str, list[str]],
        history: Optional[list[dict[str, str]]] = None
    ) -> list[dict[str, Union[str, Image.Image]]]:
        story = self.generate_story(
            prompts=prompts, history=history
        )
        generated_images = self.illustrate_story(story)
        return [
            {"paragraph": paragraph, "image": image}
            for paragraph, image in zip(story.paragraphs, generated_images)
        ]

In [None]:
illustrated_story_generation_model = IllustratedStoryGenerator(
    story_generation_model=story_generation_model,
    image_generation_model=image_generation_model,
)

illustrated_story = illustrated_story_generation_model.predict(
    prompts="Generate a 3 paragraph long story about Christmas."
)

Output()

🍩 https://wandb.ai/ml-colabs/christmas-weaving-challenge/r/call/019369d9-daf9-7b33-a098-037833e77d6f


In [None]:
# @title 🎄 Run this cell to get your illustrated story judged by Father Christmas and submitted to the challenge dashboard
class IllustrationScore(BaseModel):
    story_alignment_score: float
    christmass_alignment_score: float


class StoryJudgementScore(BaseModel):
    story_quality: float
    christmass_alignment_score: float
    explanation: str


class IllustratedStoryJudgementScore(BaseModel):
    story_judgement_score: StoryJudgementScore
    illustration_score: IllustrationScore
    final_score: float


class FatherChristmas(weave.Model):
    image_description_model_name: str = "gpt-4o"
    judgement_model_name: str = "gpt-4o"

    @weave.op()
    def decribe_image(self, image: Image.Image) -> str:
        completion = OpenAI().chat.completions.create(
            model=self.image_description_model_name,
            messages=[
                {
                    "role": "system",
                    "content": """
You are a helpful assistant meant to describe images in detail.
First you must give an overall overview describing the image in not more than
2 sentences.
Next, you must analyze the image step-by-step and describe the actions, events,
objects and their relationships, and the overall color palette, mood, and vibe
of the image."""
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {"url": base64_encode_image(image)},
                        },
                    ],
                },
            ],
        )
        return completion.choices[0].message.content

    @weave.op()
    def judge_story(self, story: str) -> StoryJudgementScore:
        completion = OpenAI().beta.chat.completions.parse(
            model=self.judgement_model_name,
            response_format=StoryJudgementScore,
            messages=[
                {
                    "role": "system",
                    "content": """
You are a helpful assistant meant to judge the quality of a story and how
accurately it aligns with the themes of christmass.

Here are some visual clues with respect to how much the story is aligned with christmas:

1. Santa Clause
2. Snow
3. Elfs
4. Christmas tree
5. presents
6. happy humans
7. ginger breads
8. star
9. snowman
10. jingle bells

You must give a chirstmas alignment score that is a fractional number between
0 and 1 which corresponds to how many of the aforementioned visual clues are present.
You must follow the following strategy for predicting the chirstmas alignment score:

1. You must assign a score of 1.0 if all the visual clues are present in the story.
2. You must assign a score of 0.9 if only 9 of the visual clues are present in the story.
3. You must assign a score of 0.8 if only 8 of the visual clues are present in the story.
4. You must assign a score of 0.7 if only 7 of the visual clues are present in the story.
5. You must assign a score of 0.6 if only 6 of the visual clues are present in the story.
6. You must assign a score of 0.5 if only 5 of the visual clues are present in the story.
7. You must assign a score of 0.4 if only 4 of the visual clues are present in the story.
8. You must assign a score of 0.3 if only 3 of the visual clues are present in the story.
9. You must assign a score of 0.2 if only 2 of the visual clues are present in the story.
10. You must assign a score of 0.1 if only 1 of the visual clues is present in the story.


You must give a story quality score that is a fractional number between
0 and 1 which corresponds to how creative and charmful the story is with respect
to the motiff of Christmas. You must follow the following strategy for
predicting the story alignment score:

1. If there's not charm or magical Christmas spirit and the story feels dull
    serious, or devoid of whimsy; you must assign a story quality score
    in the range 0.0 to 0.25.
2. If the story has some charm or magic but isn’t fully inspired and only has a
    few playful elements but doesn't evoke much joy; you must assign a
    story quality score in the range 0.25 to 0.5.
3. If the story has charm, magic, some festive cheer, and enough playful
    elements to make an elf smile; you must assign a story quality score in the
    range 0.5 to 0.75.
4. If the story exudes charm, magic, and abundant festive cheer such that
    it feels alive with holiday spirit and creativity, bringing a smile to any
    elf’s face; you must assign a story quality score in the range 0.75 to 1.0.
"""
                },
                {"role": "user", "content": story},
            ]
        )
        return completion.choices[0].message.parsed

    @weave.op()
    def judge_paragraph_illustration(
        self, story: str, paragraph: str, image: Image.Image
    ) -> IllustrationScore:
        image_description = self.decribe_image(image)
        completion = OpenAI().beta.chat.completions.parse(
            model=self.judgement_model_name,
            response_format=IllustrationScore,
            messages=[
                {
                    "role": "system",
                    "content": """
You are a helpful assistant meant to judge the quality of an image that has to
be used as an illustration for a paragraph from a story about christmass.

You will be provided the entire story within tags <story>...</story>.
You will be provided the paragraph within tags <paragraph>...</paragraph>.
You will be provided the illustration image corresponding to the paragraph.

You are to closely refer to all the information provided to predict a
fractional score call story alignment score (between 0 and 1) for the image on
how well it aligns with the story in general and the paragraph in particular.
You must follow the following strategy for predicting the christmass alignment score:
1. If the image is of poor quality, with significant issues like blurriness,
    distortions, or technical flaws; you must assign a story alignment score
    in the range 0.0 to 0.25.
2. If the image is of average quality, with some issues that detract from its appeal,
    like blurriness or inconsistencies; you must assign a story alignment score
    in the range 0.25 to 0.5.
3. If the image is of good quality, polished, and visually appealing.
    It demonstrates good technical skills, with minor issues.; you must assign
    a story alignment score in the range 0.5 to 0.75.
4. If The image is of high quality, visually striking, with attention to detail
    and creativity. It’s clear, well-composed, and aesthetically pleasing;
    you must assign a story alignment score in the range 0.75 to 1.0.
5. You must also compare the image with the story and deduct 0.2 point for
    every missing visual clue.


You must also predict a fractional score called christmass alignment score
(between 0 and 1) for how well the image aligns visually with the idea of christmass.
Here are some visual clues with respect to how much the story is aligned with christmas:

1. Santa Clause
2. Snow
3. Elfs
4. Christmas tree
5. presents
6. happy humans
7. ginger breads
8. star
9. snowman
10. jingle bells

You must follow the following strategy for predicting the chirstmas alignment score:

1. You must assign a score of 1.0 if all the visual clues are present in the story.
2. You must assign a score of 0.9 if 9 of the visual clues are present in the story.
3. You must assign a score of 0.8 if 8 of the visual clues are present in the story.
...
"""
                },
                {"role": "user", "content": f"<story>\n{story}\n</story>."},
                {
                    "role": "user",
                    "content": f"<paragraph>\n{paragraph}\n</paragraph>."
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {"url": base64_encode_image(image)},
                        },
                    ],
                },
            ],
        )
        return completion.choices[0].message.parsed

    @weave.op()
    def predict(
        self, illustrated_story: list[dict[str, Union[str, Image.Image]]]
    ) -> IllustratedStoryJudgementScore:
        story = "\n\n".join([paragraph["paragraph"] for paragraph in illustrated_story])
        story_judgement_score = self.judge_story(story)
        mean_illustration_score = IllustrationScore(
            story_alignment_score=0, christmass_alignment_score=0
        )
        for illustrated_paragraph in track(illustrated_story, description="Judging illustrated story"):
            paragraph = illustrated_paragraph["paragraph"]
            image = illustrated_paragraph["image"]
            illustration_score = self.judge_paragraph_illustration(story, paragraph, image)
            mean_illustration_score.story_alignment_score += illustration_score.story_alignment_score
            mean_illustration_score.christmass_alignment_score += illustration_score.christmass_alignment_score
        mean_illustration_score.story_alignment_score = mean_illustration_score.story_alignment_score / len(illustrated_story)
        mean_illustration_score.christmass_alignment_score = mean_illustration_score.christmass_alignment_score / len(illustrated_story)
        return IllustratedStoryJudgementScore(
            story_judgement_score=story_judgement_score,
            illustration_score=mean_illustration_score,
            final_score=(story_judgement_score.story_quality + story_judgement_score.christmass_alignment_score + mean_illustration_score.story_alignment_score + mean_illustration_score.christmass_alignment_score) * 100 / 4
        )


judge = FatherChristmas()
judgement = judge.predict(illustrated_story)

Output()

🍩 https://wandb.ai/ml-colabs/christmas-weaving-challenge/r/call/019369da-1a3e-75b1-87cd-bdd728243790
