# Lab 11 - Part 1
Visual Saliency Detection, K-Means Segmentation, and GraphCut Segmentation

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import urllib.request
import os
from IPython.display import Video, display
from google.colab.patches import cv2_imshow

In [None]:
# Download resources if they don't exist
sample_image_url = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg'
sample_video_url = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'

if not os.path.exists('sample_image.jpg'):
    urllib.request.urlretrieve(sample_image_url, 'sample_image.jpg')
if not os.path.exists('sample_video.avi'):
    urllib.request.urlretrieve(sample_video_url, 'sample_video.avi')

In [None]:
print("--- Visual Saliency Detection ---")

# Load image
image = cv2.imread('sample_image.jpg')

# Initialize OpenCV Saliency algorithm
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
success, saliencyMap = saliency.computeSaliency(image)

# Threshold the saliency map
saliencyMap = (saliencyMap * 255).astype("uint8")

cv2_imshow(image)
cv2_imshow(saliencyMap)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
# Saliency Detection on Video
print("Running Saliency Detection on Video...")
cap = cv2.VideoCapture('sample_video.avi')
while True:
    ret, frame = cap.read()
    if not ret:
        break
    success, saliencyMap = saliency.computeSaliency(frame)
    saliencyMap = (saliencyMap * 255).astype("uint8")
    cv2_imshow(frame)
    cv2_imshow(saliencyMap)
    if cv2.waitKey(30) & 0xFF == ord('q'):
        break
cap.release()
cv2.destroyAllWindows()

In [None]:
# Saliency Detection for Full Video and Save
print("Running Saliency Detection on Video and saving output...")
cap = cv2.VideoCapture('sample_video.avi')
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('saliency_output.avi', fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))), False)

while True:
    ret, frame = cap.read()
    if not ret:
        break
    success, saliencyMap = saliency.computeSaliency(frame)
    saliencyMap = (saliencyMap * 255).astype("uint8")
    out.write(saliencyMap)

cap.release()
out.release()

print("Displaying Saliency Output Video:")
display(Video('saliency_output.avi', embed=True, width=600, height=400))

In [None]:
# K-Means Segmentation
print("--- Unsupervised Image Segmentation with K-Means ---")
img = cv2.imread('sample_image.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
Z = img.reshape((-1,3))
Z = np.float32(Z)
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
K = 4
ret,label,center=cv2.kmeans(Z,K,None,criteria,10,cv2.KMEANS_RANDOM_CENTERS)
center = np.uint8(center)
res = center[label.flatten()]
segmented_image = res.reshape((img.shape))
plt.figure(figsize=(10,5))
plt.subplot(121), plt.imshow(img)
plt.title('Original Image')
plt.axis('off')
plt.subplot(122), plt.imshow(segmented_image)
plt.title('Segmented Image (KMeans)')
plt.axis('off')
plt.show()

In [None]:
# Graph Cut Segmentation
print("--- Graph Cut Segmentation ---")
img = cv2.imread('sample_image.jpg')
mask = np.zeros(img.shape[:2],np.uint8)
bgdModel = np.zeros((1,65),np.float64)
fgdModel = np.zeros((1,65),np.float64)
rect = (50,50,400,400)
cv2.grabCut(img,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_RECT)
mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
segmented_img = img*mask2[:,:,np.newaxis]
plt.figure(figsize=(12,5))
plt.subplot(121), plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('Original Image')
plt.axis('off')
plt.subplot(122), plt.imshow(cv2.cvtColor(segmented_img, cv2.COLOR_BGR2RGB))
plt.title('GraphCut Segmented Image')
plt.axis('off')
plt.show()

In [None]:
# Graph Visualization
print("Visualizing simplified graph structure...")
height, width = img.shape[:2]
G = nx.grid_2d_graph(height//10, width//10)
pos = {(x,y):(y,-x) for x,y in G.nodes()}
plt.figure(figsize=(8,8))
nx.draw(G, pos=pos, node_color='lightblue', with_labels=False, node_size=20, edge_color='gray')
plt.title('Simplified Graph Structure (Downsampled)')
plt.show()
print("--- DONE! ---")