## Environment Set-up

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
import glob
import os

## Image pre-processing

Extract frames from videos

In [None]:
# Path to the parent directory
path = 'AWP/SAM_Clustering/videos'
frames_path = 'AWP/SAM_Clustering/frames'

# Make sure the frame directory exists
os.makedirs(frames_path, exist_ok=True)

# List to store video names
video_names = []

# Global variables for frame extraction
NUM_FRAMES = 9  # number of frames to extract from each video
TIME_INTERVAL = 0.1  # time interval in seconds between each extracted frame

# Loop through all directories and subdirectories
for subdir, dirs, files in os.walk(path):
    # Check if directory is not empty
    if files:
        # Get the last part of directory which is considered as the video number
        video_number = os.path.basename(subdir)
        # Sort files to ensure naming is in order
        files.sort()
        # Enumerate files with 1-based index and construct name
        for index, file in enumerate(files, start=1):
            video_name = f"{video_number}_{index}_mp4"
            video_names.append(video_name)

            # Load the video using OpenCV
            video_path = os.path.join(subdir, file)
            vidcap = cv2.VideoCapture(video_path)

            # Get video frames per second (fps)
            fps = vidcap.get(cv2.CAP_PROP_FPS)
            # Calculate the frame skip based on desired time interval
            frame_skip = int(fps * TIME_INTERVAL)

            # Get total frames
            total_frames = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
            # Calculate the start and end frame for the middle few 5%
            start_frame = int(total_frames * 0.45)
            end_frame = int(total_frames * 0.50)

            success, image = vidcap.read()
            frame_count = 0
            extracted_frames = 0
            while success:
                # Check if this frame is one of the frames we want to extract
                if frame_count % frame_skip == 0 and extracted_frames < NUM_FRAMES and start_frame <= frame_count <= end_frame:
                    frame_name = f"{video_name}_{extracted_frames + 1}.png"
                    frame_path = os.path.join(frames_path, frame_name)
                    cv2.imwrite(frame_path, image)

                    # Add the following code to display the image:
                    img = mpimg.imread(frame_path)
                    imgplot = plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
                    plt.show()

                    extracted_frames += 1
                success, image = vidcap.read()
                frame_count += 1


## Segementation

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


In [None]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
masks_path = 'AWP/SAM_Clustering/masks'
# Make sure the frame directory exists
os.makedirs(masks_path, exist_ok=True)

In [None]:
def process_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    predictor.set_image(image)

    input_point = np.array([[1000, 630], [950, 570], [1060, 600]])
    input_label = np.array([1, 1, 1])

    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )

    mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
    )

    masks.shape

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.show()

    # Apply the mask to the image
    masked_image = image * masks[0][:, :, None]  # If masks has more than 1 dimension, select the relevant one

    # Convert the masked image back to BGR color scheme for saving
    masked_image = cv2.cvtColor(masked_image.astype(np.uint8), cv2.COLOR_RGB2BGR)

    # Save the masked image
    cv2.imwrite(os.path.join(masks_path, f'mask_{os.path.basename(image_path)}'), masked_image)


In [None]:
# The directory where the frames are stored
frame_dir = '/content/frames'


# Process each image in the frame directory
for image_file in os.listdir(frame_dir):
    image_path = os.path.join(frame_dir, image_file)
    if os.path.isfile(image_path):
        process_image(image_path)

## Fine-tuning

## Clustering

In [None]:
%matplotlib inline
# for loading/processing the images
from tensorflow.keras.utils import load_img
from tensorflow.keras.utils import img_to_array
from keras.applications.vgg16 import preprocess_input
import cv2

# models
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.applications import EfficientNetB7

# clustering and dimension reduction
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering

# image augmentation
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# for everything else
import os
import numpy as np
import matplotlib.pyplot as plt
# from random import randint
# import pandas as pd
# import pickle
# from sklearn.neighbors import NearestNeighbors

path = r"/content/masks"
# change the working directory to the path where the images are located
os.chdir(path)

# this list holds all the image filename
flowers = []

# creates a ScandirIterator aliased as files
with os.scandir(path) as files:
    # loops through each file in the directory
    for file in files:
        if file.name.endswith('.png'):
            # adds only the image files to the flowers list
            flowers.append(file.name)

model = EfficientNetB7(include_top=False, pooling='avg', weights='imagenet')

# Define our example data generator
datagen = ImageDataGenerator(
    # rotation_range=2,
    # width_shift_range=0.1,
    # height_shift_range=0.1,
    # horizontal_flip=True,
    # brightness_range=[0.7, 1.3],
)

def load_images(file):
    # load the image as a 600x600 array
    img = load_img(file, target_size=(600, 600))
    # convert from 'PIL.Image.Image' to numpy array
    img = np.array(img)
    # reshape the data for the model reshape(num_of_samples, dim 1, dim 2, channels)
    reshaped_img = img.reshape(1, 600, 600 ,3)
    return reshaped_img

# Load all images into memory
image_list = []
for flower in flowers:
    try:
        img = load_images(flower)
        image_list.append(img)
    except:
        print("Error loading image: ", flower)

images = np.vstack(image_list)

# Prepare iterator
iterator = datagen.flow(images, batch_size=1, shuffle=False)

# Extract features for each image
data = {}
for i, batch in enumerate(iterator):
    if i >= len(flowers):  # ImageDataGenerator indefinitely produces batches
        break
    img = batch[0]  # we have batch_size=1, so there's only one image in the batch
    img = preprocess_input(img)
    features = model.predict(np.array([img]), use_multiprocessing=True)
    data[flowers[i]] = features


# get a list of the filenames
filenames = np.array(list(data.keys()))

# get a list of just the features
feat = np.array(list(data.values()))

# reshape so that there are 210 samples of 2560 vectors
feat = feat.reshape(-1,2560)

# reduce the amount of dimensions in the feature vector
pca = PCA(n_components=12, random_state=22)
pca.fit(feat)
x = pca.transform(feat)

# cluster feature vectors
hclust = AgglomerativeClustering(n_clusters=6)
hclust.fit(x)

# holds the cluster id and the images { id: [images] }
groups = {}
for file, cluster in zip(filenames,hclust.labels_):
    if cluster not in groups.keys():
        groups[cluster] = []
        groups[cluster].append(file)
    else:
        groups[cluster].append(file)

# function that lets you view a cluster (based on identifier)
def view_cluster(cluster):
    plt.figure(figsize = (25,25));
    # gets the list of filenames for a cluster
    files = groups[cluster]
    # only allow up to 30 images to be shown at a time
    if len(files) > 30:
        print(f"Clipping cluster size from {len(files)} to 30")
        files = files[:29]
    # plot each image in the cluster
    for index, file in enumerate(files):
        plt.subplot(10,10,index+1);
        img = load_img(file)
        img = np.array(img)
        plt.imshow(img)
        plt.axis('off')

# Loop through the clusters
for cluster in groups.keys():
    view_cluster(cluster)
    plt.show()
