-
Notifications
You must be signed in to change notification settings - Fork 6
/
__init__.py
80 lines (69 loc) · 2.4 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import huggingface_hub
import torch
import onnxruntime as rt
import numpy as np
import cv2
def get_mask(img:torch.Tensor, s=1024):
img = (img / 255).astype(np.float32)
h, w = h0, w0 = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
mask = rmbg_model.run(None, {'img': img_input})[0][0]
mask = np.transpose(mask, (1, 2, 0))
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
return mask
# Declare Execution Providers
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
# Download and host the model
model_path = huggingface_hub.hf_hub_download(
"skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)
def rmbg_fn(img):
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return img
class RemoveImageBackgroundabg:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "abg_remover"
CATEGORY = "image"
def abg_remover(self, image:torch.Tensor):
npa = image2nparray(image)
print(npa.ndim)
rmb = rmbg_fn(npa)
image = nparray2image(rmb)
return (image,)
def image2nparray(image:torch.Tensor):
narray:np.array = np.clip(255. * image.cpu().numpy().squeeze(),0, 255).astype(np.uint8)
if narray.shape[-1] == 4:
narray = narray[..., [2, 1, 0, 3]] # For RGBA
else:
narray = narray[..., [2, 1, 0]] # For RGB
return narray
def nparray2image(narray:np.array):
print(f"narray shape: {narray.shape}")
if narray.shape[-1] == 4:
narray = narray[..., [2, 1, 0, 3]]
else:
narray = narray[..., [2, 1, 0]]
tensor = torch.from_numpy(narray/255.).float().unsqueeze(0)
return tensor
NODE_CLASS_MAPPINGS = {
"Remove Image Background (abg)": RemoveImageBackgroundabg
}