-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add optical flow track assist colab demo.
PiperOrigin-RevId: 570295744 Change-Id: Ibdabdd8820b879b3fddc5366eca7d30b9fd98996
- Loading branch information
Showing
1 changed file
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "dXPeGxzWPA73" | ||
}, | ||
"source": [ | ||
"# Demo for annotate a point track with optical flow\n", | ||
"\n", | ||
"This notebook illustrates how we use optical flow to facilitate human annotation on point tracking. Note that it is very hard to annotate a point track extensively along a whole video sequence. However we find dense optical flow estimation these days are fast and accurate. In this demo, we utilize [RAFT](https://pytorch.org/vision/stable/auto_examples/plot_optical_flow.html) to compute the dense optical flow for us.\n", | ||
"\n", | ||
"We then ask the annotater to select a point in the starting frame and the corresponding point location in the ending frame. A dynamic programming algorithm is used to optimize the estimated tracks given starting and ending point location. Note that the algorithm here differs from what we use in the original annotation system (dijkstra algorithm).\n", | ||
"\n", | ||
"The dynamic programming algorithm here requires large matrix computation. Hence running on GPU will be a lot faster." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Wv9x5NjJzm54" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install mediapy mako flow_vis" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "ZxsjTxUqzrwV" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Imports {form-width: \"25%\"}\n", | ||
"\n", | ||
"import copy\n", | ||
"import io\n", | ||
"import flow_vis\n", | ||
"import functools\n", | ||
"import gc\n", | ||
"import IPython\n", | ||
"import mediapy as media\n", | ||
"import numpy as np\n", | ||
"from PIL import Image\n", | ||
"from google.colab import html\n", | ||
"import base64\n", | ||
"from mako import template\n", | ||
"import torch\n", | ||
"import torchvision\n", | ||
"from tqdm import tqdm" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "p3FnMfsszsMb" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# If you can, run this example on a GPU, it will be a lot faster.\n", | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"\n", | ||
"torch.set_grad_enabled(False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "5eKOeizEztfG" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Load an Exemplar Video {form-width: \"25%\"}\n", | ||
"\n", | ||
"!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n", | ||
"\n", | ||
"video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')\n", | ||
"video = media.resize_video(video, (480, 768))\n", | ||
"height, width = video.shape[1:3]\n", | ||
"media.show_video(video, fps=10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "m_LydwEBzu15" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Predict Optical Flows with RAFT {form-width: \"25%\"}\n", | ||
"\n", | ||
"from torchvision.models.optical_flow import raft_large\n", | ||
"from torchvision.models.optical_flow import Raft_Large_Weights\n", | ||
"\n", | ||
"model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)\n", | ||
"model = model.eval()\n", | ||
"\n", | ||
"optical_flows = []\n", | ||
"for i in tqdm(range(video.shape[0] - 1)):\n", | ||
" image1 = video[i].astype(np.float32) / 127.5 - 1.0\n", | ||
" image1 = image1.transpose(2, 0, 1)[None]\n", | ||
" image2 = video[i + 1].astype(np.float32) / 127.5 - 1.0\n", | ||
" image2 = image2.transpose(2, 0, 1)[None]\n", | ||
" flow = model(torch.tensor(image1).to(device), torch.tensor(image2).to(device))\n", | ||
" flow = flow[-1][0].cpu().numpy()\n", | ||
" flow = flow.transpose(1, 2, 0)\n", | ||
" optical_flows.append(flow)\n", | ||
"optical_flows = np.stack(optical_flows)\n", | ||
"\n", | ||
"# Release Memory after Prediction\n", | ||
"del model\n", | ||
"gc.collect()\n", | ||
"torch.cuda.empty_cache()\n", | ||
"\n", | ||
"print(optical_flows.shape)\n", | ||
"print(np.abs(optical_flows).max())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "_tzMX-Yizwy9" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Visualize Optical Flows {form-width: \"25%\"}\n", | ||
"\n", | ||
"flow_viz = []\n", | ||
"for i in range(optical_flows.shape[0]):\n", | ||
" flow_viz.append(flow_vis.flow_to_color(optical_flows[i]))\n", | ||
"flow_viz = np.stack(flow_viz)\n", | ||
"\n", | ||
"media.show_video(flow_viz, fps=10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "5u_YkBmrzyOL" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title HTML Template {form-width: \"25%\"}\n", | ||
"\n", | ||
"class Img(html.Element):\n", | ||
" def __init__(self, src=None, show=False):\n", | ||
" super(Img, self).__init__('img')\n", | ||
" if src is not None:\n", | ||
" self.src = src\n", | ||
" self.set_attribute('style', ('display:block;' if show else 'display:none;')+'margin:0px;')\n", | ||
"\n", | ||
" @property\n", | ||
" def src(self):\n", | ||
" return self.get_property('src')\n", | ||
"\n", | ||
" @src.setter\n", | ||
" def src(self, value):\n", | ||
" content = self._to_jpeg(value)\n", | ||
" url = 'data:image/jpeg;base64,' + base64.b64encode(content).decode('utf-8')\n", | ||
" self.set_property('src', url)\n", | ||
"\n", | ||
" def _to_jpeg(self, np_image):\n", | ||
" img = Image.fromarray(np_image)\n", | ||
" buf = io.BytesIO()\n", | ||
" img.save(buf, format=\"JPEG\")\n", | ||
" return buf.getvalue()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "mIbQG5Ycz5xu" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Dynamic Programming Algorithm {form-width: \"25%\"}\n", | ||
"\n", | ||
"def interpolate(flows, frame1, click1, frame2, click2, radius=20):\n", | ||
" x1, y1 = click2idx(click1)\n", | ||
" x2, y2 = click2idx(click2)\n", | ||
"\n", | ||
" window = 2 * radius + 1\n", | ||
" x, y = np.meshgrid(np.arange(-radius, radius + 1), np.arange(-radius, radius + 1))\n", | ||
" offset_cost = np.stack([x, y], axis=-1)\n", | ||
" offset_cost = torch.tensor(offset_cost).to(device)\n", | ||
"\n", | ||
" num_frames, height, width = flows.shape[0:3]\n", | ||
"\n", | ||
" forward_i = np.zeros((num_frames + 1, height, width), dtype=np.int32)\n", | ||
" forward_j = np.zeros((num_frames + 1, height, width), dtype=np.int32)\n", | ||
"\n", | ||
" forward_cost = torch.ones((height, width)).to(device) * 1e10\n", | ||
" forward_cost[y1, x1] = 0\n", | ||
"\n", | ||
" for t in range(frame1, frame2):\n", | ||
" cost_pad = torch.nn.functional.pad(forward_cost, (radius, radius, radius, radius), 'constant', value=1e10)\n", | ||
" cost_unfold = cost_pad.unfold(0, window, 1).unfold(1, window, 1)\n", | ||
" del cost_pad\n", | ||
" gc.collect()\n", | ||
" torch.cuda.empty_cache()\n", | ||
"\n", | ||
" flow_cuda = torch.tensor(flows[t]).to(device)\n", | ||
" flow_pad = torch.nn.functional.pad(flow_cuda, (0, 0, radius, radius, radius, radius), 'constant', value=1e10)\n", | ||
" flow_unfold = flow_pad.unfold(0, window, 1).unfold(1, window, 1).permute(0, 1, 3, 4, 2)\n", | ||
" del flow_cuda, flow_pad\n", | ||
" gc.collect()\n", | ||
" torch.cuda.empty_cache()\n", | ||
"\n", | ||
" cost = cost_unfold + torch.abs(-offset_cost[None, None] - flow_unfold).sum(axis=-1)\n", | ||
" cost = cost.reshape(height, width, -1)\n", | ||
" forward_cost, argmin_indices = torch.min(cost, axis=-1)\n", | ||
" del cost\n", | ||
" gc.collect()\n", | ||
" torch.cuda.empty_cache()\n", | ||
"\n", | ||
" argmin_indices = argmin_indices.cpu().numpy()\n", | ||
" forward_i_min, forward_j_min = argmin_indices // (window), argmin_indices % (window)\n", | ||
" forward_i[t] = forward_i_min + np.arange(height)[:, None] - radius\n", | ||
" forward_j[t] = forward_j_min + np.arange(width)[None] - radius\n", | ||
"\n", | ||
" last_cost = torch.ones((height, width)).to(device) * 1e10\n", | ||
" last_cost[y2, x2] = 0\n", | ||
" forward_cost += last_cost\n", | ||
" min_cost = torch.min(forward_cost).cpu().numpy()\n", | ||
"\n", | ||
" argmin_indices = torch.argmin(forward_cost).item()\n", | ||
" min_i, min_j = argmin_indices // width, argmin_indices % width\n", | ||
" min_ij = [(min_j, min_i)]\n", | ||
"\n", | ||
" for t in range(frame2 - 1, frame1 - 1, -1):\n", | ||
" min_i, min_j = forward_i[t, min_i, min_j], forward_j[t, min_i, min_j]\n", | ||
" min_ij.insert(0, (min_j, min_i))\n", | ||
"\n", | ||
" del forward_cost\n", | ||
" gc.collect()\n", | ||
" torch.cuda.empty_cache()\n", | ||
"\n", | ||
" return np.stack(min_ij), min_cost" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "SklmS5UGz6S1" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Reset the Annotated Trajectories {form-width: \"25%\"}\n", | ||
"\n", | ||
"clicks=[None for i in range(video.shape[0])]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "LCguhULPz8qT" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# @title Start Annotation {form-width: \"25%\"}\n", | ||
"\n", | ||
"def mouse_position(event, frame_id):\n", | ||
" x = event['clientX']\n", | ||
" y = event['clientY']\n", | ||
" clicks[frame_id]=[x, y]\n", | ||
" print('\\r', 'Please re-run this cell ...', end='')\n", | ||
"\n", | ||
"def click2idx(click):\n", | ||
" x, y = click\n", | ||
" x = int(round(x))\n", | ||
" y = int(round(y))\n", | ||
" return x, y\n", | ||
"\n", | ||
"cur_pos = None\n", | ||
"frames2 = []\n", | ||
"all_pos = np.zeros([video.shape[0], 2], dtype=int)\n", | ||
"last_click = None\n", | ||
"for i in range(video.shape[0]):\n", | ||
" if clicks[i] and last_click:\n", | ||
" all_pos[last_click[0]:i+1, :], forward_cost = interpolate(optical_flows, last_click[0], last_click[1], i, clicks[i])\n", | ||
"\n", | ||
" if clicks[i]:\n", | ||
" cur_pos = copy.copy(clicks[i])\n", | ||
" last_click = (i, clicks[i])\n", | ||
" if cur_pos:\n", | ||
" x, y = click2idx(cur_pos)\n", | ||
"\n", | ||
" y = min(max(y, 0), height - 1)\n", | ||
" x = min(max(x, 0), width - 1)\n", | ||
" all_pos[i,0] = x\n", | ||
" all_pos[i,1] = y\n", | ||
" if i \u003c optical_flows.shape[0]:\n", | ||
" cur_pos[0] += optical_flows[i, y, x, 0]\n", | ||
" cur_pos[1] += optical_flows[i, y, x, 1]\n", | ||
"\n", | ||
"for i in range(video.shape[0]):\n", | ||
" fr = np.copy(video[i])\n", | ||
" x, y = all_pos[i] - 5\n", | ||
" fr[y-2:y+3,x-2:x+3,0] = 255 if clicks[i] else 0\n", | ||
" fr[y-2:y+3,x-2:x+3,1] = 0 if clicks[i] else 255\n", | ||
" fr[y-2:y+3,x-2:x+3,2] = 0 if clicks[i] else 255\n", | ||
" frames2.append(fr)\n", | ||
"\n", | ||
"imgs=[]\n", | ||
"img_ids=\"[\"\n", | ||
"for i in range(len(frames2)):\n", | ||
" img = Img(src=frames2[i], show=i==0)\n", | ||
" img.add_event_listener('click', functools.partial(mouse_position, frame_id=i))\n", | ||
" imgs.append(img)\n", | ||
" img_ids += \"\\\"\" + str(img._guid) + \"\\\",\"\n", | ||
"img_ids += \"]\"\n", | ||
"\n", | ||
"MAKO_TEMPLATE=\"\"\"\n", | ||
"\u003cinput type=\"range\" min=\"0\" max=\"${num_frames-1}\" value=\"0\" class=\"slider\" id=\"myRange\"\u003e\n", | ||
"\u003cscript\u003e\n", | ||
"img_ids=${img_ids}\n", | ||
"slider=document.getElementById(\"myRange\");\n", | ||
"cur_frame=0\n", | ||
"slider.oninput = function() {\n", | ||
" idx = this.value;\n", | ||
" for (var i = 0; i\u003c${num_frames}; i++){\n", | ||
" document.getElementById(img_ids[i]).style.display=\"none\"\n", | ||
" }\n", | ||
" document.getElementById(img_ids[idx]).style.display=\"block\"\n", | ||
"}\n", | ||
"\u003c/script\u003e\n", | ||
"\"\"\"\n", | ||
"viz_tpl = template.Template(MAKO_TEMPLATE, strict_undefined=True)\n", | ||
"script = viz_tpl.render(num_frames=len(frames2),img_ids=img_ids)\n", | ||
"\n", | ||
"display(IPython.display.HTML(\" \".join([img._repr_html_() for img in imgs])+script))\n", | ||
"\n", | ||
"################################################################################\n", | ||
"# Instructions:\n", | ||
"#\n", | ||
"# 1) click anywhere on the first frame to get a point to track.\n", | ||
"# 2) re-run this cell to see where it goes\n", | ||
"# 3) click a point on any other frame, and the demo will find the shortest path.\n", | ||
"################################################################################" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"private_outputs": true, | ||
"provenance": [] | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |