## Import Libarries

In [137]:
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
import time

from ultralytics import YOLO
import torch
import torchvision.transforms as transforms
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator


## path and model declaration

In [2]:
video_path  = os.path.join(os.getcwd(), 'src','classroom.mp4')

## checkpoints for sam
sam_checkpoints = "checkpoints"
vit_h = "sam_vit_h_4b8939.pth"
vit_b = "sam_vit_b_01ec64.pth"
vit_l = "sam_vit_l_0b3195.pth"

## check for device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [58]:
## yolo model
yolo_model = YOLO('yolov8x.pt').to(device)

## sam model
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=os.path.join(sam_checkpoints, vit_l))
sam = sam.to(device)
predictor = SamPredictor(sam)

In [140]:
def augment_image(img):
    augmented_images = [
        img,
        ## apply gaussian blur
        transforms.GaussianBlur(kernel_size=3)(img),
        ## change the hue of the image
        transforms.ColorJitter(hue=0.5)(img),
        ## change the saturation of the image
        transforms.ColorJitter(saturation=0.5)(img),
        ## change, make the image blur, then apply gaussian blur
        transforms.GaussianBlur(kernel_size=3)(transforms.ColorJitter(brightness=0.5)(img)),
    ]
    
    #img_batch = torch.stack(augmented_images)
    
    augmented_images = [np.array(img.permute(2,1,0)) for img in augmented_images]
    return augmented_images
    

In [141]:
image_path = os.path.join(os.getcwd(), 'src','test2.jpg')
img = cv2.imread(image_path)


## resize image to 416x416
#img = cv2.resize(img, (1024, 2336), interpolation=cv2.INTER_LINEAR)
## convert image to tensor
img = torch.from_numpy(img).permute(2, 1, 0).float()

image_list = augment_image(img)

In [142]:
resutls = yolo_model(image_list, conf=0.10, classes=[0])


0: 640x448 2 persons, 1: 640x448 3 persons, 2: 640x448 2 persons, 3: 640x448 (no detections), 4: 640x448 (no detections), 166.0ms
Speed: 265.0ms preprocess, 33.2ms inference, 1.6ms postprocess per image at shape (1, 3, 640, 448)


In [38]:
def process_frame(frame):
    
    results = yolo_model(frame, conf=0.25, classes=[0])
    
    ## Process results
    for result in results:
        boxes = result.boxes
        
    bbox = boxes.xyxy
    print('bbox shape: ', bbox.shape)
    #confidences = boxes.conf
    #classes = boxes.cls 
    #predictor = SamPredictor(sam)
    predictor.set_image(frame)
    
    input_boxes = bbox.to(predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, frame.shape[:2])
    
    masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
    )
    
    return masks

In [39]:
def optimized_mask2img(mask):
    palette = {
        0: (0, 0, 0),
        1: (255, 0, 0),
        2: (0, 255, 0),
        3: (0, 0, 255),
        4: (0, 255, 255),
    }
    items = mask.shape[0]
    rows = mask.shape[1]
    cols = mask.shape[2]
    image = np.zeros((items, rows, cols, 3), dtype=np.uint8)
    image[:, :, :, 0] = mask * palette[1][0]
    image[:, :, :, 1] = mask * palette[1][1]
    image[:, :, :, 2] = mask * palette[1][2]
    return image

def optimized_show_mask(masks):
    masks = np.squeeze(masks, axis = 1)
    separate_rgb_masks = optimized_mask2img(masks)
    combined_mask = np.sum(separate_rgb_masks, axis = 0)
    return combined_mask


In [41]:
## Load YOLO
## yolo model
yolo_model = YOLO('yolov8x.pt').to(device)

## sam model
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=os.path.join(sam_checkpoints, vit_h))
sam = sam.to(device)
predictor = SamPredictor(sam)
    
    
    
## Load video
video_path  = os.path.join(os.getcwd(), 'src','classroom.mp4')
cap = cv2.VideoCapture(video_path)
#cap = cv2.VideoCapture(0)
    
if cap.isOpened() == False:
    print("Error in loading the video")
    
i = 0
    
# # Get the video properties
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
fps = int(cap.get(5))

# # Define the output video path
output_path_contours = os.path.join(os.getcwd(), 'output', 'classroom_c.mp4')
output_path_segmentation = os.path.join(os.getcwd(), 'output', 'classroom_s.mp4')

# # Create a VideoWriter object to save the processed frames
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_c = cv2.VideoWriter(output_path_contours, fourcc, fps, (frame_width, frame_height))
out_s = cv2.VideoWriter(output_path_segmentation, fourcc, fps, (frame_width, frame_height))

while(cap.isOpened()):
    ret, frame = cap.read()
    try:
        # frame2 = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        #frame2 = cv2.cvtColor((cv2.GaussianBlur(frame2, (3, 3), 0)), cv2.COLOR_GRAY2RGB)
        masks = process_frame(frame, yolo_model, predictor)
            #dispaly frame and colour mask in same window
            
         
        if masks is not None:        
            frame = ((frame/np.max(frame))*255).astype(np.uint8)
            colour_mask = optimized_show_mask(masks.detach().cpu().numpy())
       
            colour_mask = cv2.addWeighted(colour_mask.astype(np.uint8), 0.3, frame, 0.7, 0, dtype=cv2.CV_8U)#colour_mask.astype(np.uint8))
                
            #-----------for contours -------
            masks = np.squeeze(masks.detach().cpu().numpy(), axis = 1).astype(np.uint8)
            #print('masks shape: ', masks.shape, masks.shape[0], np.unique(masks))
            for dim in range(masks.shape[0]):
                #print('in shape: ', masks[dim, :, :].shape)
                contours, hierarchy = cv2.findContours(image = masks[dim, :, :], mode = cv2.RETR_TREE, method = cv2.CHAIN_APPROX_NONE)
                cv2.drawContours(image = frame, contours=contours, contourIdx=-1, color=(0, 255, 0), thickness=1, lineType=cv2.LINE_AA)
        
            #cv2.imshow('frame', frame)
            #cv2.imshow('frame', colour_mask)
        
            # Write the combined frame to the output video
            out_c.write(frame)
            out_s.write(colour_mask)
        else:
            out_c.write(frame)
            out_s.write(frame)
    
        # if cv2.waitKey(25) & 0xFF == ord('q'):
                #break
    
        i = i + 1
    ## save frame and make video
    except:
        i = i + 1
        out_c.release()
        out_s.release()
        break
       

cap.release()
cv2.destroyAllWindows()


0: 384x640 8 persons, 69.0ms
Speed: 7.0ms preprocess, 69.0ms inference, 12.0ms postprocess per image at shape (1, 3, 384, 640)


bbox shape:  torch.Size([8, 4])



0: 384x640 7 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 7 persons, 8.0ms
Speed: 2.0ms preprocess, 8.0ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 8 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([8, 4])



0: 384x640 7 persons, 6.0ms
Speed: 1.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 7 persons, 6.0ms
Speed: 1.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 7 persons, 6.0ms
Speed: 1.0ms preprocess, 6.0ms inference, 4.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 7 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 7 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 8 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([8, 4])



0: 384x640 8 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([8, 4])



0: 384x640 7 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 1.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([7, 4])



0: 384x640 8 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (7, 480, 852) 7 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([8, 4])



0: 384x640 8 persons, 6.0ms
Speed: 2.0ms preprocess, 6.0ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)


masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
bbox shape:  torch.Size([8, 4])
masks shape:  (8, 480, 852) 8 [0 1]
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
in shape:  (480, 852)
