<a href="https://colab.research.google.com/github/lshus/stitchdiffusion-colab/blob/main/colab_stitchdiffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# stitchdiffusion colab for synthesizing 360-degree panoramic images
## (Four steps in total, click them sequentially.)

In [1]:
# @title ## 1. install base environment
# @markdown Clone Kohya Trainer from GitHub (Be patient, it requires several minutes.)
import os
import zipfile
import shutil
import time
from subprocess import getoutput
from IPython.utils import capture
from google.colab import drive

%store -r

# root_dir
root_dir = "/content"
deps_dir = os.path.join(root_dir, "deps")
repo_dir = os.path.join(root_dir, "kohya-trainer")
training_dir = os.path.join(root_dir, "LoRA")
pretrained_model = os.path.join(root_dir, "pretrained_model")
vae_dir = os.path.join(root_dir, "vae")
config_dir = os.path.join(training_dir, "config")

# repo_dir
accelerate_config = os.path.join(repo_dir, "accelerate_config/config.yaml")
tools_dir = os.path.join(repo_dir, "tools")
finetune_dir = os.path.join(repo_dir, "finetune")

for store in [
    "root_dir",
    "deps_dir",
    "repo_dir",
    "training_dir",
    "pretrained_model",
    "vae_dir",
    "accelerate_config",
    "tools_dir",
    "finetune_dir",
    "config_dir",
]:
    with capture.capture_output() as cap:
        %store {store}
        del cap

repo_url = "https://github.com/Linaqruf/kohya-trainer"
bitsandytes_main_py = "/usr/local/lib/python3.10/dist-packages/bitsandbytes/cuda_setup/main.py"
# branch = ""  # @param {type: "string"}
# install_xformers = True  # @param {'type':'boolean'}
# mount_drive = False  # @param {type: "boolean"}
# verbose = False # @param {type: "boolean"}
branch = ""
install_xformers = True
mount_drive = False
verbose = False

def read_file(filename):
    with open(filename, "r") as f:
        contents = f.read()
    return contents


def write_file(filename, contents):
    with open(filename, "w") as f:
        f.write(contents)


def clone_repo(url):
    if not os.path.exists(repo_dir):
        os.chdir(root_dir)
        !git clone {url} {repo_dir}
    else:
        os.chdir(repo_dir)
        !git pull origin {branch} if branch else !git pull


def install_dependencies():
    s = getoutput('nvidia-smi')

    if 'T4' in s:
        !sed -i "s@cpu@cuda@" library/model_util.py

    !pip install {'-q' if not verbose else ''} --upgrade -r requirements.txt
    !pip install {'-q' if not verbose else ''} torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 torchtext==0.15.1 torchdata==0.6.0 --extra-index-url https://download.pytorch.org/whl/cu118 -U

    if install_xformers:
        !pip install {'-q' if not verbose else ''} xformers==0.0.19 triton==2.0.0 -U

    from accelerate.utils import write_basic_config

    if not os.path.exists(accelerate_config):
        write_basic_config(save_location=accelerate_config)


def remove_bitsandbytes_message(filename):
    welcome_message = """
def evaluate_cuda_setup():
    print('')
    print('='*35 + 'BUG REPORT' + '='*35)
    print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
    print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
    print('='*80)"""

    new_welcome_message = """
def evaluate_cuda_setup():
    import os
    if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
        print('')
        print('=' * 35 + 'BUG REPORT' + '=' * 35)
        print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
        print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
        print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1')
        print('=' * 80)"""

    contents = read_file(filename)
    new_contents = contents.replace(welcome_message, new_welcome_message)
    write_file(filename, new_contents)


def main():
    os.chdir(root_dir)

    if mount_drive:
        if not os.path.exists("/content/drive"):
            drive.mount("/content/drive")

    for dir in [
        deps_dir,
        training_dir,
        config_dir,
        pretrained_model,
        vae_dir
    ]:
        os.makedirs(dir, exist_ok=True)

    clone_repo(repo_url)

    if branch:
        os.chdir(repo_dir)
        status = os.system(f"git checkout {branch}")
        if status != 0:
            raise Exception("Failed to checkout branch or commit")

    os.chdir(repo_dir)

    !apt install aria2 {'-qq' if not verbose else ''}

    install_dependencies()
    time.sleep(3)

    remove_bitsandbytes_message(bitsandytes_main_py)

    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["BITSANDBYTES_NOWELCOME"] = "1"
    os.environ["SAFETENSORS_FAST_GPU"] = "1"

    cuda_path = "/usr/local/cuda-11.8/targets/x86_64-linux/lib/"
    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
    os.environ["LD_LIBRARY_PATH"] = f"{ld_library_path}:{cuda_path}"

main()


