Skip to content

Commit

Permalink
node denoising
Browse files Browse the repository at this point in the history
  • Loading branch information
gaudelbijay committed Aug 15, 2023
1 parent caaf770 commit c4eb10e
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/diffusion_model/node_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self,
train_dir='../data/raw/',
test_dir='../data/test/',
batch_size = 1,
device = "cpu") -> None:
device = "cuda") -> None:

self.image_size = image_size

Expand Down Expand Up @@ -65,23 +65,35 @@ def model_init(self,):
)
self.bridge = CvBridge()

def denoise(self, image, t=torch.tensor([5])):
def denoise(self, image, t=torch.tensor([1])):
t = t.to(self.device)
predicted_image = self.diffusion.to(self.device).model_predictions(image, t)[1]
return predicted_image

def call_back(self, image):
try:
cv_image = self.bridge.imgmsg_to_cv2(image, desired_encoding="bgr8")
image_tensor = torch.from_numpy(cv_image.transpose(2,0,1)).float().unsqueeze(0)

# publish directly without denoising
# image_tensor = self.bridge.cv2_to_imgmsg(cv_image, "bgr8")
# self.pub_denoised_img.publish(image_tensor)

# denoise the image
image_tensor = torch.from_numpy(cv_image.transpose(2,0,1)).float().unsqueeze(0)
denoised_image_tensor = self.denoise(image_tensor.to(self.device))

# convert denoised image back to opencv format
if denoised_image_tensor.is_cuda:
denoised_image_tensor = denoised_image_tensor.cpu().detach()

# # convert denoised image back to opencv format
denoised_cv_image = denoised_image_tensor.squeeze(0).byte().numpy().transpose(1,2,0)
cv2.namedWindow('Display', cv2.WINDOW_NORMAL)
cv2.imshow('Display', denoised_cv_image)
cv2.waitKey(1)

# show image in opencv
# cv2.namedWindow('Display', cv2.WINDOW_NORMAL)
# cv2.imshow('Display', denoised_cv_image)
# cv2.waitKey(1)

#convert image to ros message
denoised_ros_image_msg = self.bridge.cv2_to_imgmsg(denoised_cv_image, "bgr8")
#publish the image
self.pub_denoised_img.publish(denoised_ros_image_msg)
Expand Down

0 comments on commit c4eb10e

Please sign in to comment.