Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The experiment reproduction question for 3D semantic segmentation on the 3D-OVS #20

Open
fine1998 opened this issue Feb 25, 2024 · 3 comments

Comments

@fine1998
Copy link

Thank you so much for sharing your great work.

When I used your code and pre-trained weights to reproduce the segmentation results for the "sofa" scene, my mIoU results were calculated to be just over 70. My current calculation method is to calculate mIoU separately for all "three relevancy maps" generated by any "positive" and then average them. I don't know if my calculation manner is correct.

In addition, the segmentation map generated for the 'grey sofa' in the 'sofa' scene is completely black. Do you have any suggestions to help me improve my results?

@minghanqin
Copy link
Owner

Hi,

Regarding the calculation of mIoU for the "sofa" scene, we calculate the average relevancy scores within the mask region for each relevancy map and select the relevancy map with the highest average response as the final predicted binary map.

Regarding the issue with the black segmentation maps you mentioned, could you provide more information? As I'm currently unclear about the specific problem you're encountering.

For more details on the implementation aspects of our work, please refer to the supplementary materials of our paper.

@fine1998
Copy link
Author

`import glob
import os
import numpy as np
import torch
import argparse
import shutil

import torchvision.transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from autoencoder.dataset import seg_dataset
from autoencoder.model import Autoencoder
from preprocess import OpenCLIPNetwork
from dataclasses import dataclass, field
from typing import Tuple, Type
from PIL import Image
from torchmetrics import JaccardIndex

if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--output_path', type=str, default='/home/amax//LangSplat/pretrained_model/langsplat_ckpt/sofa')
parser.add_argument('--dataset_name', type=str, default='sofa')
parser.add_argument('--gt_path', type=str, default='/data/langsplat/sofa/segmentations/')
parser.add_argument('--encoder_dims',
nargs='+',
type=int,
default=[256, 128, 64, 32, 3],
)
parser.add_argument('--decoder_dims',
nargs='+',
type=int,
default=[16, 32, 64, 128, 256, 256, 512],
)
args = parser.parse_args()

dataset_name = args.dataset_name
encoder_hidden_dims = args.encoder_dims
decoder_hidden_dims = args.decoder_dims
# ckpt_path = f"pretrained_model/autoencoder/sofa/best_ckpt.pth"
ckpt_path = f"/data/langsplat/sofa/sofa/autoencoder/ckpt//best_ckpt.pth"

# data_dir = f"{args.output_path}/train/ours_None/renders_npy"
data_dir = f'{args.output_path}/{args.dataset_name}'
output_dir = f"{args.output_path}/seg"


test_views = glob.glob(f"{args.gt_path}/*")
print("This is  ckpt_path:{}".format(ckpt_path))
test_views_str = sorted([item.split('/')[-1] for item in test_views if os.path.isdir(item)])
# print(test_views)

checkpoint = torch.load(ckpt_path)
train_dataset = seg_dataset(data_dir, test_views)

test_loader = DataLoader(
    dataset=train_dataset,
    batch_size=256*6075,
    shuffle=False,
    num_workers=16,
    drop_last=False
)
jaccard = JaccardIndex(task="binary", num_classes=2)
model = Autoencoder(encoder_hidden_dims, decoder_hidden_dims).to("cuda:0")

model.load_state_dict(checkpoint)
model.eval()


@dataclass
class OpenCLIPNetworkConfig:
    _target: Type = field(default_factory=lambda: OpenCLIPNetwork)
    clip_model_type: str = "ViT-B-16"
    clip_model_pretrained: str = "laion2b_s34b_b88k"
    clip_n_dims: int = 512
    negatives: Tuple[str] = ("object", "things", "stuff", "texture")
    positives: Tuple[str] = ('Pikachu',
                             'a stack of UNO cards',
                             'a red Nintendo Switch joy-con controller',
                             'Gundam',
                             'Xbox wireless controller',
                             'grey sofa')

CLIP_model = OpenCLIPNetwork(OpenCLIPNetworkConfig)
relevancy_masks = []
relevancys = []
gt_masks = []
for idx, feature in tqdm(enumerate(test_loader)):
    data = feature.to("cuda:0")
    with torch.no_grad():
        outputs = model.decode(data)
    # if idx == 0:
    features = outputs
    # else:
    #     features = np.concatenate([features, outputs], axis=0)
    # print(features.shape)
    for positive_id in range(len(OpenCLIPNetworkConfig.positives)):
        relevancy = CLIP_model.get_relevancy(features, positive_id).view(train_dataset.data_dic[idx % len(test_views_str)][0], train_dataset.data_dic[idx % len(test_views_str)][1], 2).permute(2,0,1)
        # print(relevancy)
        gt_mask_path = f'{args.gt_path}/{test_views_str[idx % len(test_views_str)]}/{OpenCLIPNetworkConfig.positives[positive_id]}.png'
        gt_mask = torchvision.transforms.ToTensor()(Image.open(gt_mask_path)).cuda()
        gt_mask = 0.2989 * gt_mask[0] + 0.5870 * gt_mask[1] + 0.1140 * gt_mask[2]
        # relevancy_mask = torch.argmin(relevancy, dim=0).float()
        relevancy_mask = torch.zeros_like(relevancy[0])
        relevancy_mask[relevancy[0] > 0.6] = 1
        # relevancy_mask = torch.nn.functional.interpolate(relevancy_mask.unsqueeze(0).unsqueeze(0), gt_mask.shape, mode='nearest')
        relevancy_masks.append(relevancy_mask.cpu())
        relevancys.append(relevancy[0].cpu())
        gt_mask[gt_mask>0.4] = 1

        gt_masks.append(gt_mask.cpu())
        relevancy_img = torchvision.transforms.ToPILImage()(relevancy_mask)
        os.makedirs(f'render/{idx // 5}/sofa/{test_views_str[idx % len(test_views_str)]}', exist_ok=True)
        relevancy_img.save(f'render/{idx // 5}/sofa/{test_views_str[idx % len(test_views_str)]}/{OpenCLIPNetworkConfig.positives[positive_id]}.png')
        # print(relevancy.shape)
        # print(gt_mask.shape)
        # miou = jaccard(relevancy.squeeze(0), gt_mask.unsqueeze(0))
        # MIoU.append(miou.data)
        # print(miou)
# print(sum(MIoU)/len(MIoU))

    # break
num_imgs = len(test_views_str)
MIoU = []
ACC = []
# print(len(relevancys))
for i in tqdm(range(num_imgs)):
    for positive_id in range(len(OpenCLIPNetworkConfig.positives)):
        relevancy_1 = relevancys[i*len(OpenCLIPNetworkConfig.positives) + positive_id]
        relevancy_2 = relevancys[i*len(OpenCLIPNetworkConfig.positives) + positive_id + len(OpenCLIPNetworkConfig.positives) * num_imgs]
        relevancy_3 = relevancys[i*len(OpenCLIPNetworkConfig.positives) + positive_id + len(OpenCLIPNetworkConfig.positives) * num_imgs * 2]


        relevancy_mask_1 = relevancy_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id]
        relevancy_mask_2 = relevancy_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id + len(OpenCLIPNetworkConfig.positives) * num_imgs]
        relevancy_mask_3 = relevancy_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id + len(OpenCLIPNetworkConfig.positives) * num_imgs * 2]
        # print(relevancy_1.shape)
        # print(relevancy_mask_1.shape)
        score_1 = torch.sum(relevancy_1 * relevancy_mask_1) / torch.sum(relevancy_mask_1)
        score_2 = torch.sum(relevancy_2 * relevancy_mask_2) / torch.sum(relevancy_mask_2)
        score_3 = torch.sum(relevancy_3 * relevancy_mask_3) / torch.sum(relevancy_mask_3)

        final_mask = [relevancy_mask_1, relevancy_mask_2, relevancy_mask_3][torch.argmax(torch.tensor([score_1, score_2, score_3]))]
        # print(score_1, score_2, score_3)
        # print(torch.argmax(torch.tensor([score_1, score_2, score_3])))
        relevancy_mask = torch.nn.functional.interpolate(final_mask.unsqueeze(0).unsqueeze(0), gt_masks[i].shape,
                                                         mode='nearest')
        relevancy_img = torchvision.transforms.ToPILImage()(relevancy_mask.squeeze(0))
        os.makedirs(f'render/sofa/{test_views_str[i % len(test_views_str)]}', exist_ok=True)
        relevancy_img.save(f'render/sofa/{test_views_str[i % len(test_views_str)]}/{OpenCLIPNetworkConfig.positives[positive_id]}.png')

        gt_img = torchvision.transforms.ToPILImage()(gt_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id].unsqueeze(0))
        os.makedirs(f'render/gt/{test_views_str[i % len(test_views_str)]}', exist_ok=True)
        gt_img.save(f'render/gt/{test_views_str[i % len(test_views_str)]}/{OpenCLIPNetworkConfig.positives[positive_id]}.png')

        miou = jaccard(relevancy_mask.squeeze(0), gt_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id].unsqueeze(0))
        acc = torch.sum(relevancy_mask.squeeze(0) == gt_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id].unsqueeze(0)) / (gt_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id].shape[0] * gt_masks[i*len(OpenCLIPNetworkConfig.positives) + positive_id].shape[1])
        ACC.append(acc)
        MIoU.append(miou)
        # print(f'{test_views_str[i % len(test_views_str)]}/{OpenCLIPNetworkConfig.positives[positive_id]}.png')
        print(miou)
        print(acc)
print(sum(MIoU) / len(MIoU))
print(sum(ACC) / len(ACC))

    # print(score_1, score_2, score_3)

`

This is my testing code, and I think the calculation manner is the same as your paper. Still, I have the black segmentation in the 'grey sofa'.
grey sofa

@minghanqin
Copy link
Owner

Hi,
We have open-sourced the evaluation part of the code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants