In [1]:
# imports
import torch
import torch.nn.functional as F
import torchexplain # pip install -i https://test.pypi.org/simple/ torchexplain
import cv2
import matplotlib.pyplot as plt
import numpy as np
import c3d

In [2]:
# define model
mean=[90.0, 98.0, 102.0]
mdl = c3d.C3D(101,False,range=(-max(mean),255-min(mean)),embedding=False,train=False)
weights = torch.load("c3d.pth")
mdl.load_state_dict(weights)
mdl = mdl.cuda()

In [3]:
# get input
reader = cv2.VideoCapture("boxing.avi")
fourcc = cv2.VideoWriter_fourcc(*'DIVX')
writer = cv2.VideoWriter("explanation.avi",fourcc,16,(320,240))
frames = []
# get all frames from video
while True:
    ret, frame = reader.read()
    if ret:
        # the model takes input frames in channel-first order, expecting image size to be (112,112)
        frame = cv2.resize(frame, (112,112))
        frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
        frame = frame.transpose(2,0,1)
        frames.append(torch.from_numpy(frame))
    else:
        break
reader.release()

In [4]:
print(len(frames))
# the model takes input of 16 frames at a time, we need to pad the video to make sure it can be split into 16 frame clips
while (len(frames) % 16) != 0:
    frames.append(frames[-1])

frames_tmp = []
for fidx in range(0,len(frames),16):
    frames_tmp.append(torch.stack(frames[fidx:fidx+ 16]))

samples = [sample.permute(1,0,2,3) for sample in frames_tmp]

252


In [5]:
for fidx in range(len(samples)):
    samples[fidx] = samples[fidx].float()
    for c in range(3):
        samples[fidx][c,...] -= mean[c]
    samples[fidx].requires_grad_()
    samples[fidx] = samples[fidx].cuda().unsqueeze(0)

In [6]:
out = []
for sample in samples:
    out.append(mdl(sample).cpu())

In [7]:
grad = []
for sample, o in zip(samples,out):
    target = o.topk(1,1)[1] # this can be any class, for this example we will explain the models first choice
    t_out = torch.zeros_like(o)
    t_out[:,target] += 1 # this is used as a mask when backpropagating to only explain for this class
    # grad returns a gradient for each input to the graph, which is normally input, weights, bias
    grad.append(torch.autograd.grad(o,sample,t_out)[0]) # we just want the input gradient, which is the relevance of the input

In [8]:
# at this point we have the LRP output, next we visualise
cmap = plt.get_cmap("seismic") # this is the trademark colour map used by LRP, DeepTaylor, and Discriminative Relevance, but any (or none) is fine

# we define the sobel operator for dR here for efficiency
sob_t = torch.tensor([[[1, 2, 1], [2, 4, 2], [1, 2, 1]], [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
                          [[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]]]) # you can find this at https://en.wikipedia.org/wiki/Sobel_operator, we also define it in the paper
sob_t = sob_t.reshape((1, 1, 3, 3, 3)).float().cuda() # (batch,nchannels,t,x,y)
sigma = 2 # how many standard deviations the derivative of the relevance needs to be above to be selected by dR
for sample, relevance in zip(samples,grad):
    relevance = relevance.sum(dim=(0,1)) # we don't care about batch-wise or colour-channel-wise relevance
    norm_rel = relevance / abs(relevance).max()
    norm_rel_deriv_in_t = F.conv3d(norm_rel[None][None], sob_t, padding=1)[0,0,...]
    # comment this out for normal lrp
    norm_rel = norm_rel * (norm_rel_deriv_in_t > (norm_rel_deriv_in_t.std() * sigma)).float() # this is the dR step, filter out any relevance that doesn't have a derivative in t, over sigma standard deviations
    # seismic takes numbers in [0,1] range, with 0 being blue and 1 red
    # post normalisation, minimum relevance would be -1, and maximum 1, therefore we scale the relevance to this
    # so that no relevance is scaled 0.5, which is white in seismic
    norm_rel = (norm_rel - -1) / (1 - -1)
    
    # from here on, we process the explanations so that they can be saved as images
    for fidx in range(16):
        # get input frame
        frame = sample[0,:,fidx,...]
        print(frame.shape)
        for c in range(3):
            frame[c,...] += mean[c]
        frame = frame.detach().cpu().numpy().astype(np.uint8)
        frame = frame.transpose(1,2,0)
        frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)

        expl_frame = norm_rel[fidx,...].cpu().numpy()
        expl_frame = cmap(expl_frame)
        expl_frame *= 255
        print(expl_frame.shape)
        expl_frame = expl_frame.astype(np.uint8)
        expl_frame = cv2.cvtColor(expl_frame,cv2.COLOR_RGBA2BGR) # seismic adds an alpha channel (RGBA)
        # overlay the explanation on the frame
        overlay = cv2.addWeighted(expl_frame,0.5,frame,0.5,0)
        overlay = cv2.resize(overlay,(320,240))
#         cv2.imshow("overlay",overlay)
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()
        # write to video
        writer.write(overlay)
writer.release()

torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
torch.Size([3, 112, 112])
(112, 112, 4)
