In [1]:
import cv2
import numpy as np
import torch
from torch import nn
from models import LinkNet34
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image, ImageFilter
import time
import sys


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LinkNet34()
# model.load_state_dict(torch.load('linknet.pth'))
model.load_state_dict(torch.load('linknet.pth', map_location=lambda storage, loc: storage))
model.eval()
model.to(device)
1

1

In [3]:

class CaptureFrames():

    def __init__(self, model, source, show_mask=False):
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model
        self.source = source
        self.show_mask = show_mask
        
    def __call__(self, source):
        self.capture_frames(source)
  
    def capture_frames(self, source):
        
        img_transform = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        camera = cv2.VideoCapture(source)
        time.sleep(2)
        self.model.eval()
        (grabbed, orig) = camera.read()
        camera.set(cv2.CAP_PROP_FPS, 25.0)

        fps = camera.get(cv2.CAP_PROP_FPS)

        if (camera.isOpened() == False):
            print("Unable to read video")

        time.sleep(2)

        frame_width = int(camera.get(3)); frame_height = int(camera.get(4));

        time_1 = time.time()
        self.frames_count = 0
        while grabbed:
            (grabbed, orig) = camera.read()
            if not grabbed:
                continue
            
            shape = orig.shape[0:2]
            frame = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame,(256,256), cv2.INTER_LINEAR )
            
            
            

            a = img_transform(Image.fromarray(frame))
            a = a.unsqueeze(0)
            imgs = Variable(a.to(dtype=torch.float, device=self.device))
            pred = self.model(imgs)
            
            pred= torch.nn.functional.interpolate(pred, size=[shape[0], shape[1]])
            mask = pred.data.cpu().numpy()
            mask = mask.squeeze()
            
            mask = mask > 0.8
            rgba = cv2.cvtColor(orig, cv2.COLOR_BGR2BGRA)
            ind = np.where(mask == 0)
            rgba[ind] = rgba[ind] - [0,0,0, 180]
            
            canvas = Image.new('RGBA', (rgba.shape[1], rgba.shape[0]), (255,255,255,255))
            canvas.paste(Image.fromarray(rgba), mask=Image.fromarray(rgba))
            rgba = np.array(canvas)
            rgb = cv2.cvtColor(rgba, cv2.COLOR_BGRA2BGR)
            k = cv2.waitKey(1)

            if self.show_mask:
                cv2.imshow('mask', rgb)

            if self.frames_count % 30 == 29:
                time_2 = time.time()
                sys.stdout.write(f'\rFPS: {30/(time_2-time_1)}')
                sys.stdout.flush()
                time_1 = time.time()
                
            if k != -1:
                self.terminate(camera)
                break
            self.frames_count+=1
        self.terminate(camera)

    
    def terminate(self, camera):
        cv2.destroyAllWindows()
        camera.release()




In [4]:
# set path=0 for webcam or set to a video file
path = 0
c = CaptureFrames(model, 0, True)
c(path)