**Applying the STCN method on the Davis2017 dataset**

We are going to apply the STCN method: https://arxiv.org/pdf/2106.05210.pdf, on the train videos of the Davis2017 dataset using the authors' implementation: https://github.com/hkchengrex/STCN.

Then, we are going to visualise the result by adapting the visualisation code from the GitHub of STM: https://github.com/seoungwugoh/STM.


In [None]:
# Access to dataset through Drive
import os
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
os.chdir('/content/drive/My Drive/')

Mounted at /content/drive/


In [None]:
%cd '/content/drive/MyDrive/STCN/STCN'

/content/drive/MyDrive/STCN/STCN


In [None]:
import os
from os import path
import time
from argparse import ArgumentParser

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from model.eval_network import STCN
from dataset.davis_test_dataset import DAVISTestDataset
from util.tensor_util import unpad
from inference_core import InferenceCore

from progressbar import progressbar

import tqdm

In [None]:
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--model', default='/content/drive/MyDrive/STCN/STCN/saves/stcn.pth')
parser.add_argument('--davis_path', default='/content/drive/MyDrive/STCN/DAVIS/2017')
parser.add_argument('--output', default='/content/drive/MyDrive/STCN/experiment/Davis2017/train')
parser.add_argument('--split', help='val/testdev', default='val')
parser.add_argument('--top', type=int, default=20)
parser.add_argument('--amp', action='store_true')
parser.add_argument('--mem_every', default=5, type=int)
parser.add_argument('--include_last', help='include last frame as temporary memory?', action='store_true')
parser.add_argument('--visualisation', default=True, type=bool) # Save the visualisation

args, unknown = parser.parse_known_args()

davis_path = args.davis_path
out_path = args.output
VIZ = args.visualisation

# Simple setup
os.makedirs(out_path, exist_ok=True)
palette = Image.open(path.expanduser(davis_path + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette()

In [None]:
torch.autograd.set_grad_enabled(False)

# Setup Dataset
if args.split == 'val':
    test_dataset = DAVISTestDataset(davis_path + '/trainval', imset='2017/train.txt')  # imset: train or val 
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
elif args.split == 'testdev':
    test_dataset = DAVISTestDataset(davis_path + '/test-dev', imset='2017/test-dev.txt')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
else:
    raise NotImplementedError


use_cuda = torch.cuda.is_available()
if use_cuda:
    print('Using GPU')
else:
    print('Using CPU')


# Load our checkpoint
top_k = args.top
prop_model = STCN().cuda().eval()

# Performs input mapping such that stage 0 model can be loaded
prop_saved = torch.load(args.model)
for k in list(prop_saved.keys()):
    if k == 'value_encoder.conv1.weight':
        if prop_saved[k].shape[1] == 4:
            pads = torch.zeros((64,1,7,7), device=prop_saved[k].device)
            prop_saved[k] = torch.cat([prop_saved[k], pads], 1)
prop_model.load_state_dict(prop_saved)

Using GPU


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

<All keys matched successfully>

In [None]:
total_process_time = 0
total_frames = 0

mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None].cuda()
std = torch.tensor([0.229, 0.224, 0.225])[:, None, None].cuda()


# Start evaluation
for data in progressbar(test_loader, max_value=len(test_loader), redirect_stdout=True):

    with torch.cuda.amp.autocast(enabled=args.amp):
        rgb = data['rgb'].cuda()
        msk = data['gt'][0].cuda()
        info = data['info']
        name = info['name'][0]
        k = len(info['labels'][0])
        size = info['size_480p']

        torch.cuda.synchronize()
        process_begin = time.time()

        processor = InferenceCore(prop_model, rgb, k, top_k=top_k, 
                        mem_every=args.mem_every, include_last=args.include_last)
        processor.interact(msk[:,0], 0, rgb.shape[1])

        # Do unpad -> upsample to original size 
        out_masks = torch.zeros((processor.t, 1, *size), dtype=torch.uint8, device='cuda')
        for ti in range(processor.t):
            prob = unpad(processor.prob[:,ti], processor.pad)
            prob = F.interpolate(prob, size, mode='bilinear', align_corners=False)
            out_masks[ti] = torch.argmax(prob, dim=0)
        
        out_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8)

        torch.cuda.synchronize()
        total_process_time += time.time() - process_begin
        total_frames += out_masks.shape[0]

        # Save the results
        this_out_path = path.join(out_path, name)
        os.makedirs(this_out_path, exist_ok=True)
        for f in range(out_masks.shape[0]):
            img_E = Image.fromarray(out_masks[f])
            img_E.putpalette(palette)
            img_E.save(os.path.join(this_out_path, '{:05d}.png'.format(f)))


        # Visualisation
        # Adapted from the github of STM
        if VIZ:
          from helpers import overlay_davis
          # visualize results
          viz_path = os.path.join('/content/drive/MyDrive/STCN/experiment/Davis2017/viz-train/', name) 
          if not os.path.exists(viz_path):
              os.makedirs(viz_path)

          for f in range(out_masks.shape[0]):
              im = rgb[0,f]  
              im = im * std + mean
              pF = (im.permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8)
              pE = out_masks[f]
              canvas = overlay_davis(pF, pE, palette)
              canvas = Image.fromarray(canvas)
              canvas.save(os.path.join(viz_path, 'f{}.jpg'.format(f)))

          vid_path = os.path.join('/content/drive/MyDrive/STCN/experiment/Davis2017/viz-train/', '{}.mp4'.format(name))
          frame_path = os.path.join('/content/drive/MyDrive/STCN/experiment/Davis2017/viz-train/', name, 'f%d.jpg')
          os.system('ffmpeg -framerate 10 -i {} {} -vcodec libx264 -crf 10  -pix_fmt yuv420p  -nostats -loglevel 0 -y'.format(frame_path, vid_path))


        del rgb
        del msk
        del processor


print('Total processing time: ', total_process_time)
print('Total processed frames: ', total_frames)
print('FPS: ', total_frames / total_process_time)

100% (60 of 60) |########################| Elapsed Time: 0:34:31 Time:  0:34:31


Total processing time:  638.2924301624298
Total processed frames:  4209
FPS:  6.594156222295966
