In [None]:
from segment_anything import SamPredictor, sam_model_registry
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import scipy.io as sio
import os 
import random 
import json
from torch.utils.data import Dataset, DataLoader

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 

In [None]:
class DefectDataset(Dataset):
    def __init__(self, data_root='defect/test_set.txt', label_dir='defect/data/labels_coco/labels',
                 height=1024, width=1024, heating_num=50, batchsize=8, bbox_shift=10):
        file = open(data_root, 'r')
        self.data_path = [line.strip() for line in file]
        file.close()

        self.label_dir = label_dir
        self.height = height
        self.width = width
        self.heating_num = heating_num
        self.batchsize = batchsize

        self.bbox_shift = bbox_shift
        print(f"number of samples: {len(self.data_path)}")

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, index):
        # load npy image (1024, 1024, 3), [0,1]
        file_path = self.data_path[index]
        basename = file_path.split('/')[-1].split('.')[0]
        label_path = os.path.join(self.label_dir,'{}_label.png'.format(basename))

        label_img = cv2.imread(label_path, 0)
        scale_x = self.width / label_img.shape[1]
        scale_y = self.height / label_img.shape[0]
        label_img = cv2.resize(label_img, (self.width, self.height))
        label_img[label_img < 120] = 120
        label_img[label_img != 120 ] = 0 #0,black,1 white
        label_img[label_img != 0] = 255 #label convert
        label_img = label_img / 255.

        data_struct = sio.loadmat(file_path)
        data = data_struct['data']
        t_len = data.shape[2]
        sub = data[:, :, -1]
        data = data[:, :, self.heating_num:min(t_len, self.heating_num+160)]
        data = data - np.tile(sub[:, :, np.newaxis], (1, 1, data.shape[2]))

        random_indices = np.random.choice(data.shape[2], size=self.batchsize, replace=False)
        data = data[:, :, random_indices]
        data = cv2.resize(data, (self.width, self.height))
        data = np.transpose(data, (2, 0, 1))
        # data = data / 255.

        labels = np.tile(label_img[np.newaxis, :, :], (data.shape[0], 1, 1))
        data = np.tile(data[:, np.newaxis, :, :], (1, 3, 1, 1))

        label_json = os.path.join(self.label_dir,'{}_label.json'.format(basename))
        bboxes = []
        with open(label_json, 'r') as fp:
            label_coord = json.load(fp)
            num_classes = len(label_coord['shapes'])
            masked_image = np.zeros((labels.shape[0], num_classes, labels.shape[1], labels.shape[2]))
            for i in range(num_classes):
                shapes = label_coord['shapes'][i]
                points = shapes['points']
                x_min, y_min = points[0][0], points[0][1]
                x_max, y_max = points[1][0], points[1][1]
                x_min = int(x_min * scale_x)
                x_max = int(x_max * scale_x)
                y_min = int(y_min * scale_y)
                y_max = int(y_max * scale_y)
                x_min = max(0, x_min - random.randint(0, self.bbox_shift))
                x_max = min(self.width, x_max + random.randint(0, self.bbox_shift))
                y_min = max(0, y_min - random.randint(0, self.bbox_shift))
                y_max = min(self.height, y_max + random.randint(0, self.bbox_shift))
                bboxes.append([x_min, y_min, x_max, y_max])
                masked_image[:, i, y_min:y_max, x_min:x_max] = labels[:, y_min:y_max, x_min:x_max]

        bboxes = np.array(bboxes)
        bboxes = np.tile(bboxes[np.newaxis, :, :], (data.shape[0], 1, 1))

        return (
            torch.tensor(masked_image).float(),
            torch.tensor(data).float(),
            torch.tensor(bboxes).float(),
            basename,
        )

In [None]:
width = 1024
height = 1024
heating_num=50
sample_rate=4
batchsize = 1
device = "cuda"
sam_checkpoint = "weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)
test_dataset = DefectDataset(data_root='defect/test_set.txt', height=height, width=width, 
                                 heating_num=heating_num, batchsize=sample_rate)
test_dataloader = DataLoader(
        test_dataset,
        batch_size=batchsize,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

In [None]:

# for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
#     data = torch.flatten(data, start_dim=0, end_dim=1)
#     labels = torch.flatten(labels, start_dim=0, end_dim=1)
#     bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
#     print(data.shape, labels.shape, bboxes.shape)

#     _, axs = plt.subplots(1, 2, figsize=(25, 25))
#     idx = random.randint(0, 4)
#     img = data[idx].cpu().permute(1, 2, 0).numpy()
#     axs[0].imshow(img/255.)
#     axs[0].axis("off")
#     boxes = bboxes[idx].cpu().numpy()
#     for box in boxes:
#         show_box(box, axs[0])
#     # set title
#     axs[0].set_title(names_temp[0])

#     img = labels[idx].cpu().permute(1, 2, 0).numpy()
#     axs[1].imshow(img[:, :, 0])
#     for box in boxes:
#         show_box(box, axs[1])
#     axs[1].axis("off")
#     # set title
#     axs[1].set_title(names_temp[0]+'label')
   
#     plt.tight_layout()
#     plt.show()
#     # plt.subplots_adjust(wspace=0.01, hspace=0)
#     # plt.savefig("./defect/data_sanitycheck_0.png", bbox_inches="tight", dpi=300)
#     # plt.close()
#     # show the example
#     print('ok')
#     break

In [None]:

# for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
#     data = torch.flatten(data, start_dim=0, end_dim=1)
#     labels = torch.flatten(labels, start_dim=0, end_dim=1)
#     bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
#     print(data.shape, labels.shape, bboxes.shape)
#     # boxes_np = bboxes.detach().cpu().numpy()
#     labels, data = labels.to(device), data.to(device)
#     bboxes = bboxes.to(device)
#     # data = data.permute(0, 2, 3, 1)
#     print(data.shape)
#     batched_output = []
#     for i in range(data.shape[0]):
#         print(data[i].shape)
#         predictor.set_image(data[i])
#         transformed_boxes = predictor.transform.apply_boxes_torch(bboxes[i], data[i].shape[:2])
#         masks, _, _ = predictor.predict_torch(
#             point_coords=None,
#             point_labels=None,
#             boxes=transformed_boxes,
#             multimask_output=False,
#         )
#         batched_output.append(masks)

#     break


In [None]:
#visualization

from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

def prepare_image(image, transform, device):
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device.device)
    print('before permute', image.shape)
    return image.permute(2, 0, 1).contiguous()


for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
    data = torch.flatten(data, start_dim=0, end_dim=1)
    labels = torch.flatten(labels, start_dim=0, end_dim=1)
    bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
    print(data.shape, labels.shape, bboxes.shape)
    boxes_np = bboxes.detach().cpu().numpy()
    # labels, data = labels.to(device), data.to(device)
    bboxes = bboxes.to(device)
    data_np = data.permute(0, 2, 3, 1)
    data_np = data_np.cpu().numpy()
    data_np = data_np.astype(np.uint8)
    print(data.shape)
    
    batched_input = []
    for i in range(data.shape[0]):
        img_tmp = prepare_image(data_np[i], resize_transform, sam)
        print(img_tmp.shape)
        box_tmp = resize_transform.apply_boxes_torch(bboxes[i], data[i].shape[:2])
        print(box_tmp.shape)
        input = {'image': prepare_image(data_np[i], resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(bboxes[i], data_np[i].shape[:2]),
         'original_size': data_np[i].shape[:2]}
        batched_input.append(input)
    batched_output = sam(batched_input, multimask_output=False)
    print('mask shape:', batched_output[0]['masks'].shape)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    img = data[0].cpu().numpy()
    img = img.transpose(1, 2, 0)
    # img = data_np[0].astype(float)
    ax.imshow(img/255.)

    for mask in batched_output[0]['masks']:
        show_mask(mask.cpu().numpy(), ax, random_color=False)
    
    for box in boxes_np[0]:
        show_box(box, ax)
    ax.axis("off")

    plt.tight_layout()
    # plt.shxianow()
    break


In [None]:
#calculate IOU

from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

def prepare_image(image, transform, device):
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device.device)
    # print('before permute', image.shape)
    return image.permute(2, 0, 1).contiguous()


def calculate_iou(y_hat, y):
    intersection = np.logical_and(y_hat, y)
    union = np.logical_or(y_hat, y)
    iou = np.sum(intersection) / np.sum(union)
    return iou

IOU_plane = []
IOU_R = []
R_type = ['036g', '029g', '035g', '012g']
for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
    data = torch.flatten(data, start_dim=0, end_dim=1)
    labels = torch.flatten(labels, start_dim=0, end_dim=1)
    bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
    print(data.shape, labels.shape, bboxes.shape)
    boxes_np = bboxes.detach().cpu().numpy()
    # labels, data = labels.to(device), data.to(device)
    bboxes = bboxes.to(device)
    data_np = data.permute(0, 2, 3, 1)
    data_np = data_np.cpu().numpy()
    data_np = data_np.astype(np.uint8)
    # print(data.shape)
    
    batched_input = []
    for i in range(data.shape[0]):
        img_tmp = prepare_image(data_np[i], resize_transform, sam)
        print(img_tmp.shape)
        box_tmp = resize_transform.apply_boxes_torch(bboxes[i], data[i].shape[:2])
        print(box_tmp.shape)
        input = {'image': prepare_image(data_np[i], resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(bboxes[i], data_np[i].shape[:2]),
         'original_size': data_np[i].shape[:2]}
        batched_input.append(input)
    batched_output = sam(batched_input, multimask_output=False)
    print('mask shape:', batched_output[0]['masks'].shape)


    labels = labels.cpu().numpy()
    for i in range(data.shape[0]):
        # img = data[i].cpu().numpy()
        # img = img.transpose(1, 2, 0)
        label_img = labels[0]
        label_img = label_img.astype(bool)
        pre_img = batched_output[i]['masks']
        pre_img = pre_img.squeeze(1)
        pre_img = pre_img.cpu().numpy()
        pre_img = pre_img.astype(bool)
        y = label_img[0]
        y_hat = pre_img[0]
        for j in range(label_img.shape[0]):
            y = np.logical_or(y, label_img[j])
        for j in range(pre_img.shape[0]):
            y_hat = np.logical_or(y_hat, pre_img[j])
        IOU = calculate_iou(y_hat, y)
        if names_temp[0] in R_type:
            IOU_R.append(IOU)
        else:
            IOU_plane.append(IOU)


print('avg plane IOU', sum(IOU_plane)/len(IOU_plane))
print('avg R IOU', sum(IOU_R)/len(IOU_R))


In [None]:
#generate images
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)

def prepare_image(image, transform, device):
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device.device)
    print('before permute', image.shape)
    return image.permute(2, 0, 1).contiguous()


for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
    data = torch.flatten(data, start_dim=0, end_dim=1)
    labels = torch.flatten(labels, start_dim=0, end_dim=1)
    bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
    print(data.shape, labels.shape, bboxes.shape)
    boxes_np = bboxes.detach().cpu().numpy()
    # labels, data = labels.to(device), data.to(device)
    bboxes = bboxes.to(device)
    data_np = data.permute(0, 2, 3, 1)
    data_np = data_np.cpu().numpy()
    data_np = data_np.astype(np.uint8)
    print(data.shape)
    
    batched_input = []
    for i in range(data.shape[0]):
        img_tmp = prepare_image(data_np[i], resize_transform, sam)
        print(img_tmp.shape)
        box_tmp = resize_transform.apply_boxes_torch(bboxes[i], data[i].shape[:2])
        print(box_tmp.shape)
        input = {'image': prepare_image(data_np[i], resize_transform, sam),
         'boxes': resize_transform.apply_boxes_torch(bboxes[i], data_np[i].shape[:2]),
         'original_size': data_np[i].shape[:2]}
        batched_input.append(input)
    batched_output = sam(batched_input, multimask_output=False)
    print('mask shape:', batched_output[0]['masks'].shape)

    for i in range(data.shape[0]):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111)
        img = data[i].cpu().numpy()
        img = img.transpose(1, 2, 0)
        # img = data_np[0].astype(float)
        ax.imshow(img/255.)

        # for mask in batched_output[i]['masks']:
        #     show_mask(mask.cpu().numpy(), ax, random_color=False)
        
        # for box in boxes_np[i]:
        #     show_box(box, ax)
        ax.axis("off")

        plt.tight_layout()
        plt.savefig('defect/output/original/{}_{}.png'.format(names_temp[0], i))
        # plt.shxianow()
 


In [None]:
for step, (labels, data, bboxes, names_temp) in enumerate(test_dataloader):
    data = torch.flatten(data, start_dim=0, end_dim=1)
    labels = torch.flatten(labels, start_dim=0, end_dim=1)
    bboxes = torch.flatten(bboxes, start_dim=0, end_dim=1)
    print(data.shape, labels.shape, bboxes.shape)
    data_np = data.permute(0, 2, 3, 1)
    data_np = data_np.cpu().numpy()
    data_np = data_np.astype(np.uint8)
    print(data.shape)