<a href="https://colab.research.google.com/github/craigls/sd-lora-test/blob/main/nes_style.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/huggingface/diffusers datasets[vision] accelerate comet_ml Pillow huggingface_hub[hf_xet]
!pip install bitsandbytes
!wget https://gist.githubusercontent.com/craigls/7adeb56a9a32e387f6355baee947653a/raw/5963845edc9f6bd9d0ea2c95faa67fb4874ee997/train_text_to_image_lora.py -O train_text_to_image_lora.py
!accelerate config default

In [None]:
import re
import os
import zipfile
import random
import json
import shutil
import urllib.request
from google.colab import drive
from datetime import datetime
import uuid

ROOT = "/content/drive"
RUN_ID = str(uuid.uuid4())
MODEL_NAME = "stabilityai/stable-diffusion-2-1-base"
INPUT_DIR = f"{ROOT}/MyDrive/sd-lora-nes/input/{RUN_ID}"
OUTPUT_DIR = f"{ROOT}/MyDrive/sd-lora-nes/output"
TMP_DIR = f"{ROOT}/MyDrive/tmp"

IMAGES_ZIPFILE_URL = "https://archive.org/download/No-Intro_Thumbnails_2016-04-10/Nintendo%20-%20Nintendo%20Entertainment%20System.zip"
IMAGES_ZIPFILE = f"{TMP_DIR}/NES.zip"

TRAINING_SET_SIZE = 200

RE_SCREENSHOT = re.compile(r".*/Named_Snaps/.+\.png$")
RE_TITLE = re.compile(r'.*/(.+?) \(')

drive.mount(ROOT)

os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)


In [None]:
from PIL import Image

CAPTION = "nes_style"

# Download the screenshots zip file
if not os.path.exists(IMAGES_ZIPFILE):
  r = urllib.request.urlretrieve(IMAGES_ZIPFILE_URL, IMAGES_ZIPFILE)

# Extract screenshots from zip and create metadata.jsonl for training data input
with open(f'{INPUT_DIR}/metadata.jsonl', 'w') as mf:
  with zipfile.ZipFile(IMAGES_ZIPFILE) as zf:
    files = zf.namelist()[:]
    random.shuffle(files)

    for fn in [fn for fn in files if RE_SCREENSHOT.match(fn)][:TRAINING_SET_SIZE]:
      basename = os.path.basename(fn)

      # Extract file and copy to INPUT_DIR
      destfn = os.path.join(INPUT_DIR, basename)
      zf.extract(fn, TMP_DIR)

      # Resize the image
      img = Image.open(os.path.join(TMP_DIR, fn))
      resized = img.resize((256 * 3, 224 * 3), Image.NEAREST)
      resized.save(destfn)

      # Add to metadata.jsonl
      mf.write(json.dumps({"file_name": destfn, "text": f"nes_style"}) + '\n')
      print(f"Selected image: {destfn} using caption: {CAPTION}")

In [None]:
import comet_ml
comet_ml.login()

In [None]:
!accelerate launch train_text_to_image_lora.py \
  --pretrained_model_name_or_path="{MODEL_NAME}" \
  --train_data_dir="{INPUT_DIR}" \
  --train_batch_size=16 \
  --num_train_epochs=30 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --output_dir="{OUTPUT_DIR}" \
  --report_to=comet_ml \
  --rank=32

In [None]:
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, use_safetensors=False).to("cuda")
pipe.load_lora_weights(f"{OUTPUT_DIR}", weight_name="pytorch_lora_weights.safetensors")

In [None]:
# chatgpt generated prompts
prompts = [
  "nes_style mercenary exploring a ruined cathedral at twilight",
  "nes_style robot racing through a collapsing digital city",
  "nes_style adventurer scaling a forgotten temple deep in the jungle",
  "nes_style hacker navigating through a glowing data stream",
  "nes_style knight battling shadow creatures in an ancient crypt",
  "nes_style pilot escaping a crumbling space station under attack",
  "nes_style treasure hunter crossing a desert to reach a hidden pyramid",
  "nes_style samurai fighting ghostly enemies on a stormy mountain pass",
  "nes_style detective chasing criminals through neon-lit city streets",
  "nes_style soldier protecting a secret base on an icy planet",
]
for image in pipe(prompts, num_images_per_prompt=1).images:
  display(image)
