## Init stable diffusion

In [None]:
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionUpscalePipeline
import torch

def initdiffusionpipeline(model_id):
    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")
    return pipe

diffusionpipeline=initdiffusionpipeline("stabilityai/stable-diffusion-2-1")


In [None]:
def initupscalepipeline(model_id):
    pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipeline.enable_attention_slicing() 
    pipeline = pipeline.to("cuda")
    return pipeline

upscalepipeline=initupscalepipeline("stabilityai/stable-diffusion-x4-upscaler")

In [None]:
def generateimage(prompt,seed=42):
    generator = torch.Generator("cuda").manual_seed(seed)
    image = diffusionpipeline(prompt,generator=generator).images[0]
    fname = f"images/{prompt.replace(' ','-')}.png"
    image.save(fname)
    display(image)
    return image

## Init LLM

In [None]:
from langchain.llms import HuggingFaceHub
from langchain import PromptTemplate, LLMChain
import os
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv()) # take environment variables from .env.
assert os.environ.get("HUGGINGFACEHUB_API_TOKEN") is not None

In [None]:
def initllm(model_id):
    # See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options
    return HuggingFaceHub(repo_id=model_id, model_kwargs={"temperature": 0.1, "max_length": 64})

# search LLMs for stable diffusion prompting:
# https://huggingface.co/models?search=stable%20diffusion%20prompt
# llmchain = initllmchain("DrishtiSharma/StableDiffusion-Prompt-Generator-GPT-Neo-125M")
llm = initllm("Ar4ikov/gpt2-650k-stable-diffusion-prompt-generator")

In [None]:
def generateprompt(basicprompt):
    betterprompt = llm(basicprompt)
    print(betterprompt)
    return betterprompt

## Run

In [None]:
myprompt="a photo of an astronaut riding a horse on mars"

In [None]:

img=generateimage(myprompt)

In [None]:
betterprompt = myprompt+generateprompt(myprompt)

In [None]:
betterimg=generateimage(betterprompt)