In [None]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image
def txtToArray(path):
    f = open(path)
    tokens=f.read().split()
    for i in range(0,len(tokens)): tokens[i]=tokens[i].split(',')
    A=np.array(tokens, dtype=np.int64)
    return(A)
class HandsDataset(torch.utils.data.Dataset):
    def __init__(self,root,transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(os.listdir(os.path.join(root,"DATA_IMAGES")))
        self.masks = list(os.listdir(os.path.join(root,"DATA_MASKS")))
        self.boxes = list(os.listdir(os.path.join(root,"DATA_BOXES")))
    def __getitem__(self,idx):
        i = 1+int(idx/100)                                               #i indicates which video
        j = 1+(idx%100)                                                  #j indicates which frame within given video
        imgStr = "Image"+str(i)+"_"+str(j)+".jpg"
        maskStr = "Mask"+str(i)+"_"+str(j)+"_"
        boxStr = "Box"+str(i)+"_"+str(j)+".txt"
        img_path = os.path.join(self.root, "DATA_IMAGES",imgStr)
        box_path = os.path.join(self.root,"DATA_BOXES",boxStr)
        box_array = txtToArray(box_path)
        boxes = []
        masks = []
        for boxid in range(4):  #we go through each bounding box and fetch its corresponding mask image
            if box_array[boxid,2] != 0 : #we have bounding box 
                xmin = box_array[boxid,0]
                ymax = box_array[boxid,1]
                xmax = xmin+box_array[boxid,2]
                ymin = ymax - box_array[boxid,3]
                boxes.append([xmin,ymin,xmax,ymax])
                #get mask path and add to masks array
                maskStrTemp = maskStr +str(boxid)+".jpg"
                mask_path = os.path.join(self.root,"DATA_MASKS",maskStrTemp)
                mask = Image.open(mask_path)
                mask = np.array(mask)
                for i in range(mask.shape[0]):                 #turns mask into a binary (black and white) image
                    for j in range(mask.shape[1]):
                        if(mask[i,j]!= 0):
                            mask[i,j] = 1
                masks.append(mask)
        img = Image.open(img_path).convert("RGB")
        num_objs = len(boxes)
        if num_objs == 0:                               
            idx = (idx+1)%4800
            return self.__getitem__(idx)
        boxes = torch.as_tensor(boxes,dtype=torch.float32)
        if len(boxes) == 0:
            area = torch.as_tensor([0],dtype=torch.float32)
        else:
            area = (boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])
        labels = torch.ones((num_objs,),dtype = torch.int64)
        masks = torch.as_tensor(masks,dtype=torch.uint8)  #these are already binary files
        image_id = torch.tensor([idx])
        iscrowd = torch.zeros((num_objs,),dtype = torch.int64) #?
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        return img, target
    def __len__(self):
        return len(self.imgs)

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    return model

In [None]:
import transforms as T
from engine import train_one_epoch, evaluate
import utils
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2
dataset = HandsDataset('', get_transform(train=False))
dataset_test = HandsDataset('', get_transform(train=False))
indices = np.arange(2600)
dataset = torch.utils.data.Subset(dataset, indices[:2400])    #train on first 24 videos       
dataset_test = torch.utils.data.Subset(dataset_test,indices)  #have all images available for testing (use last 200)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=10, shuffle=False, num_workers=1,
        collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=2, shuffle=False, num_workers=1,
        collate_fn=utils.collate_fn)
model = get_model_instance_segmentation(num_classes)
#model.load_state_dict(torch.load('model2400.pt'))
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)
num_epochs = 1                                                         #training for 1 epoch at 2400 training took over 4 hours
for epoch in range(1): 
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=1)
        #evaluate(model, data_loader_test, device=device)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def getMaskedIMG(img1,boxes1,masks1,idx1):
    m = masks1[0]
    for i in range(len(masks1)):
        if i != 0:
            m = m + masks1[i]
    dim = np.zeros((720,1280))
    theMaskColor = np.stack((m,dim,dim),axis=2)*256                     #creates a red mask with black background
    theMaskColor = np.array(theMaskColor,dtype=np.int32)
    npImg1 = np.array(img1*256,dtype=np.int32)
    finalImg = npImg1+theMaskColor                                     #add red mask to original image
    finalImg = np.array(finalImg,dtype=np.int32)
    fig,ax = plt.subplots(1)
    ax.imshow(finalImg)
    for i in range(int(len(boxes1)/2)):         #a model that isnt fully trained produces too many boxes so we select
        boxW = boxes1[i][2] - boxes1[i][0]     #the top 50%, the ones with highest probability(score)
        boxH = boxes1[i][3] - boxes1[i][1]
        rect = patches.Rectangle((boxes1[i][0],boxes1[i][3]),boxW,boxH,linewidth=0.5,edgecolor='b',facecolor='none')
        ax.add_patch(rect)
    if idx1 < 10:
        outputStr = "video_test/img-00"+str(idx1)+".png"   
    else:
        outputStr = "video_test/img-0"+str(idx1)+".png"
    #plt.show()               #if output desired in console
    fig.savefig(outputStr)    #saves to specified folder for video creation

In [None]:
outputorder = 0    #used to specify ordering of images in video
for idx in indices[2400:2500]:             #this selects 9th video for evaluation  
    img, target = dataset_test[idx]
    model.eval()
    with torch.no_grad():
        prediction = model([img.to(device)])
    out = img.numpy().transpose()                #turn torch tensor into numpy then reorder coordinates
    out = np.rot90(out)                          #by rotating 270 counterclockwise and flipping l/r
    out = np.rot90(out)
    out = np.rot90(out)
    out = np.fliplr(out)
    img1 = out                                      
    boxes1 = prediction[0]['boxes'].cpu().numpy()                  
    masks1 = prediction[0]['masks'].cpu().numpy()
    masks1 = masks1.reshape((len(masks1),720,1280))
    idx1 = outputorder
    outputorder = outputorder + 1
    getMaskedIMG(img1,boxes1,masks1,idx1)       #this function combines image with masks and boxes,