# DSSIM
Validating the DSSIM loss on gradient descent task

In [None]:
import cv2
import os
import random
import matplotlib.pyplot as plt
from pathlib import Path
import torch

from gaussian_splatting.model import View
from gaussian_splatting.colmap import parse_cameras, parse_images, parse_points3d, clean_text
from gaussian_splatting.model.loss import DSSIM, L1

random.seed(42)

In [None]:
dataset = "cat"

In [None]:
# parsing images colmap output
base_path = Path("../data") / dataset
images = {
    image_name: cv2.imread(str(base_path / f"images/{image_name}"))[:, :, ::-1] / 255
    for image_name in os.listdir(base_path / f"images")
}
with open(base_path / "cameras.txt", "r")  as f:
    cameras = parse_cameras(clean_text(f.readlines()))

with open(base_path / "points3D.txt", "r")  as f:
    points3d = parse_points3d(clean_text(f.readlines()))

with open(base_path / "images.txt", "r")  as f:
    images = parse_images(clean_text(f.readlines()), cameras, points3d, images)
    

In [None]:
views = [View.from_image(image) for image in images.values()]

In [None]:
# putting images in the gpu
images = [
    torch.tensor(view.image, dtype=torch.float, device="cuda").permute(2, 0, 1)
    for view in views
]


In [None]:
# validating my ssim loss...
image = images[0].clone()
approx = torch.rand_like(image, dtype=torch.float, device="cuda", requires_grad=True)

plt.imshow(image.clone().cpu().detach().numpy().transpose([1, 2, 0]))
plt.show()

plt.imshow(approx.clone().cpu().detach().numpy().transpose([1, 2, 0]))
plt.show()

_lambda = 1
criterion = lambda img1, img2: (1 - _lambda) * L1(img1, img2) + _lambda * DSSIM(img1, img2)
optimizer = torch.optim.Adam([approx], 0.01)
for epoch in range(100):
    optimizer.zero_grad()
    loss = criterion(image, approx)
    loss.backward()
    optimizer.step()
    if epoch%10 == 0:
        print(f"epoch {epoch}, loss: {loss}")

plt.imshow(approx.clone().cpu().detach().numpy().transpose([1, 2, 0]))
plt.show()
