Skip to content

Commit

Permalink
Merge pull request #103 from jjhickmon/main
Browse files Browse the repository at this point in the history
Added support for MPS Apple Silicon and quality of life improvements to the GUI
  • Loading branch information
hkchengrex authored Aug 21, 2023
2 parents 23f3681 + 3371c72 commit 1e1f2ec
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ output/
.vscode/
workspace/
run*.sh
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 2 additions & 2 deletions docs/DEMO.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ python interactive_demo.py --video [path to the video] --num_objects 4

* Use the slider to change the current frame. "Play Video" automatically progresses the video.
* Select interaction type: "scribble", "click", or "free". Both scribble and "free" (free-hand drawing) modify an existing mask. Using "click" on an existing object mask (i.e., a mask from propagation or other interaction methods) will reset the mask. This is because f-BRS does not take an existing mask as input.
* Select the target object using the number keys. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you start the program through the command line.
* Select the target object using the number keys or the object dial. "1" corresponds to the first object, etc. You need to specify the maximum number of objects when you start the program through the command line. On Mac, you can use ctrl+number to select the object.
* Use propagate forward/backward to let XMem do the job. Pause when correction is needed. It will only automatically stops when it hits the end of the video.
* Make sure all objects are correctly labeled before propagating. The program doesn't care which object you have interacted with -- it treats everything as user-provided inputs. Not labelling an object implicitly means that it is part of the background.
* The memory bank might be "polluted" by bad memory frames. Feel free to hit clear memory to erase that. Propagation runs faster with a small memory bank.
Expand All @@ -53,6 +53,6 @@ python interactive_demo.py --video [path to the video] --num_objects 4
- Make sure you specified `--num_objects`. We ignore object IDs that exceed `num_objects`.
2. The GUI feels slow!
- The GUI needs to read/write images and masks on-the-go. Ideally this can be implemented with multiple threads with look-ahead but I didn't. The overheads will be smaller if you place the `workspace` on a SSD. You can also use a ram disk. `eval.py` will almost certainly be faster.
- It takes more time to process more objects. This depends on `num_objects`, but not the actual number of objects that the user has annotated. *This does not mean that running time is directly proportional to the number of objects. There is significant shared computation.*
- It takes more time to process more objects. This depends on `num_objects`, not the actual number of objects that the user has annotated. *This does not mean that running time is directly proportional to the number of objects. There is significant shared computation.*
3. Can I run this on a remote server?
- X11 forwarding should be possible. I have not tried this and would love to know if it works for you.
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
prob = torch.flip(prob, dims=[-1])

# Probability mask -> index mask
out_mask = torch.argmax(prob, dim=0)
out_mask = torch.max(prob, dim=0).indices
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)

if args.save_scores:
Expand Down
9 changes: 8 additions & 1 deletion inference/interact/fbrs/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch
try:
from torch import mps
except:
pass

from ..fbrs.inference import clicker
from ..fbrs.inference.predictors import get_predictor
Expand Down Expand Up @@ -35,7 +39,10 @@ def add_click(self, x, y, is_positive):
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
torch.cuda.empty_cache()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
elif self.device.type == 'mps':
mps.empty_cache()

if self.probs_history:
self.probs_history.append((self.probs_history[-1][0], pred))
Expand Down
194 changes: 131 additions & 63 deletions inference/interact/gui.py

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions inference/interact/gui_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import (QBoxLayout, QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)


def create_parameter_box(min_val, max_val, text, step=1, callback=None):
Expand All @@ -10,12 +10,12 @@ def create_parameter_box(min_val, max_val, text, step=1, callback=None):
dial.setMaximumWidth(150)
dial.setMinimum(min_val)
dial.setMaximum(max_val)
dial.setAlignment(Qt.AlignRight)
dial.setAlignment(Qt.AlignmentFlag.AlignRight)
dial.setSingleStep(step)
dial.valueChanged.connect(callback)

label = QLabel(text)
label.setAlignment(Qt.AlignRight)
label.setAlignment(Qt.AlignmentFlag.AlignRight)

layout.addWidget(label)
layout.addWidget(dial)
Expand All @@ -29,12 +29,18 @@ def create_gauge(text):
gauge = QProgressBar()
gauge.setMaximumHeight(28)
gauge.setMaximumWidth(200)
gauge.setAlignment(Qt.AlignCenter)
gauge.setAlignment(Qt.AlignmentFlag.AlignCenter)