Cloning into '/content/kohya-trainer'...
remote: Enumerating objects: 2500, done.[K
remote: Counting objects: 100% (1166/1166), done.[K
remote: Compressing objects: 100% (361/361), done.[K
remote: Total 2500 (delta 898), reused 950 (delta 805), pack-reused 1334[K
Receiving objects: 100% (2500/2500), 4.90 MiB | 11.28 MiB/s, done.
Resolving deltas: 100% (1655/1655), done.
The following additional packages will be installed:
  libaria2-0 libc-ares2
The following NEW packages will be installed:
  aria2 libaria2-0 libc-ares2
0 upgraded, 3 newly installed, 0 to remove and 10 not upgraded.
Need to get 1,513 kB of archives.
After this operation, 5,441 kB of additional disk space will be used.
Selecting previously unselected package libc-ares2:amd64.
(Reading database ... 120880 files and directories currently installed.)
Preparing to unpack .../libc-ares2_1.18.1-1ubuntu0.22.04.2_amd64.deb ...
Unpacking libc-ares2:amd64 (1.18.1-1ubuntu0.22.04.2) ...
Selecting previously unselected package l

In [2]:
# @title ## 2. download SD 2-1 base and its vae

# download pre-trained stable diffusion to /content/pretrained_model
%cd /content/pretrained_model
!wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -O stable-diffusion-2-1-base.safetensors

# download vae of stable diffusion to /content/vae
%cd /content/vae
!wget https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt -O stablediffusion.vae.pt

/content/pretrained_model
--2023-11-24 09:30:58--  https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.55, 18.164.174.17, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/24/cb/24cbc2f7542236eb613b4f16b6802d7c2bef443e86cf9d076719733866e66c3a/df955bdf6b682338ea9b55dfc0d8f3475aadf4836e204893d28b82355e0956d2?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27v2-1_512-ema-pruned.safetensors%3B+filename%3D%22v2-1_512-ema-pruned.safetensors%22%3B&Expires=1701077458&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwMTA3NzQ1OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8yNC9jYi8yNGNiYzJmNzU0MjIzNmViNjEzYjRmMTZiNjgwMmQ3YzJiZWY0NDNlODZjZjlkMDc2NzE5NzMzODY2ZTY2YzNhL2

In [3]:
# @title ## 3. download LoRA and test file

import gdown

# URL of the pre-trained LoRA file on Google Drive
lora_url = "https://drive.google.com/u/0/uc?id=1MiaG8v0ZmkTwwrzIEFtVoBj-Jjqi_5lz&export=download"

# Destination path to save the downloaded file
lora_save_path = "/content/lora.safetensors"

# Download the file using gdown
gdown.download(lora_url, lora_save_path, quiet=False)

# download test file
!wget -O /content/kohya-trainer/stitchdiffusion_test.py https://raw.githubusercontent.com/lshus/stitchdiffusion-colab/main/stitchdiffusion_test.py

Downloading...
From: https://drive.google.com/u/0/uc?id=1MiaG8v0ZmkTwwrzIEFtVoBj-Jjqi_5lz&export=download
To: /content/lora.safetensors
100%|██████████| 55.0M/55.0M [00:02<00:00, 23.8MB/s]


--2023-11-24 09:33:20--  https://raw.githubusercontent.com/lshus/stitchdiffusion-colab/main/stitchdiffusion_test.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 65127 (64K) [text/plain]
Saving to: ‘/content/kohya-trainer/stitchdiffusion_test.py’


2023-11-24 09:33:20 (4.80 MB/s) - ‘/content/kohya-trainer/stitchdiffusion_test.py’ saved [65127/65127]



In [14]:
# @title ## 4. 360-degree panoramic image generation
# @markdown trigger word V*: '360-degree panoramic image'

# @markdown It is necessary to contain the trigger word in the
# @markdown input prompt, if you hope that the customized model
# @markdown generates a 360-degree panorama.
%store -r

network_weight = "/content/lora.safetensors"  # @param {'type':'string'}
network_mul = 0.7
network_module = "networks.lora"
network_args = ""

v2 = True
v_parameterization = False
prompt = "360-degree panoramic image, Hogwarts campus, hyper realistic"  # @param {type: "string"}
negative = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
model = "/content/pretrained_model/stable-diffusion-2-1-base.safetensors"  # @param {type: "string"}
vae = "/content/vae/stablediffusion.vae.pt"  # @param {type: "string"}
outdir = "/content/output"  # @param {type: "string"}
scale = 7
sampler = "ddim"
steps = 50 # @param {type: "integer"}
precision = "fp16"
width = 2048
height = 512
images_per_prompt = 1
batch_size = 1  # @param {type: "integer"}
clip_skip = 2
seed = 11  # @param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

config = {
    "v2": v2,
    "v_parameterization": v_parameterization,
    "network_module": network_module,
    "network_weight": network_weight,
    "network_mul": float(network_mul),
    "network_args": eval(network_args) if network_args else None,
    "ckpt": model,
    "outdir": outdir,
    "xformers": True,
    "vae": vae if vae else None,
    "fp16": True,
    "W": width,
    "H": height,
    "seed": seed if seed > 0 else None,
    "scale": scale,
    "sampler": sampler,
    "steps": steps,
    "max_embeddings_multiples": 3,
    "batch_size": batch_size,
    "images_per_prompt": images_per_prompt,
    "clip_skip": clip_skip if not v2 else None,
    "prompt": final_prompt,
}

args = ""
for k, v in config.items():
    if k.startswith("_"):
        args += f'"{v}" '
    elif isinstance(v, str):
        args += f'--{k}="{v}" '
    elif isinstance(v, bool) and v:
        args += f"--{k} "
    elif isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    elif isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python stitchdiffusion_test.py {args}"

os.chdir(repo_dir)
!{final_args}

load StableDiffusion checkpoint
loading u-net: <All keys matched successfully>
loading vae: <All keys matched successfully>
loading text encoder: <All keys matched successfully>
load VAE: /content/vae/stablediffusion.vae.pt
additional VAE loaded
loading tokenizer
prepare tokenizer
import network module: networks.lora
load network weights from: /content/lora.safetensors
create LoRA network from weights
create LoRA for Text Encoder: 138 modules.
create LoRA for U-Net: 192 modules.
enable LoRA for text encoder
enable LoRA for U-Net
weights are loaded: <All keys matched successfully>
pipeline is ready.
iteration 1/1
prompt 1/1: 360-degree panoramic image, Hogwarts campus, hyper realistic
negative prompt: lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry
100% 50/50 [04:18<00:00,  5.17s/it]
done!
