In [None]:
import torch
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
import os

os.getlogin()
os.path.expanduser("~")

from requests import get

ip = get('https://api.ipify.org').content.decode('utf8')
print('My public IP address is: {}'.format(ip))

In [None]:
from diffusers import DPMSolverMultistepScheduler, DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer

model_id = "stabilityai/stable-diffusion-xl-base-1.0"

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")

scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")

pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    torch_dtype=torch.float16
).to("mps")


In [None]:
image = pipeline("chua mia tee painting, tree", num_inference_steps=20).images[0]
image

In [None]:
from diffusers import AutoPipelineForText2Image
import torch

cpipe = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
).to("cuda")

cpipe.load_lora_weights(
    "heypoom/chuamiatee-1",
    weight_name="pytorch_lora_weights.safetensors"
)

In [None]:
image = cpipe("chua mia tee painting, flags", num_inference_steps=100).images[0]
image

In [None]:
image.save("./chua_mia_tee.png")

In [None]:
denoised_images = []

def latents_to_rgb(latents):
    weights = (
        (60, -60, 25, -70),
        (60,  -5, 15, -50),
        (60,  10, -5, -35)
    )

    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
    rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
    image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
    image_array = image_array.transpose(1, 2, 0)

    return Image.fromarray(image_array)

def denoising_callback(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    image = latents_to_rgb(latents)
    denoised_images.append(image)

    return callback_kwargs
    
prompt = "a rabbit sleeps"

pipeline_result = pipeline(
    prompt,
    num_inference_steps=5,
    callback_on_step_end=denoising_callback,
    callback_on_step_end_tensor_inputs=['latents'],
    guidance_scale=7.5
)


In [None]:
fig, axs = plt.subplots(1, len(denoised_images), figsize=(20, 4))

for i in range(len(denoised_images)):
    axs[i].imshow(denoised_images[i])
    axs[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# from diffusers import StableDiffusionImageVariationPipeline
# 
# variation_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
#   "lambdalabs/sd-image-variations-diffusers",
#   revision="v2.0",
#   torch_dtype=torch.float16,
# ).to("mps")
# 
# v_out = variation_pipe(malaya, num_images_per_prompt=5, num_inference_steps=50, guidance_scale=0)

In [None]:
malaya = Image.open("./malaya.png")

In [None]:
from diffusers import StableDiffusionImg2ImgPipeline

img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5"
).to("mps")

In [None]:
[g + 0.5 for g in range(0, 6)]

In [None]:
[i / 5 for i in range(0, 6)]

In [None]:
out_images = []

Gc = 2
Sc = 6
SIZE = 512

for g in range(0, Gc):
    for s in range(0, Sc):
        result = img_pipe(
            prompt="a dream",
            image=malaya.resize((SIZE, SIZE)).convert("RGB"),
            strength=s / 5,
            guidance_scale=g + 5.5,
            num_images_per_prompt=1,
            num_inference_steps=20,
        )
        
        out_images.append(result.images[0])

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 10))  # width, height in inches

for i in range(Gc):
    for j in range(Sc):
        index = i * Sc + j
        ax = plt.subplot(Gc, Sc, index + 1)  # nrows, ncols, index
        plt.imshow(out_images[index])
        plt.axis('off')
plt.tight_layout()
plt.show()