diff --git a/MedSAM/MedSAMLite/MedSAMLite.py b/MedSAM/MedSAMLite/MedSAMLite.py index f9e31a1..873ab24 100644 --- a/MedSAM/MedSAMLite/MedSAMLite.py +++ b/MedSAM/MedSAMLite/MedSAMLite.py @@ -33,8 +33,7 @@ try: from numpysocket import NumpySocket except: - slicer.util.pip_install('numpysocket') - from numpysocket import NumpySocket + pass # no installation anymore, shorter plugin load # # MedSAMLite @@ -167,11 +166,12 @@ def setup(self) -> None: # Create logic class. Logic implements all computations that should be possible to run # in batch mode, without a graphical user interface. self.logic = MedSAMLiteLogic() + self.logic.widget = self DEPENDENCIES_AVAILABLE = False # Initial Dependency Setup - if os.path.isfile('medsam_info') and os.path.isfile(os.path.join(open('medsam_info', 'r').read(), 'server_essentials/server.py')): + if self.is_setting_available(): try: from segment_anything.modeling import MaskDecoder DEPENDENCIES_AVAILABLE = True @@ -199,7 +199,7 @@ def setup(self) -> None: return - self.logic.server_dir = os.path.join(open('medsam_info', 'r').read(), 'server_essentials') + self.logic.server_dir = os.path.join(self.read_setting(), 'server_essentials') # Load widget from .ui file (created by Qt Designer). # Additional widgets can be instantiated manually and added to self.layout. @@ -218,6 +218,7 @@ def setup(self) -> None: # print(self.ui.clbtnOperation.layout().__dict__) self.ui.clbtnOperation.layout().addWidget(self.editor) # self.layout.addWidget(self.editor) + # self.editor.currentSegmentIDChanged.connect(print) ############################################################################ # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's @@ -355,6 +356,29 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode] # Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each # ui element that needs connection. self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui) + + def is_setting_available(self): + if not (os.path.isfile('.medsam_info') or os.path.isfile(os.path.expanduser('~/.medsam_info'))): + return False + + setting_file = '.medsam_info' if os.path.isfile('.medsam_info') else os.path.expanduser('~/.medsam_info') + server_file = os.path.join(self.read_setting(), 'server_essentials/server.py') + + return os.path.isfile(server_file) + + def read_setting(self): + setting_file = '.medsam_info' if os.path.isfile('.medsam_info') else os.path.expanduser('~/.medsam_info') + with open(setting_file, 'r') as settings: + server_essentials_root = settings.read() + return server_essentials_root + + def write_setting(self, setting): + try: + with open('.medsam_info', 'w') as settings: + settings.write(setting) + except: + with open(os.path.expanduser('~/.medsam_info'), 'w') as settings: + settings.write(setting) # @@ -378,6 +402,7 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic): timer = None progressbar = None server_dir = None + widget = None def __init__(self) -> None: """ @@ -432,7 +457,7 @@ def pip_install_wrapper(self, command, event): def download_wrapper(self, url, filename, event): with urlopen(url) as r: # self.setTotalProgress.emit(int(r.info()["Content-Length"])) - with open(filename, "ab") as f: + with open(filename, "wb") as f: while True: chunk = r.read(1024) if chunk is None: @@ -452,8 +477,7 @@ def install_dependencies(self, ctk_path): return print('Installation will happen in %s'%ctk_path.currentPath) - with open('medsam_info', 'w') as fp: - fp.write(ctk_path.currentPath) + self.widget.write_setting(ctk_path.currentPath) file_url = 'https://github.com/rasakereh/medsam-3dslicer/raw/master/server_essentials.zip' filename = os.path.join(ctk_path.currentPath, 'server_essentials.zip') @@ -467,7 +491,7 @@ def install_dependencies(self, ctk_path): 'Numpy Socket': 'numpysocket', 'FastAPI': 'fastapi', 'Uvicorn': 'uvicorn', - 'MedSam Lite Server': '-e %s'%(self.server_dir) + 'MedSam Lite Server': '-e "%s"'%(self.server_dir) } for dependency in dependencies: @@ -595,10 +619,14 @@ def showSegmentation(self, segmentation_mask): segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode") slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(loaded_seg_file, segment_volume, slicer.vtkSegmentation.EXTENT_REFERENCE_GEOMETRY) - if self.segment_res_group is None: - self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") - self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, self.segment_res_group) + current_seg_group = self.widget.editor.segmentationNode() + if current_seg_group is None: + if self.segment_res_group is None: + self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") + self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) + current_seg_group = self.segment_res_group + + slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group) slicer.mrmlScene.RemoveNode(segment_volume) slicer.mrmlScene.RemoveNode(loaded_seg_file) diff --git a/server/server.py b/server/server.py index a60fe9e..67e434a 100644 --- a/server/server.py +++ b/server/server.py @@ -66,7 +66,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width): PARENT_DIR = os.path.dirname(os.path.abspath(__file__)) MedSAM_CKPT_PATH = os.path.join(PARENT_DIR , "medsam_lite.pth") MEDSAM_IMG_INPUT_SIZE = 1024 -device = 'cpu'#torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") H, W = None, None image = None @@ -260,6 +260,8 @@ def get_image(wmin: int, wmax: int): def get_bbox1024(mask_1024, bbox_shift=3): y_indices, x_indices = np.where(mask_1024 > 0) + if x_indices.shape[0] == 0: + return np.array([0, 0, bbox_shift, bbox_shift]) x_min, x_max = np.min(x_indices), np.max(x_indices) y_min, y_max = np.min(y_indices), np.max(y_indices) # add perturbation to bounding box coordinates diff --git a/server_essentials.zip b/server_essentials.zip index a3cad2f..a8698f4 100644 Binary files a/server_essentials.zip and b/server_essentials.zip differ