In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import backbones_unet
import cv2
from time import time

In [2]:
import torch
model = torch.load("../ckpts/convnext_base_ckpt5.pth")

In [3]:
def jaccard_index(predicted, target):
    intersection = (predicted * target).sum()
    union = (predicted + target).sum() - intersection
    return (intersection / (union + 1e-6)).item()

def preprocess(img):
    img = img.resize((224, 224)) 
    img = torch.Tensor(np.array(img, dtype=np.uint8).transpose((2, 0, 1)))
    img = img.float() / 255.0
    return img

def get_video_frame(in_path):
    cap = cv2.VideoCapture(in_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    print(fps)
    frames = []
    start_time = time()
    try:
        while True:
            # Read a frame from the video
            ret, frame = cap.read()
            if not ret:
                break  # Break the loop if there are no frames left
            
            # Optionally resize or preprocess the frame here if necessary
            frames.append(frame)
    finally:
        cap.release()  # Make sure to release the video capture object
    
    return frames, fps

frames, fps = get_video_frame("./videos/fast.mp4")

segment_every_x = 20

def batched_segmentation(model,frames,segment_every_x):
    chosen = frames[::segment_every_x]
    print(len(chosen))
    loss = []
    chosen_converted = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in chosen]
    chosen_preprocessed = [preprocess(Image.fromarray(rgb_frame)) for rgb_frame in chosen_converted]
    batch = torch.stack(chosen_preprocessed,dim=0)
    print(batch.shape)
    #input to model
    batch = batch.to('cuda')
    model = model.to('cuda')
    with torch.no_grad():
        output = model(batch)
    output = output.clamp(0,1)
    print(output.clamp(0,1))
    for i in range(output.shape[0]-1):
        loss.append(jaccard_index(output[i][0],output[i+1][0]))
    print(loss)
    return output

output = batched_segmentation(model,frames,segment_every_x)
tensor_shape = output[0].shape
extended_masks = output.unsqueeze(1).repeat(1, segment_every_x, 1, 1, 1).view(-1, *tensor_shape)
print(extended_masks.shape)

print(len(frames))
chosen_converted_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
chosen_preprocessed_frames = [preprocess(Image.fromarray(rgb_frame)) for rgb_frame in chosen_converted_frames]

30.0
9
torch.Size([9, 3, 224, 224])
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          

In [22]:
def single_pass(model,input):
    input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
    input_preprocessed = preprocess(Image.fromarray(input))
    input_preprocessed = input_preprocessed[None,...]
    input_preprocessed = input_preprocessed.to('cuda')
    model = model.to('cuda')
    with torch.no_grad():
        output = model(input_preprocessed)
    output = output.clamp(0,1)
    return output[0]

print(single_pass(model,frames[0]).shape)

def frame_scheduling(frames,masks,output,segment_every_x,mode='linear'):
    loss = []
    cutoff = 0.8
    upper_cutoff = 0.90
    for i in range(output.shape[0]-1):
        loss.append(jaccard_index(output[i][0],output[i+1][0]))
    print(masks.shape)
    print(len(loss))
    if mode == 'linear':
        for i in range(len(loss)): 
            if(loss[i] < cutoff):
                print("Linear Interpolating...")
                for j in range(segment_every_x):
                    alpha = j / (segment_every_x - 1)
                    print(alpha)
                    masks[i*segment_every_x + j] = torch.lerp(masks[i*segment_every_x],masks[i*segment_every_x + segment_every_x],alpha)
                    
    elif mode == 'trust region':
        print("reached!")
        interval = segment_every_x
        for i in range(len(loss)):
            print(output.shape)
            diff = jaccard_index(masks[i*segment_every_x][0],masks[i*segment_every_x+interval][0])
            print(f"Loss:{diff}")
            if(diff < cutoff) and interval > 2:
                interval = interval // 2
                print(f"Adding new interpolations,interval = {interval}")
            elif interval*2 < segment_every_x and diff > upper_cutoff : 
                interval = interval * 2
                print(f"Combining subintervals, interval = {interval}")
            for subinterval in range(0,segment_every_x,interval):
                index = i * segment_every_x + subinterval
                new_mask = single_pass(model,frames[index])
                masks[i*segment_every_x - interval + subinterval: i*segment_every_x + subinterval] = new_mask

    return masks

masks = frame_scheduling(frames,extended_masks,output,segment_every_x,mode='trust region')

print(len(masks))
    
    

torch.Size([1, 224, 224])
torch.Size([180, 1, 224, 224])
8
reached!
torch.Size([9, 1, 224, 224])
Loss:0.7825303077697754
Adding new interpolations,interval = 10
torch.Size([9, 1, 224, 224])
Loss:0.8113927245140076
torch.Size([9, 1, 224, 224])
Loss:0.5983411073684692
Adding new interpolations,interval = 5
torch.Size([9, 1, 224, 224])
Loss:0.7189475297927856
Adding new interpolations,interval = 2
torch.Size([9, 1, 224, 224])
Loss:0.8793796300888062
torch.Size([9, 1, 224, 224])
Loss:0.9252275228500366
Combining subintervals, interval = 4
torch.Size([9, 1, 224, 224])
Loss:0.7772756814956665
Adding new interpolations,interval = 2
torch.Size([9, 1, 224, 224])
Loss:0.9273473620414734
Combining subintervals, interval = 4
180


In [23]:
def getVideo(frames,out_path,masks):
    
    H, W, C = frames[0].shape
    segmented_frames = []
    cnt = 0
    #print(len(frames))
    val_every_x_frame = 5
    segmented_frames = []
    print(len(frames))
    print(len(masks))
    for i in range(len(frames)):
        segmented_frames.append(frames[i].to('cuda') * masks[i])
    fourcc = cv2.VideoWriter_fourcc(*'MP4V')  # or (*'XVID') depending on the desired output format
    out = cv2.VideoWriter(out_path, fourcc, 30.0, (224,224))
    assert(len(segmented_frames) == len(frames))
    for frame in segmented_frames:
        frame = frame.permute(1, 2, 0)
        frame = frame.cpu().numpy()
        if frame.dtype != np.uint8:
            frame = (frame * 255).astype(np.uint8)
        #plt.imshow(frame)
        #plt.show()
        #print(frame.shape)
        # Convert your processed frame back to BGR from RGB if necessary
        bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(bgr_frame)
    out.release()  # Release everything if job is finished
    end_time = time()
    
getVideo(chosen_preprocessed_frames,"./videos/fast_out_trust_region.mp4",extended_masks)

165
180
