<a href="https://colab.research.google.com/github/jhlee508/Colab/blob/master/Cloth_Segm_u2net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%cd /content/
!rm -rf cloth-segmentation
!git clone https://github.com/levindabhi/cloth-segmentation.git
%cd cloth-segmentation
!mkdir input_images
!mkdir output_images
!mkdir trained_checkpoint
%cd trained_checkpoint/
!gdown --id 1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ

/content
Cloning into 'cloth-segmentation'...
remote: Enumerating objects: 62, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 62 (delta 5), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (62/62), done.
/content/cloth-segmentation
/content/cloth-segmentation/trained_checkpoint
Downloading...
From: https://drive.google.com/uc?id=1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ
To: /content/cloth-segmentation/trained_checkpoint/cloth_segm_u2net_latest.pth
177MB [00:01, 145MB/s]


# Upload input images in *input_images* folder.



In [None]:
from google.colab import files
%cd ../input_images
uploaded = files.upload()

/content/cloth-segmentation/input_images


Saving KakaoTalk_20210909_155907338.jpg to KakaoTalk_20210909_155907338.jpg


In [None]:
%cd /content/cloth-segmentation

/content/cloth-segmentation


In [None]:
!python infer.py

# Infer Code

In [None]:
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu

from networks import U2NET

device = "cuda"

image_dir = "input_images"
result_dir = "output_images"
checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
do_palette = False


