In [6]:
!. ./venv/bin/activate
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [233]:
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
import uuid
import io
import boto3
import random
from tqdm import tqdm
from datetime import datetime

In [209]:
import toml
config = toml.load("config.toml")

In [225]:
import redis
from redis.commands.json.path import Path
redis_config = config["redis"]

client = redis.Redis(
    host=redis_config["host"],
    port=int(redis_config["port"]),
    password=redis_config["password"]
)

In [15]:
device = 'cuda'

lms = LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear"
)

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_type=torch.float16,
    use_auth_token=config["hugging_face"]["api_key"]
).to(device)

pipe.enable_attention_slicing()

def dummy(images, **kwargs):
    return images, False

pipe.safety_checker = dummy

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [202]:
def create_latents(width, height, seed):
    generator = torch.Generator(device=device)
    new_seed = generator.seed() if seed == 0 or seed is None else seed
    generator = generator.manual_seed(new_seed)

    image_latents = torch.randn(
        (1, pipe.unet.in_channels, height // 8,  width // 8),
        device=device,
        generator=generator
    )

    return new_seed, image_latents


def send(image, file_prefix):
    filename = f"{file_prefix}_{str(uuid.uuid4())}.jpg"
    mem_file = io.BytesIO()
    image.save(mem_file, format='JPEG')
    mem_file.seek(0)

    s3 = boto3.client(
        service_name='s3',
        endpoint_url=config["r2"]["endpoint"],
        aws_access_key_id=config["r2"]["key"],
        aws_secret_access_key=config["r2"]["secret"],
    )

    s3.upload_fileobj(
        mem_file,
        Bucket=BUCKET,
        Key=filename,
        ExtraArgs={
            'ContentType': 'image/jpeg',
        }
    )

    return filename

In [195]:
attributes = toml.load("attributes.toml")

def random_prompt():
    class_name = random.choice(attributes["classes"])
    subclass = random.choice(class_name["subclasses"])
    gender = random.choice(attributes["physical_attributes"]["gender"])
    physique  = random.choice(attributes["physical_attributes"]["physique"])
    hair_style = random.choice(attributes["physical_attributes"]["hair_style"])
    hair_color = random.choice(attributes["physical_attributes"]["hair_color"])
    hair_style = hair_style.replace("{color}", hair_color)
    armor = random.choice(class_name["armor_types"])
    ethnicity = random.choice(attributes["physical_attributes"]["ethnicity"])

    character = f"a {physique} {gender} {ethnicity} {subclass} with {hair_style}, wearing cyberpunk inspired {armor}"
    style = attributes["prompt"]["style"]
    prompt = f"{character} :: {style}"

    return prompt, {
        "class": class_name["base"],
        "subclass": subclass,
        "gender": gender,
        "physique": physique,
        "hair_style": hair_style,
        "armor": armor,
    }

In [196]:
def generate_image():
    prompt, punk = random_prompt()
    width = attributes["image_attributes"]["width"]
    height = attributes["image_attributes"]["height"]
    steps = attributes["image_attributes"]["steps"]
    guidance_scale = attributes["image_attributes"]["guidance_scale"]
    seed = random.randint(0, 1000000)

    image_seed, latents = create_latents(width, height, seed)
    with autocast("cuda"):
        image = pipe(prompt, guidance_scale=guidance_scale, width=width,
                        height=height, num_inference_steps=steps, latents=latents)["images"][0]

    return image, punk


In [235]:
web_interface = False

if web_interface:
    import gradio as gr

    gr.close_all()

    server = gr.Interface(
        fn=generate_image,
        title="Random Cyberpunk",
        inputs=[],
        outputs=["image", "json"],
        allow_flagging=False
    )
    server.launch(share=False)

else:
    for n in tqdm(range(3000)):
        image, punk = generate_image()
        res = send(image, "cyberpunk")

        key = res.split(".")[0]
        punk["image"] = config["r2"]["public_access"] + res
        punk["created_at"] = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
        client.json().set(key, Path.root_path(), punk)

  0%|                                                                                                  | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|                                                                                        | 1/3000 [00:11<9:24:36, 11.30s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|                                                                                        | 2/3000 [00:22<9:27:22, 11.35s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|                                                                                        | 3/3000 [00:33<9:25:29, 11.32s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|                                                                                        | 4/3000 [00:45<9:29:54, 11.41s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|▏                                                                                       | 5/3000 [00:56<9:24:06, 11.30s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|▏                                                                                       | 6/3000 [01:06<9:07:57, 10.98s/it]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|▏                                                                                       | 7/3000 [01:18<9:09:05, 11.01s/it]

  0%|          | 0/51 [00:00<?, ?it/s]