Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

U2-net cloth segmentation model #77

Closed
lavandaboy opened this issue Jul 25, 2022 · 13 comments
Closed

U2-net cloth segmentation model #77

lavandaboy opened this issue Jul 25, 2022 · 13 comments

Comments

@lavandaboy
Copy link

Tutorial Select

Prepare Custom Model

Feedback

Hi Playtorch community,

I am trying to implement this model. It is based on U2-net but does the clothes segmentation. I converted it in the same way as I did for the usual U2-net model, using the tutorial I have previously posted.

I am using the U2-net snack snack as core, which perfectly works on my device using Playtorch app. Then I change the path to the converted model (https://cdn-128.anonfiles.com/v5l75ez1yc/431fccf2-1658318807/cloth_segm_live.ptl) in ImageMask.ts When I take a picture, nothing happens, I just see the camera UI.

Here is the link to my expo snack for cloth segmentation model.

I would appreciate any help with this issue.

@liuyinglao
Copy link
Contributor

Hi @lavandaboy

I tried with your example and got this error

Format error Exception raised from _load_for_mobile at /Users/distiller/project/torch/csrc/jit/mobile/import.cpp:623 (most recent call first): frame #0: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1912560 (0x103e1a704 in PlayTorch) frame #1: _ZN2at6native19requantize_from_intIN3c106qint32EEET_dxx + 3037040 (0x103acf500 in PlayTorch) frame #2: _ZN2at6native19requantize_from_intIN3c106qint32EEET_dxx + 3036236 (0x103acf1dc in PlayTorch) frame #3: xerbla_ + 978924 (0x103f55c10 in PlayTorch) frame #4: xerbla_ + 986472 (0x103f5798c in PlayTorch) frame #5: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1894016 (0x103e15e94 in PlayTorch) frame #6: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1895952 (0x103e16624 in PlayTorch) frame #7: _pthread_start + 148 (0x20fb989ac in libsystem_pthread.dylib) frame #8: thread_start + 8 (0x20fb97e68 in libsystem_pthread.dylib) "}
message:"Format error
Exception raised from _load_for_mobile at /Users/distiller/project/torch/csrc/jit/mobile/import.cpp:623 (most recent call first):
frame #0: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1912560 (0x103e1a704 in PlayTorch)
frame #1: _ZN2at6native19requantize_from_intIN3c106qint32EEET_dxx + 3037040 (0x103acf500 in PlayTorch)
frame #2: _ZN2at6native19requantize_from_intIN3c106qint32EEET_dxx + 3036236 (0x103acf1dc in PlayTorch)
frame #3: xerbla_ + 978924 (0x103f55c10 in PlayTorch)
frame #4: xerbla_ + 986472 (0x103f5798c in PlayTorch)
frame #5: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1894016 (0x103e15e94 in PlayTorch)
frame #6: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1895952 (0x103e16624 in PlayTorch)
frame #7: _pthread_start + 148 (0x20fb989ac in libsystem_pthread.dylib)
frame #8: thread_start + 8 (0x20fb97e68 in libsystem_pthread.dylib)

The exception seems to be thrown when you loaded the model with torch.jit._loadForMobile(), may I ask how is the model exported? to use the torch.jit._loadForMobile(), the binary file (suffixed with *.ptl) would need to be export with the pytorch API like this way

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
...
scriptified_module =  torch.jit.script(ModuleExtendingNNModule())
optimized_model = optimize_for_mobile(scriptified_module,  preserved_methods=["..."])
optimized_model._save_for_lite_interpreter("modulename.ptl")

@lavandaboy
Copy link
Author

Hi @liuyinglao,

Thanks for the prompt reply.

I re-converted the model. The link to it (https://cdn-144.anonfiles.com/85n1Ge0by6/622106f0-1658829111/cloth_segm_live.ptl) and I used the following code in Colab:

%cd /content/
!rm -rf cloth-segmentation
!git clone https://github.com/levindabhi/cloth-segmentation.git
%cd cloth-segmentation
!gdown --id 1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ
!mkdir input_images
!mkdir output_images

import os
# from tqdm import tqdm
from tqdm.notebook import tqdm
from PIL import Image
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.mobile_optimizer import optimize_for_mobile
from pathlib import Path

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu

from networks import U2NET
device = 'cuda'

image_dir = 'input_images'
result_dir = 'output_images'
checkpoint_path = 'cloth_segm_u2net_latest.pth'

def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette

transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()

scripted_model = torch.jit.script(net)
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter("cloth_segm_live.ptl")

print("model successfully exported")

I also updated the path to the model in my expo snack.

Let me know if I need to share anything else with you. Thanks!

@chrisklaiber
Copy link
Contributor

@lavandaboy your model export code looks good.

The "Format Error" in the log is coming from this line, which has failed to recognize a compatible model format, like zip: https://github.com/pytorch/pytorch/blob/v1.12.0/torch/csrc/jit/mobile/import.cpp#L623

That makes sense because the model URL used in the snack is redirecting to an HTML page:

% curl -I 'https://cdn-144.anonfiles.com/85n1Ge0by6/622106f0-1658829111/cloth_segm_live.ptl'
HTTP/1.1 301 Moved Permanently
Server: nginx
Date: Tue, 26 Jul 2022 18:56:05 GMT
Content-Type: text/html
Connection: close
Location: https://anonfiles.com/85n1Ge0by6        <== redirect
X-Cache-Host: filecache-01
X-Cache-Disk: ssd01
Accept-Ranges: bytes

% curl -I 'https://anonfiles.com/85n1Ge0by6'
HTTP/1.1 200 OK
Server: nginx
Date: Tue, 26 Jul 2022 18:56:16 GMT
Content-Type: text/html; charset=UTF-8            <== HTML page
Connection: keep-alive
Vary: Accept-Encoding
x-vdc: Yes
cache-control: public, max-age=60
x-oe: N
accept-ranges: bytes

@lavandaboy
Copy link
Author

Hi Chris,

Thanks for your message. Ok, I changed the URL to the direct Dropbox file (https://www.dl.dropboxusercontent.com/s/k9mm1b0c5xewpd7/cloth_segm_live.ptl)

Still no changes in the app. It seems to me that the app needs some time to download the model (176 Mb). I also added an alert in my snack to understand when the model is loaded because I do not know how to access the logs. Unfortunately, I never get this alert, waited for 10 minutes :) Thus, I think the model is too big for the Playtorch app.

What are your thoughts about this?

Thanks!

@raedle
Copy link
Contributor

raedle commented Jul 27, 2022

Hi @lavandaboy, the Dropbox download didn't work for me. I uploaded the model as a GitHub asset:

https://github.com/raedle/test-some/releases/download/v0.0.2.0/cloth_segm_live.ptl

The model loads, but it fails with an error on loading:

Format error Exception raised from _load_for_mobile at /Users/distiller/project/torch/csrc/jit/mobile/import.cpp:623 (most recent call first): frame #0: _ZN8facebook5react11JSIExecutor21defaultTimeoutInvokerERKNSt3__18functionIFvvEEENS3_IFNS2_12basic_stringIcNS2_11char_traitsIcEENS2_9allocatorIcEEEEvEEE + 1912560 (0x101c2a704 in PlayTorch) frame #1: _ZN2at6native19requantize_from_intIN3c106qint32EEET_dxx + 3037040 (0x1018df500 in PlayTorch)

image

The errors from the lite interpreter runtime can be opaque at times, and loading them in Python into the lite interpreter runtime can help getting the real error.

I created a Google Colab notebook, which downloads the model from GitHub assets and tries to load it into the lite interpreter runtime in Python.

Loading the model fails there too, and the error says:

RuntimeError: No CUDA GPUs are available ()

image

This means that the model was exported with cuda device. Changing the device from cuda to cpu (or removing the cuda device) when exporting the model might work.

You can test with the Google Colab notebook, and if the model loads there in the lite interpreter runtime, it should also load in PlayTorch

@raedle
Copy link
Contributor

raedle commented Jul 28, 2022

@lavandaboy, the following script exports a TorchScript model for the lite interpreter that loads in PlayTorch and returns inference results (e.g., run this in Google Colab):

%cd /content/
!rm -rf cloth-segmentation
!git clone https://github.com/levindabhi/cloth-segmentation.git
%cd cloth-segmentation
!gdown --id 1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET

checkpoint_path = 'cloth_segm_u2net_latest.pth'

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.eval()

scripted_model = torch.jit.script(net)
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter("cloth_segm_live.ptl")

print("model successfully exported")

The output is tuple with 7 tensors of rank-4 tensor (e.g., [{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]},{"dtype":"float32","shape":[1,4,224,224]}] where 224,224 is the H,W)

EDIT: The 4 channels are upper body clothes, lower body clothes, full body clothes, and background.

U2NET : This project uses an amazing U2NET as a deep learning model. Instead of having 1 channel output from u2net for typical salient object detection task it outputs 4 channels each respresting upper body cloth, lower body cloth, fully body cloth and background. Only categorical cross-entropy loss is used for a given version of the checkpoint.

Source: https://github.com/levindabhi/cloth-segmentation

I also uploaded the exported model to GitHub assets: https://github.com/raedle/test-some/releases/download/v0.0.2.0/cloth_segm_live_cpu.ptl

Hope that helps

@lavandaboy
Copy link
Author

Hi @raedle,

Thanks a lot for your help. As I understand, I need to modify the code below to convert info from tensors into actual mapping of clothes using the image captured by the cam?

def get_palette(num_cls):
    """Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i += 1
            lab >>= 3
    return palette

palette = get_palette(4)

images_list = sorted(os.listdir(image_dir))
pbar = tqdm(total=len(images_list))
for image_name in images_list:
    img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
    image_tensor = transform_rgb(img)
    image_tensor = torch.unsqueeze(image_tensor, 0)

    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()

    output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
    if do_palette:
        output_img.putpalette(palette)
    output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))

    pbar.update(1)

pbar.close()

@raedle
Copy link
Contributor

raedle commented Aug 1, 2022

@lavandaboy, yes, that's correct!

One caveat is that PlayTorch doesn't support all ops yet that are used in the post-processing for this cloth segmentation model. The good news is that you can still use the model if you create a model wrapper in Python that post-processes the model output and returns tensors that can be transformed in PlayTorch.

I prepared a model that works and created a simple demo in PlayTorch:

RPReplay_Final1659329628.MP4

Note: I haven't optimized anything and the model inference can take several seconds depending on your device. In my case, I used an iPhone 11 Pro, which takes ~11s for inference and another ~6s to convert masks to images.

Wrapped Model Export

The following shows at a high level what I did to prepare the model for PlayTorch. The full export is in this Google Colab: https://colab.research.google.com/drive/1pTLlcv2fSSQuO6ARdFACWfrttc5lZGQv?usp=sharing#scrollTo=ZwNRQ-38YhXg

Example model wrapper:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

class ModelWrapper(nn.Module):
    def __init__(self, model):
      super().__init__()
      self.model = model

    def get_tensor(self, output_tensor: torch.Tensor) -> torch.Tensor:
      output_tensor = F.log_softmax(output_tensor, dim=1)
      output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
      output_tensor = torch.squeeze(output_tensor, dim=0)
      return torch.squeeze(output_tensor, dim=0)

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
      return [self.get_tensor(res) for res in self.model(x)]

model = ModelWrapper(net)

Torchscript, optimize, and export wrapped model:

from torch.utils.mobile_optimizer import optimize_for_mobile

scripted_model = torch.jit.script(model)
optimized_model = optimize_for_mobile(scripted_model)
optimized_model._save_for_lite_interpreter("cloth_segm_live_wrapped.ptl")

The exported model can be used in PlayTorch and returns a list of 7 tensors, which are image masks with the predicted clothes in white.

Hope that helps!

@lavandaboy
Copy link
Author

Hi @raedle,

Huge thanks for your help! Everything works on my side. I am actually shocked how Pytorch models could be easily implemented on mobile devices.

I did not know about the wrapping of the model. Now it is much easier to export models. Regarding optimization, I have already contacted the author of the model to convert it using u2net_p (the small 8Mb model). Hope it will work much faster and we won't see a huge drop in terms of the results. As soon as I get the update from him, I will let you know in this thread.

Thanks once more @raedle!

@raedle
Copy link
Contributor

raedle commented Dec 1, 2022

Closing this issue due to inactivity. Please reopen if this is still ongoing!

@raedle raedle closed this as completed Dec 1, 2022
@lavandaboy
Copy link
Author

Hi @raedle,

I did some pause with this project. Now I got back to it. I am trying to get the coordinates of the area colored in white from the output image. Are there any solutions using playtorch to get these coordinates?

Thanks in advance!

@raedle
Copy link
Contributor

raedle commented Jan 5, 2023

@lavandaboy, what are you trying to achieve? Depending on you goal there might be other ways to approach the problem (e.g., if you want to subtract backgrounds and only keep salient objects). Are you looking for a way to get the bounding box, a convex hull, or something else?

@lavandaboy
Copy link
Author

Hi @raedle,

Thanks for your message. The idea is to get the polygons based on the coordinates and then display new styles of clothes instead of the initial ones.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants