In [None]:
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr
import argparse
import cv2
from PIL import Image
import numpy as np

from detgpt.common.config import Config
from detgpt.common.dist_utils import get_rank
from detgpt.common.registry import registry
from detgpt.conversation.conversation import Chat, Conversation, SeparatorStyle  # , CONV_VISION
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict
import GroundingDINO.groundingdino.datasets.transforms as T
from segment_anything import build_sam, SamPredictor 
from huggingface_hub import hf_hub_download
from supervision.draw.color import Color, ColorPalette
import re
import sys
sys.path.append(sys.path[0]+"/tracker")
from tracker.base_tracker import BaseTracker
from tqdm import tqdm

In [None]:
class Args:
    def __init__(self):
        self.cfg_path = "configs/detgpt_tasktune_13b_coco.yaml"  # replace with your actual path
        self.system_path = None  # replace with your actual path if needed
        self.dino_version = "swinb"
        self.gpu_id = 0
        self.options = None  # replace with your actual options if needed
        self.disable_detector = True
        self.enable_system = False

args = Args()
print('Initializing Chat')
cfg = Config(args)


In [None]:
cuda_llm = f"cuda:0"
cuda_detector = f"cuda:0"
cuda_sam = f"cuda:0"
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_sam_path = "output_models/sam_vit_h_4b8939.pth"
if args.dino_version == "swinb":
    config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py"
    ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
else:
    config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    ckpt_filenmae = "groundingdino_swint_ogc.pth"

if args.system_path:
    with open(args.system_path, 'r') as file:
        system_message = file.read()
        print(f"system message: \n {system_message}")
else:
    system_message = "You must strictly answer the question step by step:\n" \
                     "Step-1. describe the given image in detail.\n" \
                     "Step-2. find all the objects related to user input, and concisely explain why these objects meet the requirement.\n" \
                     "Step-3. list out all related objects existing in the image strictly as follows: <Therefore the answer is: [object_names]>.\n" \
                     "Complete all 3 steps as detailed as possible.\n" \
                     "You must finish the answer with complete sentences."

CONV_VISION = Conversation(
    system=system_message,
    roles=("Human", "Assistant"),
    messages=[],
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)


def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
    model_args = SLConfig.fromfile(model_config_path)
    model = build_model(model_args)

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model


def image_transform_grounding(init_image):
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image, _ = transform(init_image, None)  # 3, h, w
    return init_image, image


def image_transform_grounding_for_vis(init_image):
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
    ])
    image, _ = transform(init_image, None)  # 3, h, w
    return image


def print_format(message):
    print(f"*" * 20)
    print(f"\n{message}\n")


def list_to_str(cat_list, sep=". "):
    result = ""
    for cat in cat_list:
        result += cat
        result += sep
    return result[:-1]


def run_grounding(input_image, llm_message_original, box_threshold, text_threshold):
    init_image = input_image.convert("RGB")
    original_size = init_image.size
    _, image_tensor = image_transform_grounding(init_image)
    image_pil: Image = image_transform_grounding_for_vis(init_image)
    response_message = llm_message_original[-1]
    print_format(f"From run grounding, oringinal response message {response_message}")
    pattern1 = r"(?i)therefore,?\s+the\s+answer\s+is:?[\s\[\],]*(\w+[\s,]*)+([ ,]\w+[\s,]*)*"
    pattern2 = r"(?i)therefore,?\s+the\s+target\s+objects?\s+are:?[\s\[\],]*(\w+[\s,]*)+([ ,]\w+[\s,]*)*"
    # Use re.search() to find the match
    match1 = re.search(pattern1, response_message)
    match2 = re.search(pattern2, response_message)
    # Extract the matched substring
    if match1:
        substr = match1.group(0)
        # Remove the unnecessary characters
        substr = re.sub(r"(?i)therefore,?\s+the\s+answer\s+is:?[\s\[\],]*", "", substr)
        categories = re.sub(r"[\[\]]", "", substr)
        cat_list = [c.strip() for c in categories.split(',')]
        # remove duplicate
        cat_list = list(set(cat_list))
        categories = list_to_str(cat_list)
        # Print the result
        print_format(f"Detected categores: {categories}")
    elif match2:
        substr = match2.group(0)
        # Remove the unnecessary characters
        substr = re.sub(r"(?i)therefore,?\s+the\s+target\s+objects?\s+are:?[\s\[\],]*", "", substr)
        categories = re.sub(r"[\[\]]", "", substr)
        cat_list = [c.strip() for c in categories.split(',')]
        # remove duplicate
        cat_list = list(set(cat_list))
        categories = list_to_str(cat_list)
        # Print the result
        print_format(f"Detected categores: {categories}")
    else:
        print_format("No match found.")
        categories = ""
    # run grounidng
    # boxes, logits, phrases = predict(detector, image_tensor, categories, box_threshold, text_threshold, device=f"cuda:{args.gpu_id[0]}")
    boxes, logits, phrases = predict(detector, image_tensor, categories, box_threshold, text_threshold,
                                     device=cuda_detector)
    print_format(f"Detector predicted phrases {phrases}")

    annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[...,::-1]
    segmented_frame_masks = run_segment(np.asarray(image_pil), sam_predictor, boxes=boxes)
    annotated_frame_with_mask = draw_masks(segmented_frame_masks[0], annotated_frame)
    image_with_box_sam = Image.fromarray(annotated_frame_with_mask)

    return image_with_box_sam, f"{categories}"
def run_segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(cuda_sam), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords = None,
        point_labels = None,
        boxes = transformed_boxes,
        multimask_output = False,
        )
    return masks.cpu()
  

def draw_masks(masks, image):
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    colors = ColorPalette.default()
    for idx, mask in enumerate(masks):
        color = colors.by_idx(idx).as_rgb()
        color = np.array([color[0]/255, color[1]/255, color[2]/255, 0.6])

        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        
        
        mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
        annotated_frame_pil = Image.alpha_composite(annotated_frame_pil, mask_image_pil)
    return np.array(annotated_frame_pil)
def draw_mask(mask, image):
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    colors = ColorPalette.default()
    idx = 0
    color = colors.by_idx(idx).as_rgb()
    color = np.array([color[0]/255, color[1]/255, color[2]/255, 0.6])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)


    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
    annotated_frame_pil = Image.alpha_composite(annotated_frame_pil, mask_image_pil)
    return np.array(annotated_frame_pil)

def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True

In [None]:
### sam 30384m
detector = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
detector = detector.to(cuda_detector)
sam_predictor = SamPredictor(build_sam(checkpoint=ckpt_sam_path).to(cuda_sam))
xmem = BaseTracker("./output_models/XMem-s012.pth", device="cuda")

In [None]:

### language model 27830m
model_config = cfg.model_cfg
model_config.device_8bit = cuda_llm
model_cls = registry.get_model_class(model_config.arch)
model_llm = model_cls.from_config(model_config).to(cuda_llm)

vis_processor_cfg = cfg.datasets_cfg.coco_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model_llm, vis_processor, device=cuda_llm)
print_format('Initialization Finished')

In [None]:
def gpt(img, user_message):
    user_prompt = "\nThis is the first frame of a video. You need to answer me and decide a object to track according to my question. End the answer by listing out the only target object to my question strictly as follows: <Therefore the answer is: [object_names]>. \nMy question:"
    img_list = []
    chat_state = CONV_VISION.copy()
    chat.upload_img(img, chat_state, img_list)
    chat.ask(user_message + user_prompt, chat_state)
    num_beams = 5
    temperature = 0
    length_penalty = 0.5
    do_sample = False



    llm_message_old = chat.answer(conv=chat_state,
                                    img_list=img_list,
                                    num_beams=num_beams,
                                    temperature=temperature,
                                    length_penalty=length_penalty,
                                    do_sample=do_sample,
                                    max_new_tokens=300,
                                    max_length=2000)[0]
    return llm_message_old

def get_categories(message):
    pattern1 = r"(?:Therefore, the answer is|Therefore the answer is).*"
    pattern2 = r"(?:Therefore, the target objects are|Therefore the target objects are).*"
    llm_message = re.sub(pattern1, "", message)
    llm_message = re.sub(pattern2, "", llm_message)
    pattern1 = r"(?i)therefore,?\s+the\s+answer\s+is:?[\s\[\],]*(\w+[\s,]*)+([ ,]\w+[\s,]*)*"
    pattern2 = r"(?i)therefore,?\s+the\s+target\s+objects?\s+are:?[\s\[\],]*(\w+[\s,]*)+([ ,]\w+[\s,]*)*"
    # Use re.search() to find the match
    match1 = re.search(pattern1, message)
    match2 = re.search(pattern2, message)
    # Extract the matched substring
    if match1:
        substr = match1.group(0)
        # Remove the unnecessary characters
        substr = re.sub(r"(?i)therefore,?\s+the\s+answer\s+is:?[\s\[\],]*", "", substr)
        categories = re.sub(r"[\[\]]", "", substr)
        cat_list = [c.strip() for c in categories.split(',')]
        # remove duplicate
        cat_list = list(set(cat_list))
        categories = list_to_str(cat_list)
        # Print the result
        print_format(f"Detected categores: {categories}")
    elif match2:
        substr = match2.group(0)
        # Remove the unnecessary characters
        substr = re.sub(r"(?i)therefore,?\s+the\s+target\s+objects?\s+are:?[\s\[\],]*", "", substr)
        categories = re.sub(r"[\[\]]", "", substr)
        cat_list = [c.strip() for c in categories.split(',')]
        # remove duplicate
        cat_list = list(set(cat_list))
        categories = list_to_str(cat_list)
        # Print the result
        print_format(f"Detected categores: {categories}")
    else:
        categories = ""
    return categories