def get_palette(num_cls):
    """Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i += 1
            lab >>= 3
    return palette


transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()

palette = get_palette(4)

images_list = sorted(os.listdir(image_dir))
pbar = tqdm(total=len(images_list))

for image_name in images_list:
    img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
    image_tensor = transform_rgb(img)
    image_tensor = torch.unsqueeze(image_tensor, 0)

    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()
    x = output_arr

    # Change to RGB
    output_arr = np.repeat(output_arr[:, :, np.newaxis], 3, axis=2)
    output_arr[np.where((output_arr==[1, 1, 1]).all(axis=2))] = [0, 0, 255]
    output_arr[np.where((output_arr==[2, 2, 2]).all(axis=2))] = [255, 0, 0]

    output_img = Image.fromarray(output_arr.astype("uint8"), mode="RGB")

    if do_palette:
        output_img.putpalette(palette)
    output_img.save(os.path.join(result_dir, image_name[:-3] + "jpg"))

    pbar.update(1)

pbar.close()

----checkpoints loaded from path: trained_checkpoint/cloth_segm_u2net_latest.pth----


  "See the documentation of nn.Upsample for details.".format(mode)
100%|██████████| 1/1 [00:00<00:00, 10.40it/s]


In [None]:
!rm -r input_images/.ipynb_checkpoints

# Output Array

In [None]:
output_np = x

In [None]:
output_np.shape

(500, 500)

# Output Masks

In [None]:
%mkdir masks

In [None]:
# background
back_np = output_np
back_np = np.where(back_np != 0, 42, back_np)
back_np = np.where(back_np == 0, 1, back_np)
back_np = np.where(back_np == 42, 0, back_np)
back_np = np.repeat(back_np[:, :, np.newaxis], 3, axis=2)

masked_background = Image.fromarray(back_np.astype("uint8"), mode="RGB")
masked_background.save("masks/masked_background.jpg")

In [None]:
# class 1
class1_np = output_np
class1_np = np.where(class1_np != 1, 42, class1_np)
class1_np = np.where(class1_np == 1, 1, class1_np)
class1_np = np.where(class1_np == 42, 0, class1_np)
class1_np = np.repeat(class1_np[:, :, np.newaxis], 3, axis=2)

masked_class1 = Image.fromarray(class1_np.astype("uint8"), mode="RGB")
masked_class1.save("masks/masked_class1.jpg")

In [None]:
# class 2
class2_np = output_np
class2_np = np.where(class2_np != 2, 42, class2_np)
class2_np = np.where(class2_np == 2, 1, class2_np)
class2_np = np.where(class2_np == 42, 0, class2_np)
class2_np = np.repeat(class2_np[:, :, np.newaxis], 3, axis=2)

masked_class2 = Image.fromarray(class2_np.astype("uint8"), mode="RGB")
masked_class2.save("masks/masked_class2.jpg")

In [None]:
# class 3
class3_np = output_np
class3_np = np.where(class3_np != 3, 42, class3_np)
class3_np = np.where(class3_np == 3, 1, class3_np)
class3_np = np.where(class3_np == 42, 0, class3_np)
class3_np = np.repeat(class3_np[:, :, np.newaxis], 3, axis=2)

masked_class3 = Image.fromarray(class1_np.astype("uint8"), mode="RGB")
masked_class3.save("masks/masked_class3.jpg")

# Save Original Image

In [None]:
origin = img
origin.save("origin.jpg")

# Get Segmentation Images

In [None]:
!mkdir imgs

In [None]:
# Background Seg Images
back_imgs = back_np * img

background_seg_imgs = Image.fromarray(back_imgs.astype("uint8"), mode="RGB")
background_seg_imgs.save("imgs/background.jpg")

In [None]:
# Class1 Seg Images
class1_imgs = class1_np * img

class1_seg_imgs = Image.fromarray(class1_imgs.astype("uint8"), mode="RGB")
class1_seg_imgs.save("imgs/class1.jpg")

In [None]:
# Class2 Seg Images
class2_imgs = class2_np * img

class2_seg_imgs = Image.fromarray(class2_imgs.astype("uint8"), mode="RGB")
class2_seg_imgs.save("imgs/class2.jpg")

In [None]:
# Class3 Seg Images
class3_imgs = class3_np * img

class3_seg_imgs = Image.fromarray(class3_imgs.astype("uint8"), mode="RGB")
class3_seg_imgs.save("imgs/class3.jpg")

### 단색에서 HSV 값  추출

In [None]:
%mkdir one_color

In [None]:
one_color = cv2.imread('check_color/purl.jpg') # 이미지 파일을 컬러로 불러옴
height, width = one_color.shape[:2] # 이미지의 높이와 너비 불러옴, 가로 [0], 세로[1]

one_hsv = cv2.cvtColor(one_color, cv2.COLOR_BGR2HSV) # cvtColor 함수를 이용하여 hsv 색공간으로 변환
center_h = (height // 2)
center_w = (width // 2)

one_h = one_hsv[center_h][center_w][0]
one_s = one_hsv[center_h][center_w][1]
one_v = one_hsv[center_h][center_w][2]

print(one_h, one_s, one_v)

140 255 154


### 단색 입히기

In [None]:
test = class1_imgs
hsv = cv2.cvtColor(test.astype("uint8"), cv2.COLOR_RGB2HSV)

(h, s, v) = cv2.split(hsv)

eval_s = s.sum() // np.count_nonzero(s)
eval_v = v.sum() // np.count_nonzero(v)

h[:, :] = one_h
# s[:, :] = s + one_s - eval_s
s[:, :] = np.where(s + (one_s - eval_s) > 255, 255, s + (one_s - eval_s))
# v[:, :] = v + one_v - eval_v
v[:, :] = np.where(v + (one_v - eval_v) > 255, 255, v + (one_v - eval_v))

hsv = cv2.merge((h, s, v))

rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
test = rgb * class1_np
test = test + background_seg_imgs
test = Image.fromarray(test.astype("uint8"), mode="RGB")
test.save("imgs/one_color_changed.jpg")

## 체크 무늬 입히기

In [None]:
!mkdir check_color

In [None]:
check_color = cv2.imread('check_color/elastic_trans.jpg') 
check_rgb = cv2.cvtColor(check_color, cv2.COLOR_BGR2RGB)

check_np = class1_np * check_rgb
check_person = check_np + background_seg_imgs

check_seg_imgs = Image.fromarray(check_person.astype("uint8"), mode="RGB")
check_seg_imgs.save("imgs/test.jpg")

In [None]:
check_color = cv2.imread('check_color/elastic_trans.jpg')
height, width = check_color.shape[:2]

check_hsv = cv2.cvtColor(check_color, cv2.COLOR_BGR2HSV)

In [None]:
check_hsv.shape

(500, 500, 3)

In [None]:
test = class1_imgs
hsv = cv2.cvtColor(test.astype("uint8"), cv2.COLOR_RGB2HSV)

(h, s, v) = cv2.split(hsv)

aver_s = s.sum() // np.count_nonzero(s)
aver_v = v.sum() // np.count_nonzero(v)

max_s = s.mean() 
max_v = v.mean() 

for i in range(h.shape[0]):
    for j in range(h.shape[1]):
        check_h = check_hsv[i][j][0] # h
        check_s = check_hsv[i][j][1] # s
        check_v = check_hsv[i][j][2] # v
         
        h[i][j] = check_h
        s[i][j] = np.where(abs(s[i][j] + (check_s - aver_s)) > 255, 255, s[i][j] + abs((check_s - aver_s)))
        #s[i][j] = s[i][j] + check_s
        v[i][j] = np.where(abs(v[i][j] + (check_v - aver_v)) > 255, 255, abs(v[i][j] + (check_v - aver_v)))
        #v[i][j] = v[i][j] + check_v

hsv = cv2.merge((h, s, v))

rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
test = rgb * class1_np
test = test + background_seg_imgs
test = Image.fromarray(test.astype("uint8"), mode="RGB")
test.save("imgs/test2.jpg")

## Elastic Transformation

In [None]:
import numpy as np
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter

def elastic_transform(image, alpha, sigma, random_state=None):
    """Elastic deformation of images as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
    """
    if random_state is None:
        random_state = np.random.RandomState(None)

    shape = image.shape
    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    dz = np.zeros_like(dx)

    x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]))
    print(x.shape)
    indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)), np.reshape(z, (-1, 1))

    distorted_image = map_coordinates(image, indices, order=1, mode='reflect')
    return distorted_image.reshape(image.shape)

In [None]:
image = cv2.imread('check_color/check_clothe.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

transformed = elastic_transform(image, alpha=700, sigma=9)
transformed_image = Image.fromarray(transformed.astype("uint8"), mode="RGB")
transformed_image.save("check_color/elastic_trans.jpg")

(500, 500, 3)


In [None]:
check_color.shape

(500, 500, 3)

# Make Style

In [None]:
x = np.where(x != 1, 0, x)
x = np.where(x != 0, 255, x)

In [None]:
np.unique(x[3])

array([  0, 255])

In [None]:
x = np.repeat(x[:, :, np.newaxis], 3, axis=2)

In [None]:
x.shape

(456, 456, 3)

In [None]:
x = Image.fromarray(x.astype("uint8"), mode="RGB")
x.save("person.jpg")

# Upload Clothe

In [None]:
clothe = cv2.imread('clothe3.jpg')

In [None]:
clothe = np.array(clothe)

In [None]:
clothe.shape

(456, 456, 3)

# Add Style & Background

In [None]:
added_style = clothe + seg

back_image = Image.fromarray(seg_np.astype("uint8"), mode="RGB")
back_image.save("seg_background.jpg")

added_style = Image.fromarray(added_style.astype("uint8"), mode="RGB")
added_style.save("added_style.jpg")

In [None]:
added_image = clothe * x

added_style = added_image + seg_np

added_image = Image.fromarray(added_image.astype("uint8"), mode="RGB")
added_image.save("added_image.jpg")

added_style = Image.fromarray(added_style.astype("uint8"), mode="RGB")
added_style.save("added_style.jpg")

# Blended Image

In [None]:
blended = Image.blend(img, added_image, alpha=0.5)    
blended.save(os.path.join("blended.jpg"))

# Infer WebCam

In [None]:
import os

from tqdm import tqdm
from PIL import Image
import numpy as np

import warnings
import cv2

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu

from networks import U2NET

device = "cuda"

checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
do_palette = True


def get_palette(num_cls):
    """Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i += 1
            lab >>= 3
    return palette


transforms_list = []
transforms_list += [transforms.ToTensor()]
transforms_list += [Normalize_image(0.5, 0.5)]
transform_rgb = transforms.Compose(transforms_list)

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
net = net.to(device)
net = net.eval()

palette = get_palette(4)

VideoSignal = cv2.VideoCapture(0)

while True:
    ret, frame = VideoSignal.read()
    if frame == None:
        print("fail")
        break 
    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    image_tensor = transform_rgb(img)
    image_tensor = torch.unsqueeze(image_tensor, 0)

    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()
    output_arr = np.array(output_arr, dtype=np.uint8)
    result = frame * np.repeat(output_arr[:, :, np.newaxis], 3, axis=2)

    cv2.imshow('frame', result)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# When everything done, release the capture
VideoSignal.release()

cv2.destroyAllWindows()

----checkpoints loaded from path: trained_checkpoint/cloth_segm_u2net_latest.pth----
fail


# Download results from *output_images*

In [None]:
!rm -rf output_images/*