Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rasakereh committed Jan 29, 2024
1 parent 5fa1b29 commit 84cef19
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
52 changes: 40 additions & 12 deletions MedSAM/MedSAMLite/MedSAMLite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
try:
from numpysocket import NumpySocket
except:
slicer.util.pip_install('numpysocket')
from numpysocket import NumpySocket
pass # no installation anymore, shorter plugin load

#
# MedSAMLite
Expand Down Expand Up @@ -167,11 +166,12 @@ def setup(self) -> None:
# Create logic class. Logic implements all computations that should be possible to run
# in batch mode, without a graphical user interface.
self.logic = MedSAMLiteLogic()
self.logic.widget = self

DEPENDENCIES_AVAILABLE = False

# Initial Dependency Setup
if os.path.isfile('medsam_info') and os.path.isfile(os.path.join(open('medsam_info', 'r').read(), 'server_essentials/server.py')):
if self.is_setting_available():
try:
from segment_anything.modeling import MaskDecoder
DEPENDENCIES_AVAILABLE = True
Expand Down Expand Up @@ -199,7 +199,7 @@ def setup(self) -> None:

return

self.logic.server_dir = os.path.join(open('medsam_info', 'r').read(), 'server_essentials')
self.logic.server_dir = os.path.join(self.read_setting(), 'server_essentials')

# Load widget from .ui file (created by Qt Designer).
# Additional widgets can be instantiated manually and added to self.layout.
Expand All @@ -218,6 +218,7 @@ def setup(self) -> None:
# print(self.ui.clbtnOperation.layout().__dict__)
self.ui.clbtnOperation.layout().addWidget(self.editor)
# self.layout.addWidget(self.editor)
# self.editor.currentSegmentIDChanged.connect(print)
############################################################################

# Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's
Expand Down Expand Up @@ -355,6 +356,29 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]
# Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each
# ui element that needs connection.
self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui)

def is_setting_available(self):
if not (os.path.isfile('.medsam_info') or os.path.isfile(os.path.expanduser('~/.medsam_info'))):
return False

setting_file = '.medsam_info' if os.path.isfile('.medsam_info') else os.path.expanduser('~/.medsam_info')
server_file = os.path.join(self.read_setting(), 'server_essentials/server.py')

return os.path.isfile(server_file)

def read_setting(self):
setting_file = '.medsam_info' if os.path.isfile('.medsam_info') else os.path.expanduser('~/.medsam_info')
with open(setting_file, 'r') as settings:
server_essentials_root = settings.read()
return server_essentials_root

def write_setting(self, setting):
try:
with open('.medsam_info', 'w') as settings:
settings.write(setting)
except:
with open(os.path.expanduser('~/.medsam_info'), 'w') as settings:
settings.write(setting)


#
Expand All @@ -378,6 +402,7 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic):
timer = None
progressbar = None
server_dir = None
widget = None

def __init__(self) -> None:
"""
Expand Down Expand Up @@ -432,7 +457,7 @@ def pip_install_wrapper(self, command, event):
def download_wrapper(self, url, filename, event):
with urlopen(url) as r:
# self.setTotalProgress.emit(int(r.info()["Content-Length"]))
with open(filename, "ab") as f:
with open(filename, "wb") as f:
while True:
chunk = r.read(1024)
if chunk is None:
Expand All @@ -452,8 +477,7 @@ def install_dependencies(self, ctk_path):
return

print('Installation will happen in %s'%ctk_path.currentPath)
with open('medsam_info', 'w') as fp:
fp.write(ctk_path.currentPath)
self.widget.write_setting(ctk_path.currentPath)

file_url = 'https://github.com/rasakereh/medsam-3dslicer/raw/master/server_essentials.zip'
filename = os.path.join(ctk_path.currentPath, 'server_essentials.zip')
Expand All @@ -467,7 +491,7 @@ def install_dependencies(self, ctk_path):
'Numpy Socket': 'numpysocket',
'FastAPI': 'fastapi',
'Uvicorn': 'uvicorn',
'MedSam Lite Server': '-e %s'%(self.server_dir)
'MedSam Lite Server': '-e "%s"'%(self.server_dir)
}

for dependency in dependencies:
Expand Down Expand Up @@ -595,10 +619,14 @@ def showSegmentation(self, segmentation_mask):
segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(loaded_seg_file, segment_volume, slicer.vtkSegmentation.EXTENT_REFERENCE_GEOMETRY)

if self.segment_res_group is None:
self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, self.segment_res_group)
current_seg_group = self.widget.editor.segmentationNode()
if current_seg_group is None:
if self.segment_res_group is None:
self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
current_seg_group = self.segment_res_group

slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
slicer.mrmlScene.RemoveNode(segment_volume)
slicer.mrmlScene.RemoveNode(loaded_seg_file)

Expand Down
4 changes: 3 additions & 1 deletion server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width):
PARENT_DIR = os.path.dirname(os.path.abspath(__file__))
MedSAM_CKPT_PATH = os.path.join(PARENT_DIR , "medsam_lite.pth")
MEDSAM_IMG_INPUT_SIZE = 1024
device = 'cpu'#torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

H, W = None, None
image = None
Expand Down Expand Up @@ -260,6 +260,8 @@ def get_image(wmin: int, wmax: int):

def get_bbox1024(mask_1024, bbox_shift=3):
y_indices, x_indices = np.where(mask_1024 > 0)
if x_indices.shape[0] == 0:
return np.array([0, 0, bbox_shift, bbox_shift])
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
Expand Down
Binary file modified server_essentials.zip
Binary file not shown.

0 comments on commit 84cef19

Please sign in to comment.