Skip to content

Commit 3aa06fa

Browse files
committed
working version
1 parent b5f7445 commit 3aa06fa

File tree

3 files changed

+94
-88
lines changed

3 files changed

+94
-88
lines changed

images/stock_example.jpg

-64.8 KB
Loading

main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from segment import Segment
22

3-
seg = Segment("images/stock_example.jpg")
3+
text = "YESONTI YESONTI PANLU CHESTHURU RA NAA PHOTO THONI"
4+
seg = Segment("images/stock_example.jpg", text)
45
seg_map = seg.find_segments()
56
seg.vis_segmentation()

segment.py

Lines changed: 92 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,70 @@
55
import numpy as np
66
from matplotlib import gridspec
77
from matplotlib import pyplot as plt
8+
import matplotlib.font_manager as fm
9+
import matplotlib.patheffects as path_effects
10+
import matplotlib
11+
# make sure Tk backend is used
12+
matplotlib.use("TkAgg")
13+
from PIL import Image
814

915

10-
LABEL_NAMES = np.asarray([
11-
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
12-
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
13-
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'
14-
])
16+
def add_meme_text(_str, image):
17+
# Find location to add
18+
_width = image.shape[1]
19+
_height = image.shape[0]
20+
21+
# Meme font
22+
prop = fm.FontProperties(fname='fonts/debussy.ttf')
23+
24+
# Justify text
25+
text = plt.text(_width*0.5, _height*0.8, _str, color='blue', fontproperties=prop,
26+
multialignment='center', wrap=True,
27+
ha='center', va='center', size=20)
28+
text.set_path_effects([path_effects.Stroke(linewidth=6, foreground='white'),
29+
path_effects.Normal()])
30+
31+
32+
def make_transparent(src):
33+
tmp = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
34+
_, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
35+
b, g, r = cv2.split(src)
36+
rgba = [b, g, r, alpha]
37+
dst = cv2.merge(rgba, 4)
38+
return dst
39+
40+
41+
def smooth_edges(image):
42+
# Median blur - smooth edges
43+
img = cv2.medianBlur(image, 35)
44+
return img
45+
46+
47+
def contour_mask(image, mask):
48+
# Get edges through Canny edge detection
49+
edged = cv2.Canny(mask, 30, 200)
50+
# Finding Contours
51+
contours, hierarchy = cv2.findContours(edged,
52+
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
53+
# Draw contours
54+
# -1 signifies drawing all contours
55+
cv2.drawContours(image, contours, -1, (255, 255, 255), 8)
56+
57+
return image
58+
59+
60+
def mask_out(src, mask):
61+
_mask_out = cv2.subtract(mask, src)
62+
_mask_out = cv2.subtract(mask, _mask_out)
63+
return _mask_out
1564

1665

1766
class Segment:
18-
def __init__(self, img_path):
67+
def __init__(self, img_path, meme_text, dpi=127.68):
1968
# Variables
2069
self.segment_map = None
70+
self.meme_text = meme_text
71+
self.dpi = dpi
2172

2273
# Initialize TF model
2374
print("Using model: " + settings.model_file)
@@ -30,104 +81,58 @@ def __init__(self, img_path):
3081
self.image = cv2.imread(img_path)
3182
print("Loaded image")
3283

33-
# Additional variables
34-
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
35-
self.FULL_COLOR_MAP = self.label_to_color_image(FULL_LABEL_MAP)
36-
3784
def find_segments(self):
3885
output_tensors = run_tf.run_model(self.interpreter, self.image)
3986
self.segment_map = run_tf.output_to_classes(output_tensors)
4087
return self.segment_map
4188

42-
# Code taken from Google Colab
43-
# https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
44-
@staticmethod
45-
def create_pascal_label_colormap():
46-
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
47-
48-
Returns:
49-
A Colormap for visualizing segmentation results.
50-
"""
51-
colormap = np.zeros((256, 3), dtype=int)
52-
ind = np.arange(256, dtype=int)
53-
54-
for shift in reversed(range(8)):
55-
for channel in range(3):
56-
colormap[:, channel] |= ((ind >> channel) & 1) << shift
57-
ind >>= 3
58-
59-
return colormap
60-
61-
# Code taken from Google Colab
62-
# https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
63-
def label_to_color_image(self, label):
64-
"""Adds color defined by the dataset colormap to the label.
65-
66-
Args:
67-
label: A 2D array with integer type, storing the segmentation label.
68-
69-
Returns:
70-
result: A 2D array with floating type. The element of the array
71-
is the color indexed by the corresponding element in the input label
72-
to the PASCAL color map.
73-
74-
Raises:
75-
ValueError: If label is not of rank 2 or its value is larger than color
76-
map maximum entry.
77-
"""
78-
if label.ndim != 2:
79-
raise ValueError('Expect 2-D input label')
80-
81-
colormap = self.create_pascal_label_colormap()
82-
83-
if np.max(label) >= len(colormap):
84-
raise ValueError('label value too large.')
85-
86-
return colormap[label]
87-
88-
# Code taken from Google Colab
89+
# Code taken partially from Google Colab
8990
# https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
9091
def vis_segmentation(self):
9192
"""Visualizes input image, segmentation map and overlay view."""
9293
# Current details
9394
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB).astype(np.uint8)
9495
seg_map = self.segment_map
9596

96-
plt.figure(figsize=(15, 5))
97-
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
97+
plt.figure(figsize=(512/self.dpi, 512/self.dpi), dpi=self.dpi)
98+
grid_spec = gridspec.GridSpec(1, 1)
9899

99-
plt.subplot(grid_spec[0])
100-
plt.imshow(image)
101-
plt.axis('off')
102-
plt.title('input image')
100+
# Process segmentation mask
101+
# Scale seg_map
102+
seg_map = (seg_map/np.max(seg_map)) * 255
103+
seg_image = Image.fromarray(seg_map.astype('uint8'))
104+
seg_image = cv2.cvtColor(np.array(seg_image), cv2.COLOR_RGB2BGR)
103105

104-
plt.subplot(grid_spec[1])
105-
seg_image = self.label_to_color_image(seg_map).astype(np.uint8)
106-
# Blur seg_image
107-
seg_image = cv2.GaussianBlur(seg_image, (5, 5), 0)
108-
plt.imshow(seg_image)
109-
plt.axis('off')
110-
plt.title('segmentation map')
111-
112-
plt.subplot(grid_spec[2])
106+
# Resize segmentation mask
113107
_width = image.shape[1]
114108
_height = image.shape[0]
115109
_num_channels = 3
116-
res_seg_image = cv2.resize(seg_image, (_width, _height), _num_channels)
117-
plt.imshow(image)
118-
plt.imshow(res_seg_image, alpha=0.7)
110+
res_seg_image = cv2.resize(seg_image, (_width, _height),
111+
_num_channels)
112+
113+
# Postprocess mask
114+
res_seg_image = smooth_edges(res_seg_image)
115+
116+
# Mask out image
117+
res_image = mask_out(image, res_seg_image)
118+
119+
# Contour mask
120+
res_image = contour_mask(res_image, res_seg_image)
121+
122+
# Add text
123+
plt.subplot(grid_spec[0])
124+
125+
# Resize to standard 512x512 before display
126+
res_image = cv2.resize(res_image, (512, 512))
127+
add_meme_text(self.meme_text, res_image)
128+
129+
# Make image transparent
130+
res_image = make_transparent(res_image)
131+
132+
# Show image
133+
plt.imshow(res_image)
119134
plt.axis('off')
120-
plt.title('segmentation overlay')
121-
122-
unique_labels = np.unique(seg_map)
123-
ax = plt.subplot(grid_spec[3])
124-
plt.imshow(
125-
self.FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
126-
ax.yaxis.tick_right()
127-
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
128-
plt.xticks([], [])
129-
ax.tick_params(width=0.0)
130-
plt.grid('off')
131-
plt.show()
132135

136+
plt.savefig('test.png', transparent=True)
137+
plt.show()
133138

0 commit comments

Comments
 (0)