Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See our paper: [<font size=5>Visual ChatGPT: Talking, Drawing and Editing with V

## Updates:

- Add DirectML support
- Add custom GPU/CPU assignment
- Add windows support
- Merge HuggingFace ControlNet, Remove download.sh
Expand All @@ -25,7 +26,7 @@ See our paper: [<font size=5>Visual ChatGPT: Talking, Drawing and Editing with V
## Insight & Goal:
One the one hand, **ChatGPT (or LLMs)** serves as a **general interface** that provides a broad and diverse understanding of a
wide range of topics. On the other hand, **Foundation Models** serve as **domain experts** by providing deep knowledge in specific domains.
By leveraging **both general and deep knowledge**, we aim at building an AI that is capable of handling a various of tasks.
By leveraging **both general and deep knowledge**, we aim at building an AI that is capable of handling various tasks.


## Demo
Expand Down Expand Up @@ -64,7 +65,7 @@ set OPENAI_API_KEY={Your_Private_Openai_Key}
# Start Visual ChatGPT !
# You can specify the GPU/CPU assignment by "--load", the parameter indicates which
# Visual Foundation Model to use and where it will be loaded to
# The model and device are sperated by underline '_', the different models are seperated by comma ','
# The model and device are separated by underline '_', the different models are separated by comma ','
# The available Visual Foundation Models can be found in the following table
# For example, if you want to load ImageCaptioning to cpu and Text2Image to cuda:0
# You can use: "ImageCaptioning_cpu,Text2Image_cuda:0"
Expand All @@ -83,7 +84,10 @@ python visual_chatgpt.py --load "ImageCaptioning_cuda:0,ImageEditing_cuda:0,
Image2Seg_cpu,SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2,
Image2Hed_cpu,HedText2Image_cuda:3,Image2Normal_cpu,
NormalText2Image_cuda:3,Image2Line_cpu,LineText2Image_cuda:3"


# Advice for DirectML backend (with any GPU supporting DX12, like AMD Radeon RX 6700XT)
python visual_chatgpt.py --load "ImageCaptioning_dml,Text2Image_dml"

```

## GPU memory usage
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
langchain==0.0.101
torch==1.12.1
torch==1.13.1
torchvision==0.13.1
torch_directml
gradio==3.20.1
accelerate
addict
Expand Down
48 changes: 33 additions & 15 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import gradio as gr
import random
import torch
import torch_directml
import cv2
import re
import uuid
from PIL import Image
import numpy as np
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation

Expand Down Expand Up @@ -119,6 +120,22 @@ def get_new_image_name(org_img_name, func_name="update"):
return os.path.join(head, new_file_name)


def get_touch_dtype(name, device):
if name == 'ImageEditing':
if isinstance(device, str):
revision = 'fp16' if 'cuda' in device else None
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
else:
revision = None
torch_dtype = torch.float32
return revision, torch_dtype
elif isinstance(device, str):
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
else:
torch_dtype = torch.float32
return torch_dtype


class MaskFormer:
def __init__(self, device):
print(f"Initializing MaskFormer to {device}")
Expand Down Expand Up @@ -154,8 +171,7 @@ def __init__(self, device):
print(f"Initializing ImageEditing to {device}")
self.device = device
self.mask_former = MaskFormer(device=self.device)
self.revision = 'fp16' if 'cuda' in device else None
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.revision, self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)

Expand Down Expand Up @@ -193,7 +209,7 @@ class InstructPix2Pix:
def __init__(self, device):
print(f"Initializing InstructPix2Pix to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
safety_checker=None,
torch_dtype=self.torch_dtype).to(device)
Expand Down Expand Up @@ -221,7 +237,7 @@ class Text2Image:
def __init__(self, device):
print(f"Initializing Text2Image to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
torch_dtype=self.torch_dtype)
self.pipe.to(device)
Expand All @@ -247,7 +263,7 @@ class ImageCaptioning:
def __init__(self, device):
print(f"Initializing ImageCaptioning to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
Expand Down Expand Up @@ -290,7 +306,7 @@ def inference(self, inputs):
class CannyText2Image:
def __init__(self, device):
print(f"Initializing CannyText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -346,7 +362,7 @@ def inference(self, inputs):
class LineText2Image:
def __init__(self, device):
print(f"Initializing LineText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -404,7 +420,7 @@ def inference(self, inputs):
class HedText2Image:
def __init__(self, device):
print(f"Initializing HedText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -462,7 +478,7 @@ def inference(self, inputs):
class ScribbleText2Image:
def __init__(self, device):
print(f"Initializing ScribbleText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -517,7 +533,7 @@ def inference(self, inputs):
class PoseText2Image:
def __init__(self, device):
print(f"Initializing PoseText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -624,7 +640,7 @@ def inference(self, inputs):
class SegText2Image:
def __init__(self, device):
print(f"Initializing SegText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -683,7 +699,7 @@ def inference(self, inputs):
class DepthText2Image:
def __init__(self, device):
print(f"Initializing DepthText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -754,7 +770,7 @@ def inference(self, inputs):
class NormalText2Image:
def __init__(self, device):
print(f"Initializing NormalText2Image to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.controlnet = ControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
Expand Down Expand Up @@ -791,7 +807,7 @@ def inference(self, inputs):
class VisualQuestionAnswering:
def __init__(self, device):
print(f"Initializing VisualQuestionAnswering to {device}")
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.torch_dtype = get_touch_dtype(self.__class__.__name__, device)
self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained(
Expand Down Expand Up @@ -824,6 +840,8 @@ def __init__(self, load_dict):

self.models = {}
for class_name, device in load_dict.items():
if device == "dml":
device = torch_directml.device()
self.models[class_name] = globals()[class_name](device=device)

self.tools = []
Expand Down