In [1]:
from utils.load_dataset import *
from utils.custom_utils import *

In [8]:
def load_model(path):
    device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

    # create a Faster R-CNN model without pre-trained
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)

    num_classes = 3 # wheat or not(background)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # replace the pre-trained model's head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    # load the trained weights
    #model.load_state_dict(torch.load(path, map_location=device))
    model.eval()

    # move model to the right device
    model.to(device)
    return model, device

In [9]:
from torchvision import transforms

def visualize_prediction(img, model, thr=0.7):
    convert_tensor = transforms.ToTensor()
    img = convert_tensor(img)
    with torch.no_grad():
        prediction = model([img.to("cuda:1")])
    p = take_prediction(prediction[0],thr)
    for bb,label,score in p:
        if label == 0:
            im = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
            continue
        elif label == 1:
            color = "green"
            text = f"no fallen: {score:.3f}"
        else:
            color = "blue"
            text = f"fallen: {score:.3f}"
        x0,y0,x1,y1 = bb
        im = Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
        draw = ImageDraw.Draw(im)
        draw.rectangle(((x0, y0),(x1,y1)), outline=color, width=3)
        draw.text((x0, y0), text, fill=(0,0,0,0))
    #ImageShow.show(im)
    image_array = np.array(im)
    return image_array
    
def take_prediction(prediction, threshold):
    boxes = prediction['boxes'].tolist()
    labels = prediction['labels'].tolist()
    scores = prediction['scores'].tolist()
    if len(boxes) == 0:
        return [([0,0,0,0],0,0.)]
    
    res = [t for t in zip(boxes,labels,scores) if t[2]>threshold]
    if len(res) == 0:
        res = [([0,0,0,0],0,0.)]
    return res

In [10]:
model_path = "./models/checkpoint_train_real_over_virtual_with_lr_2.pth"
model_p, _ = load_model(model_path)

In [5]:
def play_video_p(path, out,model):
    cap = cv2.VideoCapture(path)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    frame_size = (frame_width,frame_height)
    fps = 30

    #out = cv2.VideoWriter(out, cv2.VideoWriter_fourcc('M','J','P','G'), fps, frame_size)
    out = cv2.VideoWriter(out, cv2.VideoWriter_fourcc(*'MP4V'), fps, frame_size)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("Can't receive frame (stream end?). Exiting ...")
            break
        im = visualize_prediction(frame,model,0.9)
        out.write(im)

    cap.release()
    out.release()

In [12]:
play_video_p("sample2.mp4","output_sample.mp4",model_p)

Can't receive frame (stream end?). Exiting ...
