<a href="https://colab.research.google.com/github/kgh1234/3DGS/blob/main/dust3r_gs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# dust3r 설치
!git clone --recursive https://github.com/naver/dust3r.git
%cd dust3r
!git submodule update --init --recursive

# 경로 설정
import sys
sys.path.append("/content/dust3r")


Cloning into 'dust3r'...
remote: Enumerating objects: 550, done.[K
remote: Counting objects: 100% (405/405), done.[K
remote: Compressing objects: 100% (208/208), done.[K
remote: Total 550 (delta 293), reused 197 (delta 197), pack-reused 145 (from 1)[K
Receiving objects: 100% (550/550), 732.88 KiB | 14.09 MiB/s, done.
Resolving deltas: 100% (319/319), done.
Submodule 'croco' (https://github.com/naver/croco) registered for path 'croco'
Cloning into '/content/dust3r/croco'...
remote: Enumerating objects: 124, done.        
remote: Counting objects: 100% (44/44), done.        
remote: Compressing objects: 100% (22/22), done.        
remote: Total 124 (delta 28), reused 22 (delta 22), pack-reused 80 (from 1)        
Receiving objects: 100% (124/124), 384.39 KiB | 3.53 MiB/s, done.
Resolving deltas: 100% (50/50), done.
Submodule path 'croco': checked out '743ee71a2a9bf57cea6832a9064a70a0597fcfcb'
/content/dust3r


In [2]:
# 수정된 is_symmetrized 적용 (한 줄로 덮어쓰기!)
!echo -e "def is_symmetrized(gt1, gt2):\n    return False" > /content/dust3r/dust3r/utils/misc.py


In [9]:
import importlib
import dust3r.utils.misc
importlib.reload(dust3r.utils.misc)


<module 'dust3r.utils.misc' from '/content/dust3r/dust3r/utils/misc.py'>

In [30]:
# GaussianOnly 모델 정의
import torch
import torch.nn as nn
from dust3r.model import AsymmetricCroCo3DStereo

class DUSt3R_GaussianOnly(nn.Module):
    def __init__(self, model_name="naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"):
        super().__init__()
        self.backbone = AsymmetricCroCo3DStereo.from_pretrained(model_name)
        hidden_dim = 768
        self.head_mu = nn.Linear(3, 3)         # 입력 차원은 3
        self.head_sigma = nn.Linear(3, 3)
        self.head_alpha = nn.Linear(3, 1)
        self.head_rgb = nn.Linear(3, 3)

    def forward(self, image1, image2):
        self.backbone.eval()
        view1 = {"img": image1, "mask": torch.ones_like(image1[:, :1, :, :])}
        view2 = {"img": image2, "mask": torch.ones_like(image2[:, :1, :, :])}
        with torch.no_grad():
            out_dict, _ = self.backbone(view1, view2)  # 🔥 여기가 핵심!
            feats = out_dict['pts3d']
            mu = feats  # 이미 3D 좌표니까
            sigma = torch.exp(self.head_sigma(feats))
            alpha = torch.sigmoid(self.head_alpha(feats))
            rgb = torch.sigmoid(self.head_rgb(feats))
        return {
            'mu': mu,
            'sigma': sigma,
            'alpha': alpha,
            'rgb': rgb
        }


In [31]:
# 이미지 로딩 및 시간 측정
from PIL import Image
import torchvision.transforms as T
import time

def load_image(path):
    image = Image.open(path).convert("RGB")
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor()
    ])
    return transform(image).unsqueeze(0).cuda()

@torch.no_grad()
def measure_inference_time_gs(model, img1, img2, repeat=10):
    model.eval()
    torch.cuda.synchronize()
    for _ in range(3):
        _ = model(img1, img2)
    times = []
    for _ in range(repeat):
        start = time.time()
        _ = model(img1, img2)
        torch.cuda.synchronize()
        end = time.time()
        times.append(end - start)
    return sum(times) / repeat


In [32]:
out_dict, _ = gs_model.backbone({
    "img": img1,
    "mask": torch.ones_like(img1[:, :1, :, :])
}, {
    "img": img2,
    "mask": torch.ones_like(img2[:, :1, :, :])
})

print(out_dict.keys())


dict_keys(['pts3d', 'conf'])


In [33]:
# 모델 생성 + 이미지 경로 수정
img1 = load_image("/content/000002.jpg")
img2 = load_image("/content/000003.jpg")

gs_model = DUSt3R_GaussianOnly().cuda().eval()
gs_time = measure_inference_time_gs(gs_model, img1, img2)
print(f"✅ GaussianOnly 평균 추론 시간: {gs_time:.4f}초")


✅ GaussianOnly 평균 추론 시간: 0.2227초
