55import numpy as np
66from matplotlib import gridspec
77from matplotlib import pyplot as plt
8+ import matplotlib .font_manager as fm
9+ import matplotlib .patheffects as path_effects
10+ import matplotlib
11+ # make sure Tk backend is used
12+ matplotlib .use ("TkAgg" )
13+ from PIL import Image
814
915
10- LABEL_NAMES = np .asarray ([
11- 'background' , 'aeroplane' , 'bicycle' , 'bird' , 'boat' , 'bottle' , 'bus' ,
12- 'car' , 'cat' , 'chair' , 'cow' , 'diningtable' , 'dog' , 'horse' , 'motorbike' ,
13- 'person' , 'pottedplant' , 'sheep' , 'sofa' , 'train' , 'tv'
14- ])
16+ def add_meme_text (_str , image ):
17+ # Find location to add
18+ _width = image .shape [1 ]
19+ _height = image .shape [0 ]
20+
21+ # Meme font
22+ prop = fm .FontProperties (fname = 'fonts/debussy.ttf' )
23+
24+ # Justify text
25+ text = plt .text (_width * 0.5 , _height * 0.8 , _str , color = 'blue' , fontproperties = prop ,
26+ multialignment = 'center' , wrap = True ,
27+ ha = 'center' , va = 'center' , size = 20 )
28+ text .set_path_effects ([path_effects .Stroke (linewidth = 6 , foreground = 'white' ),
29+ path_effects .Normal ()])
30+
31+
32+ def make_transparent (src ):
33+ tmp = cv2 .cvtColor (src , cv2 .COLOR_BGR2GRAY )
34+ _ , alpha = cv2 .threshold (tmp , 0 , 255 , cv2 .THRESH_BINARY )
35+ b , g , r = cv2 .split (src )
36+ rgba = [b , g , r , alpha ]
37+ dst = cv2 .merge (rgba , 4 )
38+ return dst
39+
40+
41+ def smooth_edges (image ):
42+ # Median blur - smooth edges
43+ img = cv2 .medianBlur (image , 35 )
44+ return img
45+
46+
47+ def contour_mask (image , mask ):
48+ # Get edges through Canny edge detection
49+ edged = cv2 .Canny (mask , 30 , 200 )
50+ # Finding Contours
51+ contours , hierarchy = cv2 .findContours (edged ,
52+ cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_NONE )
53+ # Draw contours
54+ # -1 signifies drawing all contours
55+ cv2 .drawContours (image , contours , - 1 , (255 , 255 , 255 ), 8 )
56+
57+ return image
58+
59+
60+ def mask_out (src , mask ):
61+ _mask_out = cv2 .subtract (mask , src )
62+ _mask_out = cv2 .subtract (mask , _mask_out )
63+ return _mask_out
1564
1665
1766class Segment :
18- def __init__ (self , img_path ):
67+ def __init__ (self , img_path , meme_text , dpi = 127.68 ):
1968 # Variables
2069 self .segment_map = None
70+ self .meme_text = meme_text
71+ self .dpi = dpi
2172
2273 # Initialize TF model
2374 print ("Using model: " + settings .model_file )
@@ -30,104 +81,58 @@ def __init__(self, img_path):
3081 self .image = cv2 .imread (img_path )
3182 print ("Loaded image" )
3283
33- # Additional variables
34- FULL_LABEL_MAP = np .arange (len (LABEL_NAMES )).reshape (len (LABEL_NAMES ), 1 )
35- self .FULL_COLOR_MAP = self .label_to_color_image (FULL_LABEL_MAP )
36-
3784 def find_segments (self ):
3885 output_tensors = run_tf .run_model (self .interpreter , self .image )
3986 self .segment_map = run_tf .output_to_classes (output_tensors )
4087 return self .segment_map
4188
42- # Code taken from Google Colab
43- # https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
44- @staticmethod
45- def create_pascal_label_colormap ():
46- """Creates a label colormap used in PASCAL VOC segmentation benchmark.
47-
48- Returns:
49- A Colormap for visualizing segmentation results.
50- """
51- colormap = np .zeros ((256 , 3 ), dtype = int )
52- ind = np .arange (256 , dtype = int )
53-
54- for shift in reversed (range (8 )):
55- for channel in range (3 ):
56- colormap [:, channel ] |= ((ind >> channel ) & 1 ) << shift
57- ind >>= 3
58-
59- return colormap
60-
61- # Code taken from Google Colab
62- # https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
63- def label_to_color_image (self , label ):
64- """Adds color defined by the dataset colormap to the label.
65-
66- Args:
67- label: A 2D array with integer type, storing the segmentation label.
68-
69- Returns:
70- result: A 2D array with floating type. The element of the array
71- is the color indexed by the corresponding element in the input label
72- to the PASCAL color map.
73-
74- Raises:
75- ValueError: If label is not of rank 2 or its value is larger than color
76- map maximum entry.
77- """
78- if label .ndim != 2 :
79- raise ValueError ('Expect 2-D input label' )
80-
81- colormap = self .create_pascal_label_colormap ()
82-
83- if np .max (label ) >= len (colormap ):
84- raise ValueError ('label value too large.' )
85-
86- return colormap [label ]
87-
88- # Code taken from Google Colab
89+ # Code taken partially from Google Colab
8990 # https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
9091 def vis_segmentation (self ):
9192 """Visualizes input image, segmentation map and overlay view."""
9293 # Current details
9394 image = cv2 .cvtColor (self .image , cv2 .COLOR_BGR2RGB ).astype (np .uint8 )
9495 seg_map = self .segment_map
9596
96- plt .figure (figsize = (15 , 5 ) )
97- grid_spec = gridspec .GridSpec (1 , 4 , width_ratios = [ 6 , 6 , 6 , 1 ] )
97+ plt .figure (figsize = (512 / self . dpi , 512 / self . dpi ), dpi = self . dpi )
98+ grid_spec = gridspec .GridSpec (1 , 1 )
9899
99- plt .subplot (grid_spec [0 ])
100- plt .imshow (image )
101- plt .axis ('off' )
102- plt .title ('input image' )
100+ # Process segmentation mask
101+ # Scale seg_map
102+ seg_map = (seg_map / np .max (seg_map )) * 255
103+ seg_image = Image .fromarray (seg_map .astype ('uint8' ))
104+ seg_image = cv2 .cvtColor (np .array (seg_image ), cv2 .COLOR_RGB2BGR )
103105
104- plt .subplot (grid_spec [1 ])
105- seg_image = self .label_to_color_image (seg_map ).astype (np .uint8 )
106- # Blur seg_image
107- seg_image = cv2 .GaussianBlur (seg_image , (5 , 5 ), 0 )
108- plt .imshow (seg_image )
109- plt .axis ('off' )
110- plt .title ('segmentation map' )
111-
112- plt .subplot (grid_spec [2 ])
106+ # Resize segmentation mask
113107 _width = image .shape [1 ]
114108 _height = image .shape [0 ]
115109 _num_channels = 3
116- res_seg_image = cv2 .resize (seg_image , (_width , _height ), _num_channels )
117- plt .imshow (image )
118- plt .imshow (res_seg_image , alpha = 0.7 )
110+ res_seg_image = cv2 .resize (seg_image , (_width , _height ),
111+ _num_channels )
112+
113+ # Postprocess mask
114+ res_seg_image = smooth_edges (res_seg_image )
115+
116+ # Mask out image
117+ res_image = mask_out (image , res_seg_image )
118+
119+ # Contour mask
120+ res_image = contour_mask (res_image , res_seg_image )
121+
122+ # Add text
123+ plt .subplot (grid_spec [0 ])
124+
125+ # Resize to standard 512x512 before display
126+ res_image = cv2 .resize (res_image , (512 , 512 ))
127+ add_meme_text (self .meme_text , res_image )
128+
129+ # Make image transparent
130+ res_image = make_transparent (res_image )
131+
132+ # Show image
133+ plt .imshow (res_image )
119134 plt .axis ('off' )
120- plt .title ('segmentation overlay' )
121-
122- unique_labels = np .unique (seg_map )
123- ax = plt .subplot (grid_spec [3 ])
124- plt .imshow (
125- self .FULL_COLOR_MAP [unique_labels ].astype (np .uint8 ), interpolation = 'nearest' )
126- ax .yaxis .tick_right ()
127- plt .yticks (range (len (unique_labels )), LABEL_NAMES [unique_labels ])
128- plt .xticks ([], [])
129- ax .tick_params (width = 0.0 )
130- plt .grid ('off' )
131- plt .show ()
132135
136+ plt .savefig ('test.png' , transparent = True )
137+ plt .show ()
133138
0 commit comments