In [18]:
import cv2
import torch
import torch.nn.functional as F
from torchvision import transforms
import os

# Load pretrained model
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
# Segment people only for the purpose of human silhouette extraction
people_class = 15

# Evaluate model
model.eval()
print ("Model has been loaded.")

blur = torch.FloatTensor([[[[1.0, 2.0, 1.0],[2.0, 4.0, 2.0],[1.0, 2.0, 1.0]]]]) / 16.0

# Use GPU if supported, for better performance
if torch.cuda.is_available():
	model.to('cuda')
	blur = blur.to('cuda')
	
# Apply preprocessing (normalization)
preprocess = transforms.Compose([
	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to create segmentation mask
def makeSegMask(img):
    # Scale input frame
	frame_data = torch.FloatTensor( img ) / 255.0

	input_tensor = preprocess(frame_data.permute(2, 0, 1))
    
    # Create mini-batch to be used by the model
	input_batch = input_tensor.unsqueeze(0)

    # Use GPU if supported, for better performance
	if torch.cuda.is_available():
		input_batch = input_batch.to('cuda')

	with torch.no_grad():
		output = model(input_batch)['out'][0]

	segmentation = output.argmax(0)

	bgOut = output[0:1][:][:]
	a = (1.0 - F.relu(torch.tanh(bgOut * 0.30 - 1.0))).pow(0.5) * 2.0

	people = segmentation.eq( torch.ones_like(segmentation).long().fill_(people_class) ).float()

	people.unsqueeze_(0).unsqueeze_(0)
	
	for i in range(3):
		people = F.conv2d(people, blur, stride=1, padding=1)

	# Activation function to combine masks - F.hardtanh(a * b)
	combined_mask = F.relu(F.hardtanh(a * (people.squeeze().pow(1.5)) ))
	combined_mask = combined_mask.expand(1, 3, -1, -1)

	res = (combined_mask * 255.0).cpu().squeeze().byte().permute(1, 2, 0).numpy()
	thresh = 127
	im_bw = cv2.threshold(res, thresh, 255, cv2.THRESH_BINARY)[1]

	return im_bw

Using cache found in C:\Users\h4rip/.cache\torch\hub\pytorch_vision_v0.6.0


Model has been loaded.


In [25]:
import cv2

vidcap = cv2.VideoCapture('./samples.mp4')
success,image = vidcap.read()
count = 0

frames = './test_dataset/frames/'
masks = './test_dataset/masks/'

while success:
  if(count%4 == 0):
    cv2.imwrite(frames+"%d.jpg" % count, image)     # save frame as JPEG file   
    # cv2.imwrite(masks+str(count)+'.png', makeSegMask(cv2.resize((image), (256,256))))
    print("Done : %s" %count)   
  success,image = vidcap.read()
  count += 1

Done : 0
Done : 4
Done : 8
Done : 12
Done : 16
Done : 20
Done : 24
Done : 28
Done : 32
Done : 36
Done : 40
Done : 44
Done : 48
Done : 52
Done : 56
Done : 60
Done : 64
Done : 68
Done : 72
Done : 76
Done : 80
Done : 84
Done : 88
Done : 92
Done : 96
Done : 100
Done : 104
Done : 108
Done : 112
Done : 116
Done : 120
Done : 124
Done : 128
Done : 132
Done : 136
Done : 140
Done : 144
Done : 148
Done : 152
Done : 156
Done : 160
Done : 164
Done : 168
Done : 172
Done : 176
Done : 180
Done : 184
Done : 188
Done : 192
Done : 196
Done : 200
Done : 204
Done : 208
Done : 212
Done : 216
Done : 220
Done : 224
Done : 228
Done : 232
Done : 236
Done : 240
Done : 244
Done : 248
Done : 252
Done : 256
Done : 260
Done : 264
Done : 268
Done : 272
Done : 276
Done : 280
Done : 284
Done : 288
Done : 292
Done : 296
Done : 300
Done : 304
Done : 308
Done : 312
Done : 316
Done : 320
Done : 324
Done : 328
Done : 332
Done : 336
Done : 340
