In [1]:
import os
from diffusers import StableDiffusionPipeline
import torch
from torchvision import models
from torchvision import transforms
#from decimal import Decimal

device = "cuda"

# load model
model_path = "./trained_weights/als_10_rings/" #pytorch_lora_weights.bin"
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    safety_checker=None,
    feature_extractor=None,
    requires_safety_checker=False
)

# load lora weights
pipe.unet.load_attn_procs(model_path)
# set to use GPU for inference
pipe.to(device)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.20.0.dev0",
  "_name_or_path": "runwayml/stable-diffusion-v1-5",
  "feature_extractor": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [2]:
def generate(prompt, out_name='im'):
    out_dir = '40k_generated/test'
    if not os.path.exists(f'./{out_dir}'):
        os.makedirs(f'./{out_dir}')
    image = pipe(prompt, num_inference_steps=30).images[0]
    image.save(f"./{out_dir}/{out_name}.jpg")
    # print(outputs.images)
    # for idx, image in enumerate(outputs.images):
    #     image.save(f"./{out_dir}/{out_name}_{idx}.jpg")
    

In [3]:
def classifier(model, image, threshold=0.5) -> bool:
    model = model.to(device)
    image = image.to(device).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        output = model(image)
    prob = torch.nn.Softmax(dim=1)(output)
    #print(prob[:,1] > threshold)
    return prob[:,1] > threshold, prob

In [4]:
import ipywidgets as widgets
from IPython.display import display, Image, clear_output
from PIL import Image as PIL_Image
import math

text_input = widgets.Text(
    value='',
    placeholder='Enter a prompt',
    description='Prompt:',
    disabled=False
)

# threshold_input = widgets.Text(
#     value='',
#     placeholder='Enter a probability [0-1], defualt 0.5',
#     description='Threshold:',
#     disabled=False
# )

threshold_input = widgets.FloatSlider(
    value=0.5,           # Initial value
    min=0,               # Minimum value
    max=1,               # Maximum value
    step=0.01,           # Step size
    description='Threshold:',
    orientation='horizontal'  # Slider orientation (horizontal or vertical)
)


batch_size = widgets.IntSlider(
    value=1,           # Initial value
    min=1,               # Minimum value
    max=48,               # Maximum value
    step=1,           # Step size
    description='Batchsize:',
    orientation='horizontal'  # Slider orientation (horizontal or vertical)
)


max_epoch = widgets.IntSlider(
    value=5,           # Initial value
    min=1,               # Minimum value
    max=100,               # Maximum value
    step=1,           # Step size
    description='N_attempts:',
    orientation='horizontal'  # Slider orientation (horizontal or vertical)
)

display_button = widgets.Button(
    description="Generate Image",
    button_style="info"
)


pretrained_classfiers = {
    'ResNet50': 'resnet_100epochs.pth',
    'ViT-16x16': 'vit_p16_50epochs.pth',
    'ViT-32x32': 'vit_p32_50epochs.pth',
}

dropdown = widgets.Dropdown(
    options=pretrained_classfiers,
    description='Classifier:',
    disabled=False
)


# Create an output widget to display the image
output_image = widgets.Output()
filename='/lovelace/zhuowen/diffusers/als/40k_generated/test/im.jpg'

# Define a function to display the image when the button is clicked
def display_image(button):
    flag = False
    epoch = 0
    prob = 0
    
    while not flag and epoch < max_epoch.value:
        with output_image:
            clear_output()  # Clear any previous output
            image_prompt = text_input.value.strip()
            #threshold = threshold_input.value.strip()
            cls_dir = './discriminator_data/checkpoints/'
            cls_model = torch.load(cls_dir+dropdown.value)
            
            epoch += 1
            if image_prompt:
                print(f"Generating {batch_size.value} images, {epoch}th attempt.")
                print_prob = prob[0][1].to('cpu').item() if isinstance(prob, torch.Tensor) else 0 
                print(f'Previous probability: {round(print_prob, 2)}')
                generate(image_prompt)
                try:
                    image = Image(filename=filename)
                except Exception as e:
                    print(f"Error: {e}")
    
                transform=transforms.Compose([transforms.Resize(224),
                                              transforms.CenterCrop(224),
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                             ])
                image_tensor = transform(PIL_Image.open(filename))
                b, prob= classifier(cls_model, image_tensor, threshold=threshold_input.value)
                flag = b
                print_prob = prob[0][1].to('cpu').item() if isinstance(prob, torch.Tensor) else 0
                print(f'Current probability: {round(print_prob, 2)}')

                if flag:
                    display(image)



# Attach the function to the button's click event
display_button.on_click(display_image)

# Display the widgets
widgets.VBox([text_input, dropdown, threshold_input, batch_size, max_epoch, display_button, output_image])


VBox(children=(Text(value='', description='Prompt:', placeholder='Enter a prompt'), Dropdown(description='Clas…