Inference Notebook for FBA dataset
- Model path: "sdxl-base-1.0-fbadataset5e-4-lrwrmp0-ep20-withoutpadding-noflip-lora-newlabels"
- Training Dataset: "artisanalwasp/resized_fba_with_lanczos_wo_wearscores_refactoredlabels" 

- Install dependencies

In [1]:
! pip install -U peft transformers diffusers

Collecting peft
  Using cached peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting transformers
  Using cached transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
Collecting diffusers
  Using cached diffusers-0.32.2-py3-none-any.whl.metadata (18 kB)
Collecting accelerate>=0.21.0 (from peft)
  Downloading accelerate-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Using cached peft-0.14.0-py3-none-any.whl (374 kB)
Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hUsing cached diffusers-0.32.2-py3-none-any.whl (3.2 MB)
Downloading accelerate-1.4.0-py3-none-any.whl (342 kB)
Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
Installing col

Diffusers Pipeline Setup        
Loading Stable Diffusion XL Base 1.0 as base model then attaching LoRa adapters


In [1]:
%load_ext autoreload
%autoreload 2

from diffusers import DiffusionPipeline
import torch
import os
from src.prompt_generator import PromptGenerator

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
).to("cuda")

# Optional CPU offloading to save same GPU memory
pipe.enable_model_cpu_offload()

model_path = "artisanalwasp/sdxl-base-1.0-fbadataset5e-4-lrwrmp0-ep20-withoutpadding-noflip-lora-newlabels"
model_name = (model_path).split("/")[-1]
print("Name of the model: ", model_name)

# Loading Trained LoRa weights
pipe.load_lora_weights(model_path)

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  4.85it/s]


Name of the model:  sdxl-base-1.0-fbadataset5e-4-lrwrmp0-ep20-withoutpadding-noflip-lora-newlabels


Generate prompt that includes build-up edge wear (BUE)
- Prompt for only BUE wear
- Random prompt including BUE wear (can contain other wear types too)

In [None]:
total_num_images = 50
tool_type = "insert"
cat = "bue"
prompt = PromptGenerator.generate_with_one_cat(tool_type, cat)

os.makedirs(f"./generated_images/{model_name}/bue_only", exist_ok=True)
for i in range(total_num_images):
    image = pipe(
        prompt=prompt,
        num_inference_steps=50,
        width=1280,
        height=1024,
        guidance_scale= 7.0 #how much to prompt effects the generated image
    ).images[0]
    
    image.save(f"./generated_images/{model_name}/bue_only/{prompt}_{i}.bmp")

100%|██████████| 50/50 [00:11<00:00,  4.48it/s]


Image 1 generated in 13.27 seconds


Generating images for Groove category

In [2]:
total_num_images = 50
tool_type = "insert"
cat = "groove"
prompt = PromptGenerator.generate_with_one_cat(tool_type, cat)

os.makedirs(f"./generated_images/{model_name}/groove_only", exist_ok=True)

for i in range(total_num_images):
    image = pipe(
        prompt=prompt,
        num_inference_steps=50,
        width=1280,
        height=1024,
        guidance_scale= 7.0 #how much to prompt effects the generated image
    ).images[0]

    image.save(f"./generated_images/{model_name}/groove_only/{prompt}_{i}.bmp")

100%|██████████| 50/50 [00:11<00:00,  4.28it/s]
100%|██████████| 50/50 [00:10<00:00,  4.63it/s]
100%|██████████| 50/50 [00:10<00:00,  4.66it/s]
100%|██████████| 50/50 [00:10<00:00,  4.72it/s]
100%|██████████| 50/50 [00:10<00:00,  4.70it/s]
100%|██████████| 50/50 [00:10<00:00,  4.72it/s]
 76%|███████▌  | 38/50 [00:08<00:02,  4.54it/s]


KeyboardInterrupt: 

In [None]:
total_num_images = 30
tool_type = "insert"
categories = ["bue", "groove", "flank"]
cat = "bue"
prompt = ""
os.makedirs(f"./generated_images/{model_name}/bue", exist_ok=True)

for i in range(total_num_images):
    random_prompt = PromptGenerator.generate_random_prompt_with_a_category(tool_type=tool_type, categories=categories, fixed_cat=cat)
    image = pipe(
        prompt=random_prompt,
        num_inference_steps=50,
        width=1280,
        height=1024,
        guidance_scale= 7.0 #how much to prompt effects the generated image
    ).images[0]

    image.save(f"./generated_images/{model_name}/bue/{random_prompt}_{i}.bmp")

Random prompt without any wear type restrictions

In [28]:
total_num_images = 50
tool_type = "insert"
categories = ["bue", "groove", "flank"]
prompt = ""
os.makedirs(f"./generated_images/{model_name}/random_prompt2", exist_ok=True)

for i in range (total_num_images):
    random_prompt = PromptGenerator.generate_random_prompt(tool_type, categories)
    # print(prompt)
    image = pipe(
        prompt=random_prompt,
        num_inference_steps=50,
        width=1280,
        height=1024,
        guidance_scale= 7.0 #how much to prompt effects the generated image
    ).images[0]

    image.save(f"./generated_images/{model_name}/random_prompt2/{random_prompt}_{i}.bmp")

100%|██████████| 50/50 [00:10<00:00,  4.60it/s]
100%|██████████| 50/50 [00:10<00:00,  4.68it/s]
100%|██████████| 50/50 [00:10<00:00,  4.72it/s]
100%|██████████| 50/50 [00:10<00:00,  4.70it/s]
100%|██████████| 50/50 [00:11<00:00,  4.54it/s]
100%|██████████| 50/50 [00:10<00:00,  4.70it/s]
100%|██████████| 50/50 [00:10<00:00,  4.68it/s]
100%|██████████| 50/50 [00:10<00:00,  4.67it/s]
100%|██████████| 50/50 [00:10<00:00,  4.66it/s]
100%|██████████| 50/50 [00:10<00:00,  4.68it/s]
100%|██████████| 50/50 [00:10<00:00,  4.67it/s]
100%|██████████| 50/50 [00:10<00:00,  4.67it/s]
100%|██████████| 50/50 [00:10<00:00,  4.68it/s]
100%|██████████| 50/50 [00:10<00:00,  4.67it/s]
100%|██████████| 50/50 [00:10<00:00,  4.66it/s]
100%|██████████| 50/50 [00:10<00:00,  4.60it/s]
100%|██████████| 50/50 [00:10<00:00,  4.64it/s]
100%|██████████| 50/50 [00:10<00:00,  4.68it/s]
100%|██████████| 50/50 [00:11<00:00,  4.54it/s]
100%|██████████| 50/50 [00:10<00:00,  4.65it/s]
100%|██████████| 50/50 [00:10<00:00,  4.