label = QLabel(text)
label.setAlignment(Qt.AlignRight)
label.setAlignment(Qt.AlignmentFlag.AlignRight)

layout.addWidget(label)
layout.addWidget(gauge)

return gauge, layout


def apply_to_all_children_widget(layout, func):
# deliberately non-recursive
for i in range(layout.count()):
func(layout.itemAt(i).widget())
2 changes: 1 addition & 1 deletion inference/interact/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def end_path(self):
self.curr_path = [[] for _ in range(self.K + 1)]

def predict(self):
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1)
# self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
# self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
# self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
Expand Down
20 changes: 14 additions & 6 deletions inference/interact/interactive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def image_to_torch(frame: np.ndarray, device='cuda'):
return frame_norm, frame

def torch_prob_to_numpy_mask(prob):
mask = torch.argmax(prob, dim=0)
mask = torch.max(prob, dim=0).indices
mask = mask.cpu().numpy().astype(np.uint8)
return mask

Expand All @@ -26,16 +26,24 @@ def index_numpy_to_one_hot_torch(mask, num_classes):
"""
Some constants fro visualization
"""
try:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
except:
device = torch.device("cpu")

color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
# scales for better visualization
color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8)
color_map = color_map_np.tolist()
if torch.cuda.is_available():
color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
color_map_torch = torch.from_numpy(color_map_np).to(device) / 255

grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32)
if torch.cuda.is_available():
grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
grayscale_weights_torch = torch.from_numpy(grayscale_weights).to(device).unsqueeze(0)

def get_visualization(mode, image, mask, layer, target_object):
if mode == 'fade':
Expand Down Expand Up @@ -112,7 +120,7 @@ def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
# Changes the image in-place to avoid copying
image = image.permute(1, 2, 0)
im_overlay = image
mask = torch.argmax(mask, dim=0)
mask = torch.max(mask, dim=0).indices

colored_mask = color_map_torch[mask]
foreground = image*alpha + (1-alpha)*colored_mask
Expand Down
1 change: 1 addition & 0 deletions inference/interact/s2m_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'):
self.device = device

def interact(self, image, prev_mask, scr_mask):
print(self.device)
image = image.to(self.device, non_blocking=True)
prev_mask = prev_mask.unsqueeze(0)

Expand Down
38 changes: 28 additions & 10 deletions interactive_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
from os import path
# fix for Windows
if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
Expand All @@ -17,15 +18,21 @@
from inference.interact.fbrs_controller import FBRSController
from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M

from PyQt5.QtWidgets import QApplication
from PyQt6.QtWidgets import QApplication
from inference.interact.gui import App
from inference.interact.resource_manager import ResourceManager
from contextlib import nullcontext

torch.set_grad_enabled(False)

if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

if __name__ == '__main__':

# Arguments parsing
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
Expand Down Expand Up @@ -64,32 +71,43 @@
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()

# create temporary workspace if not specified
config = vars(args)
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True

with torch.cuda.amp.autocast(enabled=not args.no_amp):
if config["workspace"] is None:
if config["images"] is not None:
basename = path.basename(config["images"])
elif config["video"] is not None:
basename = path.basename(config["video"])[:-4]
else:
raise NotImplementedError(
'Either images, video, or workspace has to be specified')

config["workspace"] = path.join('./workspace', basename)

with torch.cuda.amp.autocast(enabled=not args.no_amp) if device.type == 'cuda' else nullcontext():
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
network = XMem(config, args.model, map_location=device).to(device).eval()

# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().cuda().eval()
s2m_saved = torch.load(args.s2m_model, map_location=device)
s2m_model = S2M().to(device).eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None

s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255)
s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255, device=device)
if args.fbrs_model is not None:
fbrs_controller = FBRSController(args.fbrs_model)
fbrs_controller = FBRSController(args.fbrs_model, device=device)
else:
fbrs_controller = None

# Manages most IO
resource_manager = ResourceManager(config)

app = QApplication(sys.argv)
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config)
sys.exit(app.exec_())
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config, device)
sys.exit(app.exec())
2 changes: 1 addition & 1 deletion requirements_demo.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
PyQt5
PyQt6
Cython
scipy

0 comments on commit 1e1f2ec

Please sign in to comment.