In [1]:
import cv2
import sys, os
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd
from collections import defaultdict
from segment_anything import SamPredictor, sam_model_registry
#from sam_segment import predict_masks_with_sam
from utils import load_img_to_array, save_array_to_img

In [2]:
image_folder = 'data/images/'
label_folder = 'data/labels/'
mask_output_folder = 'data/masks/'  # 마스크를 저장할 폴더 경로

In [3]:
file_names = [os.path.splitext(f)[0] for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png')]
print(len(file_names))

253


In [4]:
def normalized_coordinate_to_absolute(norm_x, norm_y, file_name):
    
    image_width, image_height=1280,720
    if len(file_name) > 40:
        image_width, image_height= 1920, 1200
        
    abs_x = int(norm_x * image_width)
    abs_y = int(norm_y * image_height)
    return [abs_x, abs_y]

In [5]:
def yolo_to_xyxy(norm_x_center, norm_y_center, norm_width, norm_height, file_name):
    abs_center_x, abs_center_y = normalized_coordinate_to_absolute(norm_x_center, norm_y_center, file_name)
    abs_width, abs_height = normalized_coordinate_to_absolute(norm_width, norm_height, file_name)

    # Calculate the top-left and bottom-right coordinates
    x_min = abs_center_x - abs_width // 2
    y_min = abs_center_y - abs_height // 2
    x_max = abs_center_x + abs_width // 2
    y_max = abs_center_y + abs_height // 2

    return [x_min, y_min, x_max, y_max]

In [6]:
def extract_coordinates(txt_folder_path, img_folder_path):
    # 이미지 폴더에서 모든 파일명을 가져옴
    file_names = [os.path.splitext(f)[0] for f in os.listdir(img_folder_path) if f.endswith('.jpg') or f.endswith('.png')]

    # 데이터를 저장할 리스트 초기화
    data = []

    for file_name in file_names:
        txt_file = os.path.join(txt_folder_path, file_name + '.txt')

        # 해당 .txt 파일이 존재하는지 확인
        if os.path.exists(txt_file):
            with open(txt_file, 'r') as file:
                lines = file.readlines()
                for line in lines:
                    try:
                        class_id, x_center, y_center, width, height = line.strip().split()
                        data.append({
                            "file_id": file_name,
                            "x_center": float(x_center),
                            "y_center": float(y_center),
                            "width": float(width),
                            "height": float(height),
                            "label": int(class_id)
                        })
                    except ValueError:
                        print(f"Line parsing error in file {file_name}: {line}")
        else:
            print(f"No annotation for image {file_name}")

    # 데이터프레임 생성
    return pd.DataFrame(data)

In [7]:
df = extract_coordinates(txt_folder_path, img_folder_path)


In [17]:
tp = df[1010:1020]
tp

Unnamed: 0,file_id,x_center,y_center,width,height,label
1010,20231031_4b0462d1-da22-483d-92e3-f5769d89c144,0.474242,0.799283,0.089063,0.156667,14
1011,20231031_4b0462d1-da22-483d-92e3-f5769d89c144,0.486758,0.820398,0.151042,0.178333,13
1012,20231031_4b0462d1-da22-483d-92e3-f5769d89c144,0.518173,0.959274,0.125521,0.080833,23
1013,20231031_4b0462d1-da22-483d-92e3-f5769d89c144,0.972164,0.217613,0.055208,0.1525,33
1014,20231031_4b0462d1-da22-483d-92e3-f5769d89c144,0.047656,0.469692,0.095312,0.23,33
1015,2023-04-21_59.mp4#t=140,0.402481,0.870213,0.173673,0.257353,23
1016,2023-04-21_59.mp4#t=140,0.353894,0.722541,0.11785,0.216912,0
1017,2023-04-21_59.mp4#t=140,0.366644,0.348154,0.136458,0.29902,0
1018,2023-04-21_59.mp4#t=140,0.388697,0.590744,0.099242,0.232958,2
1019,2023-04-21_59.mp4#t=140,0.576844,0.76911,0.062026,0.265931,0


In [9]:
input_img = '../wim_data/train/images/2023-04-21_48.mp4#t=0.jpg'
point_labels = [1]
sam_model_type = "vit_h"
sam_ckpt = './pretrained_models/sam_vit_h_4b8939.pth'

In [10]:
len(df)

17236

In [15]:
df1 = df.iloc[[0]].copy()
df1

Unnamed: 0,file_id,x_center,y_center,width,height,label
0,2023-04-24_45.mp4#t=280,0.462414,0.574525,0.167586,0.25,23


In [23]:

df1

Unnamed: 0,file_id,x_center,y_center,width,height,label
0,2023-04-24_45.mp4#t=280,0.462414,0.574525,0.167586,0.250000,23
1,2023-04-24_45.mp4#t=280,0.620000,0.470358,0.124138,0.127451,1
2,2023-04-24_45.mp4#t=280,0.625862,0.374157,0.084828,0.121324,23
3,2023-04-24_45.mp4#t=280,0.646897,0.282246,0.053793,0.111520,25
4,2023-04-24_45.mp4#t=280,0.518276,0.311045,0.080690,0.147059,12
...,...,...,...,...,...,...
17231,2023-04-24_64.mp4#t=140,0.637330,0.193738,0.101226,0.210913,8
17232,2023-04-24_64.mp4#t=140,0.604080,0.061949,0.155527,0.117286,0
17233,2023-04-24_64.mp4#t=140,0.465657,0.217663,0.124204,0.243011,13
17234,2023-04-24_64.mp4#t=140,0.358363,0.383565,0.097756,0.130002,3


In [11]:
ob_bd_path = '../wim_data/objects/images_bounding1/'
ob_path = '../wim_data/objects/images_object1/'
mask_object = '../wim_data/objects/masks_object1/'
mask_full_path = '../wim_data/objects/masks_full1/'

if not os.path.exists(ob_bd_path):
    os.makedirs(ob_bd_path)
if not os.path.exists(ob_path):
    os.makedirs(ob_path)
if not os.path.exists(mask_object):
    os.makedirs(mask_object)
if not os.path.exists(mask_full_path):
    os.makedirs(mask_full_path)

In [13]:
#point_labels = [1]
sam_model_type = "vit_h"
sam_ckpt = './pretrained_models/sam_vit_h_4b8939.pth'
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
boxes_and_labels_per_image  = defaultdict(list)
#latest_coords = normalized_coordinate_to_absolute(x_center, y_center, row.file_id)


for index, row in tp.iterrows():
    x_center, y_center = row.x_center, row.y_center
    norm_width, norm_height = row.width, row.height
    
    input_box = yolo_to_xyxy(x_center, y_center, norm_width, norm_height, row.file_id)
    boxes_and_labels_per_image[row.file_id].append((input_box, row.label))

In [19]:
for file_id, boxes_and_labels  in boxes_and_labels_per_image.items():
    print(file_id)
    for box, label in boxes_and_labels:
        print(box)
        print(label)    

20231031_4b0462d1-da22-483d-92e3-f5769d89c144
[825, 865, 995, 1053]
14
[789, 877, 1079, 1091]
13
[874, 1103, 1114, 1199]
23
[1813, 170, 1919, 352]
33
[0, 425, 182, 701]
33
2023-04-21_59.mp4#t=140
[404, 534, 626, 718]
23
[377, 442, 527, 598]
0
[382, 143, 556, 357]
0
[434, 342, 560, 508]
2
[699, 458, 777, 648]
0


In [16]:
sam = sam_model_registry[sam_model_type](checkpoint=sam_ckpt).to(device=device)
mask_predictor = SamPredictor(sam)

print('process start')
# Process each image
for file_id, boxes_and_labels in boxes_and_labels_per_image.items():
    labels = []
    input_boxs =[]
    
    input_img_path = img_folder_path + file_id + '.jpg'
    img = load_img_to_array(input_img_path)  # Load the image
    print('1')
    for box, label in boxes_and_labels:
        labels.append(label)
        input_boxs.append(box)
    input_box_tensor = torch.tensor(input_boxs, device=device)
    #input_box = yolo_to_xyxy(x_center, y_center, norm_width, norm_height)
    print('2')
    mask_predictor.set_image(img)
    print('transforming...')
    transformed_boxes = mask_predictor.transform.apply_boxes_torch(input_box_tensor, img.shape[:2])
    
    print("Start SAM")
    masks, scores, _ = mask_predictor.predict_torch(
    boxes = transformed_boxes,
    multimask_output=True,
    point_coords=None,
    point_labels=None
    )
    print("Ends SAM")
    for i in range(masks.shape[0]): ## object 개수
        object_masks = masks[i]  # Masks for the current object
        object_scores = scores[i]  # Scores for the current object
        # Find the mask index with the highest score
        print(f'scores: {scores[i]}')
        best_mask_index = torch.argmax(object_scores).item()

        # Extract the best mask
        best_mask = object_masks[best_mask_index].cpu().numpy().astype(np.uint8) * 255

        # masks = masks.astype(np.uint8)* 255
        # mask_full = masks[2]
    
        extracted_object = cv2.bitwise_and(img, img, mask=best_mask)
        mask_full_cropped = cv2.bitwise_and(best_mask, best_mask)
        
        x, y, w, h = cv2.boundingRect(best_mask)
        cropped_image = extracted_object[y:y+h, x:x+w]
        cropped_mask = mask_full_cropped[y:y+h, x:x+w]
        
        common_name = f"{file_id}_{labels[i]}_{i:02}.png"
        
        cropped_object = cv2.bitwise_and(cropped_image, cropped_image, mask=cropped_mask)
        
        cropped_bd_object = cv2.cvtColor(extracted_object[y:y+h, x:x+w], cv2.COLOR_BGR2RGB)
        
        # 배경을 투명하게 만들기 위해 알파 채널 추가
        b_channel, g_channel, r_channel = cv2.split(cropped_object)
        alpha_channel = np.where(cropped_mask==255, 255, 0).astype(np.uint8)  # 마스크에 따라 알파 채널 설정
        rgba_image = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
        
        save_array_to_img(cropped_mask, mask_object+common_name)
        save_array_to_img(best_mask, mask_full_path+common_name)
        save_array_to_img(rgba_image, ob_path+common_name)
        save_array_to_img(cropped_bd_object, ob_bd_path+common_name)

Start SAM
Ends SAM
scores: tensor([0.9680, 0.9983, 0.9984], device='cuda:0')
scores: tensor([0.9218, 0.9281, 0.9206], device='cuda:0')
scores: tensor([0.8777, 0.8774, 0.9091], device='cuda:0')
scores: tensor([0.9843, 0.9928, 0.9874], device='cuda:0')


In [None]:
 x_center, y_center = row.x_center, row.y_center
    norm_width, norm_height = row.width, row.height
     
    
    input_box = yolo_to_xyxy(x_center, y_center, norm_width, norm_height)
    
    #latest_coords = normalized_coordinate_to_absolute(x_center, y_center, row.file_id)
    sam = sam_model_registry[sam_model_type](checkpoint=sam_ckpt).to(device=device)
    mask_predictor = SamPredictor(sam)
    mask_predictor.set_image(img)
    
    transformed_boxes = mask_predictor.transform.apply_boxes_torch(.boxes.xyx, img.shape[:2])
    
    masks = masks.astype(np.uint8)* 255
    mask_full = masks[2]
    
    extracted_object = cv2.bitwise_and(img, img, mask=mask_full)
    mask_full_cropped = cv2.bitwise_and(mask_full, mask_full)
    
    x, y, w, h = cv2.boundingRect(mask_full)
    cropped_image = extracted_object[y:y+h, x:x+w]
    cropped_mask = mask_full_cropped[y:y+h, x:x+w]
    
    common_name = f"{row.file_id}_{row.label}_{row.x_center}_{row.y_center}.png"
    
    cropped_object = cv2.bitwise_and(cropped_image, cropped_image, mask=cropped_mask)
    
    cropped_bd_object = cv2.cvtColor(extracted_object[y:y+h, x:x+w], cv2.COLOR_BGR2RGB)
    
    # 배경을 투명하게 만들기 위해 알파 채널 추가
    b_channel, g_channel, r_channel = cv2.split(cropped_object)
    alpha_channel = np.where(cropped_mask==255, 255, 0).astype(np.uint8)  # 마스크에 따라 알파 채널 설정
    rgba_image = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
    
    save_array_to_img(cropped_mask, mask_object+common_name)
    save_array_to_img(mask_full, mask_full_path+common_name)
    save_array_to_img(rgba_image, ob_path+common_name)
    save_array_to_img(cropped_bd_object, ob_bd_path+common_name)

In [None]:
image_bgr = cv2.imread("{}/{}".format(HOME, os.path.basename(IMAGE_PATH)), cv2.IMREAD_COLOR)

transformed_boxes = mask_predictor.transform.apply_boxes_torch(results[0].boxes.xyxy, image_bgr.shape[:2])

mask_predictor.set_image(image_bgr)

masks, scores, logits = mask_predictor.predict_torch(
    boxes = transformed_boxes,
    multimask_output=False,
    point_coords=None,
    point_labels=None
)
masks = np.array(masks.cpu())

In [None]:


masks, _, _ = predict_masks_with_sam(
        img,
        [latest_coords],
        point_labels,
        model_type=sam_model_type,
        ckpt_p=sam_ckpt,
        device=device,
    )