### **Applying STCN to 13 videos of the Something-Something dataset with first frame annotations given by LOST**

As first frame annotations, we are going to use the bounding-box prediction obtained by LOST.

Then we are going to apply the STCN method, using authors' implementation: https://github.com/hkchengrex/STCN, and we are going to visualise the results by adapting the visualisation code of STM: https://github.com/seoungwugoh/STM.

In [None]:
# Access to dataset through Drive
import os
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
os.chdir('/content/drive/My Drive/')

Mounted at /content/drive/


In [None]:
import sys
sys.path.insert(0, '/content/drive/MyDrive/STCN/STCN')

import os
from os import path
from argparse import ArgumentParser

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image

from model.eval_network import STCN
from dataset.generic_test_dataset import GenericTestDataset
from util.tensor_util import unpad
from inference_core import InferenceCore
from helpers import overlay_davis

from progressbar import progressbar
import time

In [None]:
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--model', default='/content/drive/MyDrive/STCN/STCN/saves/stcn.pth')
parser.add_argument('--data_path', default='/content/drive/MyDrive/STCN/LOST-bounding-box')
parser.add_argument('--output', default='/content/drive/MyDrive/STCN/experiment/LOST-Box')
parser.add_argument('--top', type=int, default=20)
parser.add_argument('--amp_off', action='store_true')
parser.add_argument('--mem_every', default=5, type=int)
parser.add_argument('--include_last', help='include last frame as temporary memory?', action='store_true')
args, _ = parser.parse_known_args()

data_path = args.data_path
out_path = args.output
args.amp = not args.amp_off
palette = Image.open(path.expanduser('/content/drive/MyDrive/STCN/DAVIS/2017' + '/trainval/Annotations/480p/blackswan/00000.png')).getpalette()
VIZ = True    # Visualisation

# Simple setup
os.makedirs(out_path, exist_ok=True)
torch.autograd.set_grad_enabled(False)

# Setup Dataset
test_dataset = GenericTestDataset(data_root=data_path)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

# Load our checkpoint
top_k = args.top
prop_model = STCN().cuda().eval()

# Performs input mapping such that stage 0 model can be loaded
prop_saved = torch.load(args.model)
for k in list(prop_saved.keys()):
    if k == 'value_encoder.conv1.weight':
        if prop_saved[k].shape[1] == 4:
            pads = torch.zeros((64,1,7,7), device=prop_saved[k].device)
            prop_saved[k] = torch.cat([prop_saved[k], pads], 1)
prop_model.load_state_dict(prop_saved)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

<All keys matched successfully>

In [None]:
total_process_time = 0
total_frames = 0

mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None].cuda()
std = torch.tensor([0.229, 0.224, 0.225])[:, None, None].cuda()


# Start eval
for data in progressbar(test_loader, max_value=len(test_loader), redirect_stdout=True):

    with torch.cuda.amp.autocast(enabled=args.amp):
        rgb = data['rgb'].cuda()
        msk = data['gt'][0].cuda()
        info = data['info']
        name = info['name'][0]
        k = len(info['labels'][0])
        size = rgb.shape[-2:]

        torch.cuda.synchronize()
        process_begin = time.time()

        processor = InferenceCore(prop_model, rgb, k, top_k=top_k, 
                        mem_every=args.mem_every, include_last=args.include_last)
        processor.interact(msk[:,0], 0, rgb.shape[1])

        # Do unpad -> upsample to original size 
        out_masks = torch.zeros((processor.t, 1, *size), dtype=torch.uint8, device='cuda')
        for ti in range(processor.t):
            prob = unpad(processor.prob[:,ti], processor.pad)
            prob = F.interpolate(prob, size, mode='bilinear', align_corners=False)
            out_masks[ti] = torch.argmax(prob, dim=0)
        
        out_masks = (out_masks.detach().cpu().numpy()[:,0]).astype(np.uint8)

        torch.cuda.synchronize()
        total_process_time += time.time() - process_begin
        total_frames += out_masks.shape[0]

        # Save the results
        this_out_path = path.join(out_path, name)
        os.makedirs(this_out_path, exist_ok=True)
        for f in range(out_masks.shape[0]):
            img_E = Image.fromarray(out_masks[f])
            img_E.putpalette(palette)
            img_E.save(os.path.join(this_out_path, '{:05d}.png'.format(f)))


        # Visualisation
        # Adapted from the github of STM
        if VIZ:
          # visualize results
          viz_path = os.path.join('/content/drive/MyDrive/STCN/experiment/LOST-Box-viz', name) 
          if not os.path.exists(viz_path):
              os.makedirs(viz_path)

          for f in range(out_masks.shape[0]):
              im = rgb[0,f]  
              im = im * std + mean
              pF = (im.permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8)
              pE = out_masks[f]
              canvas = overlay_davis(pF, pE, palette)
              canvas = Image.fromarray(canvas)
              canvas.save(os.path.join(viz_path, 'f{}.jpg'.format(f)))

          vid_path = os.path.join('/content/drive/MyDrive/STCN/experiment/LOST-Box-viz', '{}.mp4'.format(name))
          frame_path = os.path.join('/content/drive/MyDrive/STCN/experiment/LOST-Box-viz', name, 'f%d.jpg')
          os.system('ffmpeg -framerate 10 -i {} {} -vcodec libx264 -crf 10  -pix_fmt yuv420p  -nostats -loglevel 0 -y'.format(frame_path, vid_path))


        del rgb
        del msk
        del processor


print('Total processing time: ', total_process_time)
print('Total processed frames: ', total_frames)
print('FPS: ', total_frames / total_process_time)

100% (13 of 13) |########################| Elapsed Time: 0:03:44 Time:  0:03:44


Total processing time:  20.008266925811768
Total processed frames:  577
FPS:  28.838079886651162
