diff --git a/README.md b/README.md index c5061e51..18604176 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ See our paper: [Visual ChatGPT: Talking, Drawing and Editing with V ## Updates: +- Add DirectML support - Add custom GPU/CPU assignment - Add windows support - Merge HuggingFace ControlNet, Remove download.sh @@ -25,7 +26,7 @@ See our paper: [Visual ChatGPT: Talking, Drawing and Editing with V ## Insight & Goal: One the one hand, **ChatGPT (or LLMs)** serves as a **general interface** that provides a broad and diverse understanding of a wide range of topics. On the other hand, **Foundation Models** serve as **domain experts** by providing deep knowledge in specific domains. -By leveraging **both general and deep knowledge**, we aim at building an AI that is capable of handling a various of tasks. +By leveraging **both general and deep knowledge**, we aim at building an AI that is capable of handling various tasks. ## Demo @@ -64,7 +65,7 @@ set OPENAI_API_KEY={Your_Private_Openai_Key} # Start Visual ChatGPT ! # You can specify the GPU/CPU assignment by "--load", the parameter indicates which # Visual Foundation Model to use and where it will be loaded to -# The model and device are sperated by underline '_', the different models are seperated by comma ',' +# The model and device are separated by underline '_', the different models are separated by comma ',' # The available Visual Foundation Models can be found in the following table # For example, if you want to load ImageCaptioning to cpu and Text2Image to cuda:0 # You can use: "ImageCaptioning_cpu,Text2Image_cuda:0" @@ -83,7 +84,10 @@ python visual_chatgpt.py --load "ImageCaptioning_cuda:0,ImageEditing_cuda:0, Image2Seg_cpu,SegText2Image_cuda:2,Image2Pose_cpu,PoseText2Image_cuda:2, Image2Hed_cpu,HedText2Image_cuda:3,Image2Normal_cpu, NormalText2Image_cuda:3,Image2Line_cpu,LineText2Image_cuda:3" - + +# Advice for DirectML backend (with any GPU supporting DX12, like AMD Radeon RX 6700XT) +python visual_chatgpt.py --load "ImageCaptioning_dml,Text2Image_dml" + ``` ## GPU memory usage diff --git a/requirements.txt b/requirements.txt index 9e031af8..679d3a71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ langchain==0.0.101 -torch==1.12.1 +torch==1.13.1 torchvision==0.13.1 +torch_directml gradio==3.20.1 accelerate addict diff --git a/visual_chatgpt.py b/visual_chatgpt.py index b1a8ff0c..6aad3d34 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -2,6 +2,7 @@ import gradio as gr import random import torch +import torch_directml import cv2 import re import uuid @@ -9,7 +10,7 @@ import numpy as np import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation +from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering from transformers import AutoImageProcessor, UperNetForSemanticSegmentation @@ -119,6 +120,22 @@ def get_new_image_name(org_img_name, func_name="update"): return os.path.join(head, new_file_name) +def get_touch_dtype(name, device): + if name == 'ImageEditing': + if isinstance(device, str): + revision = 'fp16' if 'cuda' in device else None + torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + else: + revision = None + torch_dtype = torch.float32 + return revision, torch_dtype + elif isinstance(device, str): + torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + else: + torch_dtype = torch.float32 + return torch_dtype + + class MaskFormer: def __init__(self, device): print(f"Initializing MaskFormer to {device}") @@ -154,8 +171,7 @@ def __init__(self, device): print(f"Initializing ImageEditing to {device}") self.device = device self.mask_former = MaskFormer(device=self.device) - self.revision = 'fp16' if 'cuda' in device else None - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.revision, self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.inpaint = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device) @@ -193,7 +209,7 @@ class InstructPix2Pix: def __init__(self, device): print(f"Initializing InstructPix2Pix to {device}") self.device = device - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=self.torch_dtype).to(device) @@ -221,7 +237,7 @@ class Text2Image: def __init__(self, device): print(f"Initializing Text2Image to {device}") self.device = device - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) self.pipe.to(device) @@ -247,7 +263,7 @@ class ImageCaptioning: def __init__(self, device): print(f"Initializing ImageCaptioning to {device}") self.device = device - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") self.model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device) @@ -290,7 +306,7 @@ def inference(self, inputs): class CannyText2Image: def __init__(self, device): print(f"Initializing CannyText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -346,7 +362,7 @@ def inference(self, inputs): class LineText2Image: def __init__(self, device): print(f"Initializing LineText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -404,7 +420,7 @@ def inference(self, inputs): class HedText2Image: def __init__(self, device): print(f"Initializing HedText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -462,7 +478,7 @@ def inference(self, inputs): class ScribbleText2Image: def __init__(self, device): print(f"Initializing ScribbleText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -517,7 +533,7 @@ def inference(self, inputs): class PoseText2Image: def __init__(self, device): print(f"Initializing PoseText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -624,7 +640,7 @@ def inference(self, inputs): class SegText2Image: def __init__(self, device): print(f"Initializing SegText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -683,7 +699,7 @@ def inference(self, inputs): class DepthText2Image: def __init__(self, device): print(f"Initializing DepthText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained( "fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -754,7 +770,7 @@ def inference(self, inputs): class NormalText2Image: def __init__(self, device): print(f"Initializing NormalText2Image to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.controlnet = ControlNetModel.from_pretrained( "fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( @@ -791,7 +807,7 @@ def inference(self, inputs): class VisualQuestionAnswering: def __init__(self, device): print(f"Initializing VisualQuestionAnswering to {device}") - self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.torch_dtype = get_touch_dtype(self.__class__.__name__, device) self.device = device self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") self.model = BlipForQuestionAnswering.from_pretrained( @@ -824,6 +840,8 @@ def __init__(self, load_dict): self.models = {} for class_name, device in load_dict.items(): + if device == "dml": + device = torch_directml.device() self.models[class_name] = globals()[class_name](device=device) self.tools = []