Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEATURE: add support for final stable diffusion xl model #122

Merged
merged 1 commit into from Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions config/settings.yml
Expand Up @@ -103,10 +103,11 @@ plugins:
ai_stability_api_url:
default: "https://api.stability.ai"
ai_stability_engine:
default: "stable-diffusion-xl-beta-v2-2-2"
default: "stable-diffusion-xl-1024-v1-0"
type: enum
choices:
- "stable-diffusion-xl-beta-v2-2-2"
- "stable-diffusion-xl-1024-v1-0"
- "stable-diffusion-768-v2-1"
- "stable-diffusion-v1-5"
ai_hugging_face_api_url:
default: ""
Expand Down
33 changes: 29 additions & 4 deletions lib/shared/inference/stability_generator.rb
Expand Up @@ -3,20 +3,40 @@
module ::DiscourseAi
module Inference
class StabilityGenerator
def self.perform!(prompt)
def self.perform!(prompt, width: nil, height: nil)
headers = {
"Referer" => Discourse.base_url,
"Content-Type" => "application/json",
"Accept" => "application/json",
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}",
}

sdxl_allowed_dimentions = [
[1024, 1024],
[1152, 896],
[1216, 832],
[1344, 768],
[1536, 640],
[640, 1536],
[768, 1344],
[832, 1216],
[896, 1152],
]

if (!width && !height)
if SiteSetting.ai_stability_engine.include? "xl"
width, height = sdxl_allowed_dimentions[0]
else
width, height = [512, 512]
end
end

payload = {
text_prompts: [{ text: prompt }],
cfg_scale: 7,
clip_guidance_preset: "FAST_BLUE",
height: 512,
width: 512,
height: width,
width: height,
samples: 4,
steps: 30,
}
Expand All @@ -27,7 +47,12 @@ def self.perform!(prompt)

response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers)

raise Net::HTTPBadResponse if response.status != 200
if response.status != 200
Rails.logger.error(
"AI stability generator failed with status #{response.status}: #{response.body}}",
)
raise Net::HTTPBadResponse
end

JSON.parse(response.body, symbolize_names: true)
end
Expand Down
50 changes: 50 additions & 0 deletions spec/shared/inference/stability_generator_spec.rb
@@ -0,0 +1,50 @@
# frozen_string_literal: true
require "rails_helper"

describe DiscourseAi::Inference::StabilityGenerator do
def gen(prompt)
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
end

it "sets dimentions to 512x512 for non XL model" do
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
SiteSetting.ai_stability_api_key = "123"

stub_request(:post, "http://www.a.b.c/v1/generation/stable-diffusion-v1-5/text-to-image")
.with do |request|
json = JSON.parse(request.body)
expect(json["text_prompts"][0]["text"]).to eq("a cow")
expect(json["width"]).to eq(512)
expect(json["height"]).to eq(512)
expect(request.headers["Authorization"]).to eq("Bearer 123")
expect(request.headers["Content-Type"]).to eq("application/json")
true
end
.to_return(status: 200, body: "{}", headers: {})

gen("a cow")
end

it "sets dimentions to 1024x1024 for XL model" do
SiteSetting.ai_stability_engine = "stable-diffusion-xl-1024-v1-0"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
SiteSetting.ai_stability_api_key = "123"
stub_request(
:post,
"http://www.a.b.c/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image",
)
.with do |request|
json = JSON.parse(request.body)
expect(json["text_prompts"][0]["text"]).to eq("a cow")
expect(json["width"]).to eq(1024)
expect(json["height"]).to eq(1024)
expect(request.headers["Authorization"]).to eq("Bearer 123")
expect(request.headers["Content-Type"]).to eq("application/json")
true
end
.to_return(status: 200, body: "{}", headers: {})

gen("a cow")
end
end