def run_grounding(img,categories):
    box_threshold, text_threshold = 0.25, 0.25
    original_size = img.size
    _, image_tensor = image_transform_grounding(img)
    image_pil: Image = image_transform_grounding_for_vis(img)
 
    
    # run grounidng
    # boxes, logits, phrases = predict(detector, image_tensor, categories, box_threshold, text_threshold, device=f"cuda:{args.gpu_id[0]}")
    boxes, logits, phrases = predict(detector, image_tensor, categories, box_threshold, text_threshold,
                                        device=cuda_detector)
    print_format(f"Detector predicted phrases {phrases}")
    # choose the one with highest score
    annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes[logits.argmax()].reshape((1,-1)), logits=logits[logits.argmax()].reshape((1)), phrases=[phrases[logits.argmax()]])
    annotated_frame = annotated_frame[...,::-1]
    segmented_frame_masks = run_segment(np.asarray(image_pil), sam_predictor, boxes=boxes[logits.argmax()])
    annotated_frame_with_mask = draw_mask(segmented_frame_masks[logits.argmax()], annotated_frame)
    image_with_box_sam = Image.fromarray(annotated_frame_with_mask)
    return segmented_frame_masks

def semi_track(images, template_mask):
    xmem.clear_memory()
    masks = []
    logits = []
    painted_images = []
    images = video_state["painted_images"]
    
    # resize template_mask to images[0].shape
    template_mask = cv2.resize(template_mask.astype(np.uint8), (images[0].shape[1], images[0].shape[0]), interpolation=cv2.INTER_NEAREST)
    for i in tqdm(range(len(images)), desc="Tracking image"):
        if i ==0:           
            mask, logit, painted_image = xmem.track(images[i], template_mask)
            masks.append(mask)
            logits.append(logit)
            painted_images.append(painted_image)

        else:
            mask, logit, painted_image = xmem.track(images[i])
            masks.append(mask)
            logits.append(logit)
            painted_images.append(painted_image)
    return painted_images

In [None]:
def track(images, user_message):
    img = Image.fromarray(images[0]).convert('RGB')
    llm_message_old = gpt(img, user_message)
    response_message = llm_message_old
    print_format(f"From run grounding, oringinal response message {response_message}")
    categories = get_categories(response_message)
    
    segmented_frame_masks = run_grounding(img,categories)
    template_mask = segmented_frame_masks[0][0].cpu().numpy()
    painted_images = semi_track(images, template_mask)
    return painted_images


In [None]:
import gradio as gr
import argparse
import gdown
import cv2
import numpy as np
import os
import sys
import psutil
import time

video_state = gr.State(
    {
    "user_name": "",
    "video_name": "",
    "origin_images": None,
    "painted_images": None,
    "masks": None,
    "inpaint_masks": None,
    "logits": None,
    "select_frame_number": 0,
    "fps": 30
    }
)

video_path = "./examples/elon.mp4"
frames = []
user_name = time.time()
try:
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret == True:
            current_memory_usage = psutil.virtual_memory().percent
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            if current_memory_usage > 90:
                operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
                print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
                break
        else:
            break
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
    print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
image_size = (frames[0].shape[0],frames[0].shape[1]) 
# initialize video_state
video_state = {
    "user_name": user_name,
    "video_name": os.path.split(video_path)[-1],
    "origin_images": frames,
    "painted_images": frames.copy(),
    "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
    "logits": [None]*len(frames),
    "select_frame_number": 0,
    "fps": fps
    }
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
img = Image.fromarray(video_state["origin_images"][video_state["select_frame_number"]]).convert("RGB")

In [None]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]

In [None]:
painted_images = track(following_frames, "I want to track elon.")

In [None]:
import torchvision

def generate_video_from_frames(frames, output_path, fps=30):
    """
    Generates a video from a list of frames.
    
    Args:
        frames (list of numpy arrays): The frames to include in the video.
        output_path (str): The path to save the generated video.
        fps (int, optional): The frame rate of the output video. Defaults to 30.
    """
    frames = torch.from_numpy(np.asarray(frames))
    if not os.path.exists(os.path.dirname(output_path)):
        os.makedirs(os.path.dirname(output_path))
    torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
    return output_path

In [None]:
generate_video_from_frames(painted_images,"./examples/elon_output.mp4", fps=30)