# Test the disaster identifying model on a drone video
#### Specify path of a test video and output video

In [33]:
from __future__ import print_function, division
import matplotlib.pyplot as plt
import numpy as np
import os
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import copy
import cv2
from collections import deque
cudnn.benchmark = True
plt.ion()   # interactive mode
from PIL import Image

from omnixai.data.image import Image
from omnixai.explainers.vision.specific.gradcam.pytorch.gradcam import GradCAM
import plotly.io as pio
pio.renderers.default = "png"
from IPython.display import display # to display images

In [29]:
# The preprocessing model
transform = transforms.Compose([
    transforms.Resize((224,224)),
    #transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])

In [30]:
model = torch.load('trained_model.pth')


In [31]:

test_video = '/kaggle/input/test-video/gettyimages-1207220141-640_adpp.mp4'
output_video = 'output.mp4'

In [32]:
CLASSES = {0:"Collapsed Building", 1:"Fire", 2:"Flood", 3:"Normal"}
BATCH_SIZE = 8
IMG_SIZE = (224, 224)
TRANSFORM_IMG = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMG_SIZE),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] )
    ])


preprocess = lambda ims: torch.stack([TRANSFORM_IMG(im.to_pil()) for im in ims])
explainer = GradCAM(
    model=model,
    target_layer=model.layer4[-1],
    preprocess_function=preprocess
)



# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

videoCapture = cv2.VideoCapture(test_video)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
print(fps)
size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
videoWriter = cv2.VideoWriter(output_video, fourcc, fps, size)
Q = deque(maxlen=int(fps))
t1 = time.time()
c=0
success, frame = videoCapture.read()
while success:
    c+=1
    frame_copy = copy.deepcopy(frame) 
    frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
    image_gc = Image(copy.deepcopy(frame_copy))


    # Explain the top label
    explanations = explainer.explain(image_gc)
    image_tensor = TRANSFORM_IMG(frame_copy)

    image_tensor = image_tensor.unsqueeze(0) 
    test_input = image_tensor.to(device)
    outputs = model(test_input)
    _, predicted = torch.max(outputs, 1)
    probability =  F.softmax(outputs, dim=1)
    top_probability, top_class = probability.topk(1, dim=1)
    predicted = predicted.cpu().detach().numpy()
    predicted = predicted.tolist()[0]
    Q.append(predicted)

    results = np.array(Q).mean(axis=0)
    #i = np.argmax(results)
    #print(Q, results, CLASSES[np.round(results)])
    label =CLASSES[np.round(results)]
    top_probability = top_probability.cpu().detach().numpy()
    top_probability = top_probability.tolist()[0][0]
    percentage = top_probability
    top_probability = '%.2f%%' % (top_probability * 100)
    #if confidence is low, set label as normal
    if percentage < 0.30:
        label="Normal"
    heatmap = explanations.explanations[0]['scores']
    heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img  = cv2.addWeighted(heatmap, 0.3, frame, 0.8, 0)
    #plt.imshow(frame+heatmap)
    #print(superimposed_img.shape)
    if label is "Normal":
        color = (0, 150, 0)
        frame = cv2.putText(frame, label+': '+top_probability, (50, 50), 2, 0.5 ,color, 1 )
        videoWriter.write(frame)
    else:
        color = (0, 0, 150)
        superimposed_img = cv2.putText(superimposed_img, label+': '+top_probability, (50, 50), 2, 0.5 ,color, 1 )
        videoWriter.write(superimposed_img)



    success, frame = videoCapture.read()
videoWriter.release()

t2=time.time()
print('done', t2-t1, c)






25.0


OpenCV: FFMPEG: tag 0x58564944/'DIVX' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'


done 10.840048789978027 375
