Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
jaketae committed Jan 13, 2023
2 parents 725f7fb + a5a8637 commit c13340b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 71 deletions.
64 changes: 54 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,80 @@ Given a prompt as an opening line of a story, GPT writes the rest of the plot; S

![out](https://user-images.githubusercontent.com/25360440/210071764-51ed5872-ba56-4ed0-919b-d9ce65110185.gif)

## Installation

## Quickstart
### PyPI

Story Teller is available on [PyPI](https://pypi.org/project/storyteller-core/).

```
$ pip install storyteller-core
```

### Source

1. Clone the repository.

```
$ git clone https://github.com/jaketae/storyteller.git
$ cd storyteller
```

2. Install dependencies.

```
$ pip install .
```

*Note for Apple M1 users, `mecab-python3` is not available for M1. You would need to install `mecab` before running the following commands. You can do this with brew `brew install mecab` and continue with the next steps. You can get more information [here](https://github.com/SamuraiT/mecab-python3/issues/84).*
*Note: For Apple M1/2 users, [`mecab-python3`](https://github.com/SamuraiT/mecab-python3) is not available. You need to install `mecab` before running `pip install`. You can do this with [Hombrew](https://www.google.com/search?client=safari&rls=en&q=homebrew&ie=UTF-8&oe=UTF-8) via `brew install mecab`. For more information, refer to [this issue](https://github.com/SamuraiT/mecab-python3/issues/84).*

2. Install package requirements.

3. (Optional) To develop locally, install `dev` dependencies and install pre-commit hooks. This will automatically trigger linting and code quality checks before each commit.

```
$ pip install --upgrade pip wheel
$ pip install -e .
# for dev requirements, do:
# pip install -e .[dev]
$ pip install -e .[dev]
$ pre-commit install
```

3. Run the demo. The final video will be saved as `/out/out.mp4`, alongside other intermediate images, audio files, and subtitles.
## Quickstart

The quickest way to run a demo is through the CLI. Simply type

```
$ storyteller
# alternatively with make, do:
# make run
```

The final video will be saved as `/out/out.mp4`, alongside other intermediate images, audio files, and subtitles.

To adjust the defaults with custom parametes, toggle the CLI flags as needed.

```
$ storyteller --help
usage: storyteller [-h] [--writer_prompt WRITER_PROMPT]
[--painter_prompt_prefix PAINTER_PROMPT_PREFIX] [--num_images NUM_IMAGES]
[--output_dir OUTPUT_DIR] [--seed SEED] [--max_new_tokens MAX_NEW_TOKENS]
[--writer WRITER] [--painter PAINTER] [--speaker SPEAKER]
[--writer_device WRITER_DEVICE] [--painter_device PAINTER_DEVICE]
optional arguments:
-h, --help show this help message and exit
--writer_prompt WRITER_PROMPT
--painter_prompt_prefix PAINTER_PROMPT_PREFIX
--num_images NUM_IMAGES
--output_dir OUTPUT_DIR
--seed SEED
--max_new_tokens MAX_NEW_TOKENS
--writer WRITER
--painter PAINTER
--speaker SPEAKER
--writer_device WRITER_DEVICE
--painter_device PAINTER_DEVICE
```

## Usage

For more advanced use cases, you can also directly interface with Story Teller in Python code.

1. Load the model with defaults.

```python
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "storyteller-core"
version = "0.0.1"
version = "0.0.2"
description = "Multimodal AI Story Teller, built with Stable Diffusion, GPT, and neural text-to-speech"
authors = ["Jaesung Tae <jaesungtae@gmail.com>"]
readme = "README.md"
Expand All @@ -14,12 +14,10 @@ soundfile = "^0.11.0"
tts = "^0.10.1"
diffusers = "^0.11.1"
transformers = "^4.25.1"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.21.0"
pre-commit = {version = "^2.21.0", extras = ["dev"]}

[tool.poetry.scripts]
storyteller = "storyteller.__main__:main"
storyteller = "storyteller.cli:main"

[build-system]
requires = ["poetry-core"]
Expand Down
23 changes: 0 additions & 23 deletions src/storyteller/__main__.py

This file was deleted.

47 changes: 47 additions & 0 deletions src/storyteller/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse
import dataclasses
import logging
import os

from storyteller import StoryTeller, StoryTellerConfig
from storyteller.utils import set_log_level, set_seed


def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--writer_prompt",
type=str,
default="Once upon a time, unicorns roamed the Earth.",
)
parser.add_argument(
"--painter_prompt_prefix", type=str, default="Beautiful painting"
)
parser.add_argument("--num_images", type=int, default=10)
parser.add_argument("--output_dir", type=str, default="out")
parser.add_argument("--seed", type=int, default=42)
default_config = StoryTellerConfig()
for key, value in dataclasses.asdict(default_config).items():
parser.add_argument(f"--{key}", type=type(value), default=value)
args = parser.parse_args()
return args


def main() -> None:
args = get_args()
set_seed(args.seed)
set_log_level(logging.WARNING)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
config = StoryTellerConfig()
for field in dataclasses.fields(config):
name = field.name
setattr(config, name, getattr(args, name))
story_teller = StoryTeller(config)
os.makedirs(args.output_dir, exist_ok=True)
story_teller.generate(
args.writer_prompt, args.painter_prompt_prefix, args.num_images, args.output_dir
)


if __name__ == "__main__":
main()
11 changes: 4 additions & 7 deletions src/storyteller/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from dataclasses import dataclass
from pathlib import Path

import torch


@dataclass
class StoryTellerConfig:
image_size: int = 512
max_new_tokens: int = 50
writer: str = "gpt2"
painter: str = "stabilityai/stable-diffusion-2"
speaker: str = "tts_models/en/ljspeech/glow-tts"
writer_device: str = "cuda:0"
painter_device: str = "cuda:0"
output_dir: str = Path(__file__).parent.parent / "out"
seed: int = 42
diffusion_prompt_prefix: str = "Beautiful painting"
writer_device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
painter_device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
44 changes: 18 additions & 26 deletions src/storyteller/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
from typing import List

Expand All @@ -15,69 +14,62 @@
make_timeline_string,
require_ffmpeg,
require_punkt,
set_seed,
subprocess_run,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.getLogger("diffusers").setLevel(logging.CRITICAL)
logging.getLogger("transformers").setLevel(logging.CRITICAL)


class StoryTeller:
@require_ffmpeg
@require_punkt
def __init__(self, config: StoryTellerConfig):
set_seed(config.seed)
self.config = config
os.makedirs(config.output_dir, exist_ok=True)
writer_device = torch.device(config.writer_device)
painter_device = torch.device(config.writer_device)
self.writer = pipeline(
"text-generation", model=config.writer, device=writer_device
)
self.painter = StableDiffusionPipeline.from_pretrained(
config.painter,
height=self.config.image_size,
width=self.config.image_size,
use_auth_token=False,
).to(painter_device)
self.speaker = TTS(config.speaker)
self.sample_rate = self.speaker.synthesizer.output_sample_rate
self.output_dir = None

@classmethod
def from_default(cls):
config = StoryTellerConfig()
return cls(config)

@torch.inference_mode()
def paint(self, prompt) -> Image:
return self.painter(f"{self.config.diffusion_prompt_prefix}: {prompt}").images[
0
]
def paint(self, prompt: str) -> Image:
return self.painter(prompt).images[0]

@torch.inference_mode()
def speak(self, prompt) -> List[int]:
def speak(self, prompt: str) -> List[int]:
return self.speaker.tts(prompt)

@torch.inference_mode()
def write(self, prompt) -> str:
def write(self, prompt: str) -> str:
return self.writer(prompt, max_new_tokens=self.config.max_new_tokens)[0][
"generated_text"
]

def get_output_path(self, file):
return os.path.join(self.config.output_dir, file)
def get_output_path(self, file: str) -> str:
return os.path.join(self.output_dir, file)

def generate(
self,
prompt: str,
writer_prompt: str,
painter_prompt_prefix: str,
num_images: int,
output_dir: str,
) -> None:
video_paths = []
sentences = self.write_story(prompt, num_images)
self.output_dir = output_dir
sentences = self.write_story(writer_prompt, num_images)
for i, sentence in enumerate(sentences):
video_path = self._generate(i, sentence)
video_path = self._generate(i, sentence, painter_prompt_prefix)
video_paths.append(video_path)
self.concat_videos(video_paths)

Expand All @@ -89,12 +81,12 @@ def concat_videos(self, video_paths: List[str]) -> None:
f.write(f"file {os.path.split(video_path)[-1]}\n")
subprocess_run(f"ffmpeg -f concat -i {files_path} -c copy {output_path}")

def _generate(self, id_: int, sentence: str) -> str:
def _generate(self, id_: int, sentence: str, painter_prompt_prefix: str) -> str:
image_path = self.get_output_path(f"{id_}.png")
audio_path = self.get_output_path(f"{id_}.wav")
subtitle_path = self.get_output_path(f"{id_}.srt")
video_path = self.get_output_path(f"{id_}.mp4")
image = self.paint(sentence)
image = self.paint(f"{painter_prompt_prefix} {sentence}")
image.save(image_path)
audio = self.speak(sentence)
duration, remainder = divmod(len(audio), self.sample_rate)
Expand All @@ -110,11 +102,11 @@ def _generate(self, id_: int, sentence: str) -> str:
)
return video_path

def write_story(self, prompt: str, num_sentences: int) -> List[str]:
def write_story(self, writer_prompt: str, num_sentences: int) -> List[str]:
sentences = []
while len(sentences) < num_sentences + 1:
prompt = self.write(prompt)
sentences = sent_tokenize(prompt)
writer_prompt = self.write(writer_prompt)
sentences = sent_tokenize(writer_prompt)
while len(sentences) > num_sentences:
sentences.pop()
return sentences
6 changes: 6 additions & 0 deletions src/storyteller/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import random
import shutil
Expand Down Expand Up @@ -66,3 +67,8 @@ def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


def set_log_level(level: int) -> None:
"""Disables specified logging level and below."""
logging.disable(level)

0 comments on commit c13340b

Please sign in to comment.