Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running SAM backbone on frontend #6019

Merged
merged 23 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cvat-canvas/src/typescript/interactionHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ export class InteractionHandlerImpl implements InteractionHandler {
'shape-rendering': 'geometricprecision',
'pointer-events': 'none',
opacity: 0.5,

// always fit masks to background size
// in general mask can be smaller what is useful in optimization purposes
width: geometry.image.width,
height: geometry.image.height,
}).addClass('cvat_canvas_interact_intermediate_shape');
image.move(this.geometry.offset, this.geometry.offset);
this.drawnIntermediateShape = image;
Expand Down
14 changes: 12 additions & 2 deletions cvat-core/src/plugins.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ export default class PluginRegistry {
static async apiWrapper(wrappedFunc, ...args) {
// I have to optimize the wrapper
const pluginList = await PluginRegistry.list();
const aggregatedOptions = {
preventMethodCall: false,
};

for (const plugin of pluginList) {
const pluginDecorators = plugin.functions.filter((obj) => obj.callback === wrappedFunc)[0];
if (pluginDecorators && pluginDecorators.enter) {
try {
await pluginDecorators.enter.call(this, plugin, ...args);
const options = await pluginDecorators.enter.call(this, plugin, ...args);
if (options?.preventMethodCall) {
aggregatedOptions.preventMethodCall = true;
}
} catch (exception) {
if (exception instanceof PluginError) {
throw exception;
Expand All @@ -24,7 +31,10 @@ export default class PluginRegistry {
}
}

let result = await wrappedFunc.implementation.call(this, ...args);
let result = null;
if (!aggregatedOptions.preventMethodCall) {
result = await wrappedFunc.implementation.call(this, ...args);
}

for (const plugin of pluginList) {
const pluginDecorators = plugin.functions.filter((obj) => obj.callback === wrappedFunc)[0];
Expand Down
3 changes: 3 additions & 0 deletions cvat-ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"dependencies": {
"@ant-design/icons": "^4.6.3",
"@types/lodash": "^4.14.172",
"@types/lru-cache": "^7.10.10",
"@types/platform": "^1.3.4",
"@types/react": "^16.14.15",
"@types/react-color": "^3.0.5",
Expand All @@ -41,8 +42,10 @@
"dotenv-webpack": "^8.0.1",
"error-stack-parser": "^2.0.6",
"lodash": "^4.17.21",
"lru-cache": "^9.1.1",
"moment": "^2.29.2",
"mousetrap": "^1.6.5",
"onnxruntime-web": "^1.14.0",
"platform": "^1.3.6",
"prop-types": "^15.7.2",
"react": "^16.14.0",
Expand Down
262 changes: 262 additions & 0 deletions cvat-ui/plugins/sam_plugin/src/ts/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
// Copyright (C) 2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import { PluginEntryPoint, ComponentBuilder } from 'components/plugins-entrypoint';
import { InferenceSession, Tensor } from 'onnxruntime-web';
import { LRUCache } from 'lru-cache';

interface SAMPlugin {
name: string;
description: string;
cvat: {
lambda: {
call: {
enter: (
plugin: SAMPlugin,
taskID: number,
model: any,
args: any,
) => Promise<null | { preventMethodCall: boolean }>;
bsekachev marked this conversation as resolved.
Show resolved Hide resolved
leave: (
plugin: SAMPlugin,
result: any,
taskID: number,
model: any,
args: any,
) => Promise<any>;
};
};
};
data: {
modelID: string;
modelURL: string;
embeddings: LRUCache<string, Tensor>;
lowResMasks: LRUCache<string, Tensor>;
session: InferenceSession | null;
};
callbacks: {
onStatusChange: ((status: string) => void) | null;
};
}

interface ONNXInput {
image_embeddings: Tensor;
point_coords: Tensor;
point_labels: Tensor;
orig_im_size: Tensor;
mask_input: Tensor;
has_mask_input: Tensor;
readonly [name: string]: Tensor;
}

interface ClickType {
clickType: -1 | 0 | 1,
height: number | null,
width: number | null,
x: number,
y: number,
}

function getModelScale(): { height: number, width: number, samScale: number } {
// Input images to SAM must be resized so the longest side is 1024
const LONG_SIDE_LENGTH = 1024;
const w = +window.document.getElementsByClassName('cvat_masks_canvas_wrapper')[0].style.width.slice(0, -2);
const h = +window.document.getElementsByClassName('cvat_masks_canvas_wrapper')[0].style.height.slice(0, -2);
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
return { height: h, width: w, samScale };
}

function modelData(
{
clicks, tensor, modelScale, maskInput,
}: {
clicks: ClickType[];
tensor: Tensor;
modelScale: ReturnType<typeof getModelScale>;
maskInput: Tensor | null;
},
): ONNXInput {
const imageEmbedding = tensor;

const n = clicks.length;
// If there is no box input, a single padding point with
// label -1 and coordinates (0.0, 0.0) should be concatenated
// so initialize the array to support (n + 1) points.
const pointCoords = new Float32Array(2 * (n + 1));
const pointLabels = new Float32Array(n + 1);

// Add clicks and scale to what SAM expects
for (let i = 0; i < n; i++) {
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
pointLabels[i] = clicks[i].clickType;
}

// Add in the extra point/label when only clicks and no box
// The extra point is at (0, 0) with label -1
pointCoords[2 * n] = 0.0;
pointCoords[2 * n + 1] = 0.0;
pointLabels[n] = -1.0;

// Create the tensor
const pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2]);
const pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1]);
const imageSizeTensor = new Tensor('float32', [
modelScale.height,
modelScale.width,
]);

const prevMask = maskInput ||
new Tensor('float32', new Float32Array(256 * 256), [1, 1, 256, 256]);
const hasMaskInput = new Tensor('float32', [maskInput ? 1 : 0]);

return {
image_embeddings: imageEmbedding,
point_coords: pointCoordsTensor,
point_labels: pointLabelsTensor,
orig_im_size: imageSizeTensor,
mask_input: prevMask,
has_mask_input: hasMaskInput,
};
}

const samPlugin: SAMPlugin = {
name: 'Segmeny Anything',
description: 'Plugin handles non-default SAM serverless function output',
cvat: {
lambda: {
call: {
async enter(
plugin: SAMPlugin,
taskID: number,
model: any, { frame }: { frame: number },
): Promise<null | { preventMethodCall: boolean }> {
if (model.id === plugin.data.modelID) {
if (!plugin.data.session) {
throw new Error('SAM plugin is not ready, session was not initialized');
}

const key = `${taskID}_${frame}`;
if (plugin.data.embeddings.has(key)) {
return { preventMethodCall: true };
}
}

return null;
},

async leave(
plugin: SAMPlugin,
result: any,
taskID: number,
model: any,
{ frame, pos_points, neg_points }: {
frame: number, pos_points: number[][], neg_points: number[][],
},
): Promise<{ mask: number[][]; orig_size: [number, number]; }> {
const key = `${taskID}_${frame}`;
if (model.id !== plugin.data.modelID) {
return result;
}

if (result) {
const bin = window.atob(result.blob);
const uint8Array = new Uint8Array(bin.length);
for (let i = 0; i < bin.length; i++) {
uint8Array[i] = bin.charCodeAt(i);
}
const float32Arr = new Float32Array(uint8Array.buffer);
plugin.data.embeddings.set(key, new Tensor('float32', float32Arr, [1, 256, 64, 64]));
}

const modelScale = getModelScale();
const composedClicks = [...pos_points, ...neg_points].map(([x, y], index) => ({
clickType: index < pos_points.length ? 1 : 0 as 0 | 1 | -1,
height: null,
width: null,
x,
y,
}));

const feeds = modelData({
clicks: composedClicks,
tensor: plugin.data.embeddings.get(key) as Tensor,
modelScale,
maskInput: plugin.data.lowResMasks.has(key) ? plugin.data.lowResMasks.get(key) as Tensor : null,
});

function toMatImage(input: number[], width: number, height: number): number[][] {
const image = Array(height).fill(0);
for (let i = 0; i < image.length; i++) {
image[i] = Array(width).fill(0);
}

for (let i = 0; i < input.length; i++) {
const row = Math.floor(i / width);
const col = i % width;
image[row][col] = input[i] > 0.0 ? 255 : 0;
}

return image;
}

function onnxToImage(input: any, width: number, height: number): number[][] {
return toMatImage(input, width, height);
}

const data = await (plugin.data.session as InferenceSession).run(feeds);
const { masks, low_res_masks: lowResMasks } = data;
const imageData = onnxToImage(masks.data, masks.dims[3], masks.dims[2]);
plugin.data.lowResMasks.set(key, lowResMasks);

return {
mask: imageData,
orig_size: [modelScale.width, modelScale.height],
};
},
},
},
},
data: {
modelID: 'pth-facebookresearch-sam-vit-h',
modelURL: '/api/lambda/sam_detector.onnx',
embeddings: new LRUCache({
// float32 tensor [256, 64, 64] is 4 MB, max 512 MB
max: 128,
updateAgeOnGet: true,
updateAgeOnHas: true,
}),
lowResMasks: new LRUCache({
// float32 tensor [1, 256, 256] is 0.25 MB, max 32 MB
max: 128,
updateAgeOnGet: true,
updateAgeOnHas: true,
}),
session: null,
},
callbacks: {
onStatusChange: null,
},
};

const SAMModelPlugin: ComponentBuilder = ({ core }) => {
InferenceSession.create(samPlugin.data.modelURL).then((session) => {
samPlugin.data.session = session;
core.plugins.register(samPlugin);
});

return {
name: 'Segment Anything model',
destructor: () => {},
};
};

function register(): void {
if (Object.prototype.hasOwnProperty.call(window, 'cvatUI')) {
(window as any as { cvatUI: { registerComponent: PluginEntryPoint } })
.cvatUI.registerComponent(SAMModelPlugin);
}
}

window.addEventListener('plugins.ready', register, { once: true });
10 changes: 7 additions & 3 deletions cvat-ui/react_nginx.conf
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ server {
location / {
# Any route that doesn't exist on the server (e.g. /devices)
try_files $uri $uri/ /index.html;
add_header Cache-Control: "no-cache, no-store, must-revalidate";
add_header Pragma: "no-cache";
add_header Expires: 0;
add_header Cache-Control "no-cache, no-store, must-revalidate";
add_header Pragma "no-cache";
add_header Cross-Origin-Opener-Policy "same-origin";
add_header Cross-Origin-Embedder-Policy "credentialless";
add_header Expires 0;
}

location /assets {
expires 1y;
add_header Cache-Control "public";
add_header Cross-Origin-Embedder-Policy "require-corp";

access_log off;
}
}
Loading