Part 2: Fit a Neural Radiance Field from Multi-view Images

In [1]:
from PIL import Image
import mediapy as media
from pprint import pprint
from tqdm import tqdm

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from diffusers import DiffusionPipeline
from transformers import T5EncoderModel

# For downloading web images
import requests
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import tqdm
 

device = 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def data_load(data):
    data = np.load(f"lego_200x200.npz")

    # Training images: [100, 200, 200, 3]
    images_train = data["images_train"] / 255.0

    # Cameras for the training images 
    # (camera-to-world transformation matrix): [100, 4, 4]
    c2ws_train = data["c2ws_train"]

    # Validation images: 
    images_val = data["images_val"] / 255.0

    # Cameras for the validation images: [10, 4, 4]
    # (camera-to-world transformation matrix): [10, 200, 200, 3]
    c2ws_val = data["c2ws_val"]

    # Test cameras for novel-view video rendering: 
    # (camera-to-world transformation matrix): [60, 4, 4]
    c2ws_test = data["c2ws_test"]

    # Camera focal length
    focal = data["focal"]  # float
    return focal, images_train, c2ws_train, images_val, c2ws_val, c2ws_test

Part 2.1: Create Rays from Cameras

Part 2.2: Sampling

Part 2.3: Putting the Dataloading All Together

In [3]:
def transform(c2w, x_c):
    if x_c.dim() == 2:
        x_c = x_c.unsqueeze(-1)
    c2w=torch.tensor(c2w)
    xc=torch.cat([x_c, torch.ones_like(x_c[:, :1, :])], dim=1)

    c2w=c2w.to(device)
    xc=xc.to(device)
    x_w=torch.bmm(c2w.float(),xc.float())
    
    return x_w.squeeze(-1)

def transform_check(c2w,x_c):
    c2w_inv = torch.inverse(c2w)
    return x_c == transform(c2w_inv, transform(c2w, x_c))

def pixel_to_camera(K, uv, s):
    ox=K[0,2]
    oy=K[1,2]
    
    xc = s * (uv[:,0] - ox)/K[0,0]
    yc = s * (uv[:,1] - oy) /K[1,1]

    return torch.stack([xc,yc, torch.tensor(s*torch.ones_like(xc))],dim=1)

def pixel_to_ray_batch(K, c2w, uv):

    c2w=torch.tensor(c2w)
    w2c=torch.inverse(c2w)
    r=w2c[:,:3,:3]
    t=w2c[:,:3,3]

    
    ray_o= torch.bmm(-torch.inverse(torch.tensor(r)), t.unsqueeze(-1)) 
    ray_o=ray_o.squeeze(-1)

    K=torch.tensor(K)
    uv=torch.tensor(uv)

    xw=pixel_to_camera(K,uv,1)
    xw=transform(c2w,xw).squeeze(-1)
    xw=xw[:,:3]

    xw=xw.to(device)
    ray_o=ray_o.to(device)
    temp=xw-ray_o
    ray_d= (temp)/torch.norm(temp,dim=1,keepdim=True)
    return ray_o, ray_d

def pixel_to_ray(K, c2w, uv):
    w2c=np.linalg.inv(c2w)
    r=w2c[:2,:2]
    r1=np.linalg.inv(r)
    t=w2c[0,3]
    r1=torch.tensor(r1).to(device)
    
    
    ray_o=torch.bmm(-r1,t.unsqueeze(-1)).squeeze(-1)
    xw=pixel_to_camera(K,uv,1)
    xw=transform(c2w,xw).squeeze(-1)
    # print(xw.shape)


    ray_d= (xw-ray_o)/torch.norm(xw-ray_o,dim=1)
    return ray_o, ray_d


def Ray_Sample_im(images, K, c2ws, m_im,n_samples, offset=.5):
    # irst sample M images, and then sample N // M rays from every image
    m,x,y,z=images.shape
    rand_ims = np.random.randint(0, m, m_im)
    rays_os = []
    rays_ds = []
    for i in range(m_im):
        im = images[rand_ims[i]]
        c2w = c2ws[rand_ims[i]]
        H,W = im.shape[:2]
        uv = torch.rand(n_samples, 2)
        uv[:,0] = (uv[:,0]+x+offset) * W
        uv[:,1] = (uv[:,1]+y+offset) * H
        uv = uv.long()
        rays_o, rays_d = pixel_to_ray(K, c2w, uv)
        rays_o = rays_o.to(device)
        rays_d = rays_d.to(device)
        rays_os.append(rays_o)
        rays_ds.append(rays_d)
    return torch.cat(rays_os, dim=0), torch.cat(rays_ds, dim=0)
        
def sample_along_rays(rayo,rayd,near=2, far=6, n_sampels=64, perturb=True):
    # t = t + (np.random.rand(t.shape) * t_width)
    t = np.linspace(near, far, n_sampels)
    t_width= (far-near)/n_sampels
    # print(t_width)
    # print(t.shape)
    if perturb:
        ts=t.shape
        ts=int(ts[0])
        # print(ts)
        t = t + (np.random.rand(ts) * t_width)
    t = torch.tensor(t).float().to(device)
    t=t.expand(rayo.shape[0], n_sampels).unsqueeze(-1)
    rayo=torch.tensor(rayo).to(device)
    rayd=torch.tensor(rayd).to(device)
    return rayo.unsqueeze(1)+ rayd.unsqueeze(1) *t


def get_K(im, focal):
    h,w,z=im.shape
    return torch.tensor(np.array([[focal, 0, w/2], [0, focal, h/2], [0, 0, 1]]))




In [4]:
from torch.utils.data import Dataset
class RaysData(Dataset):
    def __init__(self, images,K, c2ws, focal, n_samples=64, near=2, far=6, offset=0.5):
        self.images = images
        self.c2ws = c2ws
        self.focal = focal
        self.n_samples = n_samples
        self.near = near
        self.offset = offset
        self.far = far
        self.K = get_K(images[0], focal)
    def uvs(self):
        H,W = self.images.shape[1],self.images.shape[2]
        # print(H.type,W.type)  


        uv = torch.stack(torch.meshgrid(torch.arange(self.images.shape[0]), torch.arange(H), torch.arange(W)), dim=-1).to(device).float()
        uv[..., 0] += self.offset
        uv[..., 1] += self.offset
        
        uv = uv.reshape(-1, 2)


        # uv = torch.stack(torch.meshgrid(torch.arange(self.images.shape[0]), torch.arange(H), torch.arange(W)), dim=-1).to(device).float()
        # uv[:,1] = (uv[:,1] + self.offset) 
        # uv[:,2] = (uv[:,2] + self.offset)
        # # uv = uv.long()
        # print(uv.shape)
        # print(uv.type)

        return uv
    def pixels(self):

        H, W = self.images.shape[1], self.images.shape[2]
        all_uv = []
        for i in range(self.images.shape[0]): 
            uv = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), dim=-1).to(device).float()
            uv[:,: 0] += self.offset
            uv[:,:, 1] += self.offset 
            uv = uv.reshape(-1, 2)
            all_uv.append(uv)

            uv_int = uv.long()
            # im=self.images.numpy()
            # uv_int=uv_int.cpu().numpy()
            # image_pixels = self.images[i, uv_int[:, 1], uv_int[:, 0]]  
            all_uv.append(uv_int)
        # all_uv = torch.tensor(all_uv).to(device)
        # all_uv =all_uv.cpu()
        all_uv = torch.cat(all_uv, dim=0)
        
        im_pixels = self.images[:, all_uv[:, 1], all_uv[:, 0]]
        all_uv =np.array(all_uv)
        # print(all_uv.shape)
        return im_pixels
        # H, W = self.images.shape[1], self.images.shape[2]
        # all_uv = []

        # for i in range(self.images.shape[0]):
        #     uv = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), dim=-1).to(device).float()
        #     uv[..., 0] += self.offset
        #     uv[..., 1] += self.offset
        #     uv = uv.reshape(-1, 2)
        #     all_uv.append(uv)
        
        # return torch.cat(all_uv, dim=0) 
        # print(self.images.reshape(-1,3).shape)
        # pix=self.images.reshape(-1,3)
        # print(pix.reshape(-1).shape)
        # return self.images.reshape(-1,3)
    def rays(self):
        uv = self.uvs()
        ray_o, ray_d = pixel_to_ray_batch(self.K, self.c2ws, uv)
        return ray_o, ray_d

    def sample_rays(self, n_samples, offset=.5):

        H, W = self.images.shape[1], self.images.shape[2]
        all_uv = []

        for i in range(self.images.shape[0]):
            uv = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), dim=-1).to(device).float()
            # print("UV",uv.shape)
            uv[:,: 0] += offset
            uv[:,:, 1] += offset
            uv = uv.reshape(-1, 2)
            all_uv.append(uv)
        
        all_uv = torch.cat(all_uv, dim=0) 
        # print("UVALL",all_uv.shape)

        sampled_indices = torch.randint(0, all_uv.shape[0], (n_samples,))
        im_flat=self.images.reshape(-1,3)
        sampled_pixels = im_flat[sampled_indices]

        sampled_uv = all_uv[sampled_indices]  
        sampled_c2ws = self.c2ws[sampled_indices // (H * W)]
        # print(sampled_indices.shape)
        # print(self.c2ws.shape)
        # print(sampled_uv.shape)
        ray_o, ray_d = pixel_to_ray_batch(self.K, sampled_c2ws, sampled_uv)

        return ray_o, ray_d, torch.tensor(sampled_pixels).to(device)



In [5]:


# import viser, time  # pip install viser
# import numpy as np

# focal, images_train, c2ws_train, images_val, c2ws_val, c2ws_test= data_load("lego_200x200.npz")
# # images_train=images_train[:10]
# # c2ws_train=c2ws_train[:10]
# # print(images_train[0].shape)
# K=get_K(images_train[0], focal)
# dataset = RaysData(images_train, K, c2ws_train,focal)
# rays_o, rays_d, pixels = dataset.sample_rays(100) # Should expect (B, 3)
# points = sample_along_rays(rays_o, rays_d, perturb=True)
# points = points.cpu().numpy()
# rays_o = rays_o.cpu().numpy()
# K = K.cpu().numpy()
# rays_d = rays_d.cpu().numpy()
# pixels = pixels.cpu().numpy()
# H, W = images_train.shape[1:3]
# # ---------------------------------------

# # server = viser.ViserServer(share=True)

# # for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
# #     server.add_camera_frustum(
# #         f"/cameras/{i}",

# #         fov=2 * np.arctan2(H / 2, K[0, 0]),
#         # aspect=W / H,
# #         scale=0.15,
# #         wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
# #         position=c2w[:3, 3],
# #         image=image
# #     )
# # for i, (o, d) in enumerate(zip(rays_o, rays_d)):
# #     server.add_spline_catmull_rom(
# #         f"/rays/{i}", positions=np.stack((o, o + d * 6.0)),
# #     )
# # server.add_point_cloud(
# #     f"/samples",
# #     colors=np.zeros_like(points).reshape(-1, 3),
# #     points=points.reshape(-1, 3),
# #     point_size=0.02,
# # )
# # time.sleep(1000)



In [6]:


# # Visualize Cameras, Rays and Samples
# import viser, time
# import numpy as np

# # --- You Need to Implement These ------
# focal, images_train, c2ws_train, images_val, c2ws_val, c2ws_test= data_load("lego_200x200.npz")
# K=get_K(images_train[0], focal)

# dataset = RaysData(images_train, K, c2ws_train,focal)

# # This will check that your uvs aren't flipped
# uvs_start = 0
# uvs_end = 40_000
# sample_uvs = dataset.uvs()[uvs_start:uvs_end] # These are integer coordinates of widths / heights (xy not yx) of all the pixels in an image
# # uvs are array of xy coordinates, so we need to index into the 0th image tensor with [0, height, width], so we need to index with uv[:,1] and then uv[:,0]


# # sample_uvs = sample_uvs.cpu().numpy().astype(np.int32)
# print(sample_uvs.shape)
# dataset_pixels = dataset.pixels()
# dataset_pixels=dataset_pixels[uvs_start:uvs_end]
# print(dataset_pixels.shape) 
# print(images_train[0, sample_uvs[:,1], sample_uvs[:,0]].shape)
# print(images_train[0,dataset_pixels[:,1], dataset_pixels[:,0]])
# print(images_train[0, sample_uvs[:,1], sample_uvs[:,0]] )
# assert np.all(images_train[0, sample_uvs[:,1], sample_uvs[:,0]] == dataset_pixels)    

# # # Uncoment this to display random rays from the first image
# indices = np.random.randint(low=0, high=40_000, size=100)

# # # Uncomment this to display random rays from the top left corner of the image
# # indices_x = np.random.randint(low=100, high=200, size=100)
# # indices_y = np.random.randint(low=0, high=100, size=100)
# # indices = indices_x + (indices_y * 200)

# data = {"rays_o": dataset.rays_o[indices], "rays_d": dataset.rays_d[indices]}
# points = sample_along_rays(data["rays_o"], data["rays_d"], random=True)
# # ---------------------------------------

# server = viser.ViserServer(share=True)
# for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
#   server.add_camera_frustum(
#     f"/cameras/{i}",
#     fov=2 * np.arctan2(H / 2, K[0, 0]),
#     aspect=W / H,
#     scale=0.15,
#     wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
#     position=c2w[:3, 3],
#     image=image
#   )
# for i, (o, d) in enumerate(zip(data["rays_o"], data["rays_d"])):
#   positions = np.stack((o, o + d * 6.0))
#   server.add_spline_catmull_rom(
#       f"/rays/{i}", positions=positions,
#   )
# server.add_point_cloud(
#     f"/samples",
#     colors=np.zeros_like(points).reshape(-1, 3),
#     points=points.reshape(-1, 3),
#     point_size=0.03,
# )
# time.sleep(1000)



In [7]:
def volrend(sigmas, rgbs, step_size=(6.0 - 2.0) / 64):
    # x=torch.exp(-sigmas*step_size)
    # x=1-x
    # T=torch.cumprod(x,dim=2)
    # T=T.to(device)
    # x=x.to(device)
    # rgbs=rgbs.to(device)

    # T = torch.cat([torch.ones((T.shape[0], 1, 1), device=device), T[:, :-1]], dim=1)    
    # c=T*x*rgbs
    # return torch.sum(c,dim=1)
    x = 1-torch.exp(-sigmas *step_size)
    T = torch.cumprod(1 - x, dim=1) 
    
    T = T.to(device)
    x = x.to(device)
    rgbs = rgbs.to(device)
    T = torch.cat([torch.ones((T.shape[0], 1, 1), device=device), T[:, :-1]], dim=1)
    
    c = T * x * rgbs  
    
    ret = torch.sum(c, dim=1)  
    
    return ret







# import torch
# torch.manual_seed(42)
# sigmas = torch.rand((10, 64, 1))
# rgbs = torch.rand((10, 64, 3))
# step_size = (6.0 - 2.0) / 64
# rendered_colors = volrend(sigmas, rgbs, step_size)

# correct = torch.tensor([
#     [0.5006, 0.3728, 0.4728],
#     [0.4322, 0.3559, 0.4134],
#     [0.4027, 0.4394, 0.4610],
#     [0.4514, 0.3829, 0.4196],
#     [0.4002, 0.4599, 0.4103],
#     [0.4471, 0.4044, 0.4069],
#     [0.4285, 0.4072, 0.3777],
#     [0.4152, 0.4190, 0.4361],
#     [0.4051, 0.3651, 0.3969],
#     [0.3253, 0.3587, 0.4215]
#   ])
# rendered_colors = rendered_colors.cpu()
# print(rendered_colors)
# correct = correct.cpu()
# print(correct.shape)

# assert torch.allclose(rendered_colors, correct, rtol=1e-4, atol=1e-4)



In [8]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

class MLP(nn.Module):
    def __init__(self, input_size=2, L=10, L1=4, device='cuda'):
        super(MLP, self).__init__()

        self.L = L
        self.L1 = L1
        self.device = device

        self.l1 = nn.Linear(3 * (2 * L + 1), 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 256)
        self.l4 = nn.Linear(256, 256)
        # CONCAT

        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 256)
        self.l7 = nn.Linear(256, 256)
        self.l8 = nn.Linear(256, 256)
        # SPLIT OFF

        self.l9 = nn.Linear(256, 1)
        # DENSITY

        self.l10 = nn.Linear(256, 256)
        self.l11 = nn.Linear(3 * (2 * L1 + 1)+256, 128)
        self.l12 = nn.Linear(128, 3)
        # COLOR

        self.relu=nn.ReLU()
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x,rd):
        x = x.to(self.device)
        rd = rd.to(self.device)
        xpe=PE(x,self.L)
        # print("XPE",xpe.shape)
        rdpe=PE(rd,self.L1)
        # print("RDPE",rdpe.shape)
        xpe=xpe.to(self.device)
        rdpe=rdpe.to(self.device)
        a = self.relu(self.l1(xpe))
        a = self.relu(self.l2(a))
        a = self.relu(self.l3(a))
        a = self.relu(self.l4(a))
        # CONCAT

        a = self.relu(self.l5(a))
        a = self.relu(self.l6(a))
        a = self.relu(self.l7(a))
        # a = self.l8(a)
        a = self.relu(self.l8(a))
        # SPLIT OFF
        d= self.relu(self.l9(a))
        # DENSITY
        a = self.l10(a)
        # print("A",a.shape)
        # print("RDPE",rdpe.shape)

        # A torch.Size([1000, 64, 256])
        # RDPE torch.Size([1000, 27])

        # a=torch.cat([a,rdpe],dim=-1)    # RuntimeError: Tensors must have same number of dimensions: got 3 and 2
        # a=torch.cat([a,rdpe.unsqueeze(1)],dim=-1)
        a=torch.cat((a,rdpe),dim=2)
        # print("RGB",a.shape)
        a = self.relu(self.l11(a))
        # print("AAAAAAAAAAAAAAAAAAAAAAAAAAAA128")
        a = self.sigmoid(self.l12(a))
        # print("AAAAAAAAAAAAAAAAAAAAAAAA3")
        return a.float(), d.float()


    

def PE(x, L):


    freqs = 2.0 ** torch.arange(L).float().to(device)
    x_input = x.unsqueeze(-1) * freqs * 2 * torch.pi
    pe = torch.cat([torch.sin(x_input), torch.cos(x_input)], dim=-1)
    
    oe = torch.cat([x, pe.reshape(*x.shape[:-1], -1)], dim=-1) 
    
    return oe


def compute_psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return 100  
    return 10 * torch.log10(1.0 / mse)

def rand_sample(im, N):
    x, y, z = im.shape
    points = []
    rand_x = np.random.randint(0, x, N)
    rand_y = np.random.randint(0, y, N)
    colors = []
    
    for i in range(N):
        points.append([rand_y[i] / y, rand_x[i] / x])  # Normalized coordinates
        c = im[rand_x[i], rand_y[i]]
        c = c / 255  # Normalize color values to [0, 1]
        colors.append(c)
    # print("points",points)
    return np.array(points), np.array(colors)

def create_image(ima, model, L):
    h, w, _ = ima.shape
    h, w = h // 3, w // 3  
    # print(h,w)
    x = torch.linspace(0, w-1, w).repeat(h, 1) / w  
    # print(x.shape)
    y = torch.linspace(0, h-1, h).repeat(w, 1).transpose(0, 1) / h  
    # print(y.shape)
    # Stack x and y to get (x, y) pairs and reshape them into a single tensor
    all_coords = torch.stack([x, y], dim=-1).view(-1, 2)
    # print(all_coords.shape)
    # print(all_coords)

    
    with torch.no_grad():
        predicted_pixels = model(all_coords)
    
    predicted_pixels_cpu = predicted_pixels.cpu()

    predicted_image = predicted_pixels_cpu.reshape(h, w, 3).numpy()

    predicted_image = np.clip(predicted_image * 255, 0, 255).astype(np.uint8)
    # predicted_image=np.rot90(predicted_image,3)
    
    return predicted_image


def train_mod(images_train, c2ws_train, images_val, c2ws_val,focal, model, N, L, iterations, lr=0.005):
    # images_train=images_train.to(device)
    # c2ws_train=c2ws_train.to(device)
    # images_val=images_val.to(device)
    # c2ws_val=c2ws_val.to(device)
    
    # im = np.array(im)
    model.to(device)
    loss_func = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr)
    train_data=RaysData(images_train, get_K(images_train[0], focal), c2ws_train, focal)
    val_data=RaysData(images_val, get_K(images_val[0], focal), c2ws_val, focal)


    loss_list = []
    psnr_list = []
    predicted_images = []
    model.train()
    val_loss =[]
    val_psnr=[]
    # print("TRAIN DATA",images_train[].shape)
    for i in range(iterations):
        rayo,rayd, pixels = train_data.sample_rays(N)
        # print(rayo.shape)
        xw=sample_along_rays(rayo,rayd,perturb=True)
        rayd=rayd.unsqueeze(1).expand(-1, xw.shape[1], -1)
        # print("RAYSD", rayd.type)
        optimizer.zero_grad()
        xw = xw.to(device).float()
        rayd = rayd.to(device).float()
           
        # print("XWTUPE",xw.dtype)
        # print(rayd.dtype)
        # print("rad_sahpe",rayd.shape)

        rgb, d = model(xw, rayd)
        
        predrgb = volrend(d, rgb)
        # print("PRED RGB", predrgb.shape)
        # print("PIXELS", pixels.shape)
        # colors=images_train.reshape(-1,3)
        # colors=colors.cpu().numpy()
        # colors=colors[pixels]
        loss = loss_func(predrgb.float(), pixels.float()).float()
        # loss = sum(loss for l.float() in loss.values())
        loss_list.append(loss.item())
        
        psnr = compute_psnr(predrgb, pixels)    
        psnr_list.append(psnr)
        loss.backward()
        optimizer.step()
        if i % 20 ==0:
            with torch.no_grad():
                model.eval()
                rayo, rayd, pixels = val_data.sample_rays(1000)
                xw = sample_along_rays(rayo, rayd, perturb=False)
                xw = xw.to(device).float()
                rayd=rayd.unsqueeze(1).expand(-1, xw.shape[1], -1)

                rayd = rayd.to(device).float()
                rgb, d = model(xw, rayd)
                predrgb = volrend(d, rgb)
                loss = loss_func(predrgb, pixels)
                psnr = compute_psnr(predrgb, pixels)
                val_loss.append(loss.item())
                val_psnr.append(psnr)
                print(f"Validation Loss: {loss.item()}, PSNR: {psnr}")
                model.train()
    return model, loss_list, psnr_list, predicted_images, val_loss, val_psnr

In [None]:
focal, images_train, c2ws_train, images_val, c2ws_val, c2ws_test= data_load("lego_200x200.npz")
model=MLP(input_size=3,L=10,L1=4,device='cuda')
modela, loss_list, psnr_list, predicted_images, val_loss, val_psnr = train_mod(images_train, c2ws_train, images_val, c2ws_val,focal, model, 1000, 20, 20000, lr=0.005)

  ray_o= torch.bmm(-torch.inverse(torch.tensor(r)), t.unsqueeze(-1))
  K=torch.tensor(K)
  uv=torch.tensor(uv)
  return torch.stack([xc,yc, torch.tensor(s*torch.ones_like(xc))],dim=1)
  c2w=torch.tensor(c2w)
  rayo=torch.tensor(rayo).to(device)
  rayd=torch.tensor(rayd).to(device)


Validation Loss: 0.0944333215028575, PSNR: 10.248747346070335
Validation Loss: 0.05872927333862456, PSNR: 12.311453726795685
Validation Loss: 0.05740318819301266, PSNR: 12.410639860657732
Validation Loss: 0.05974273653474681, PSNR: 12.237148882815937
Validation Loss: 0.06189474071016124, PSNR: 12.083462520718665
Validation Loss: 0.0653786896173978, PSNR: 11.845637882162862
Validation Loss: 0.0647819777176382, PSNR: 11.885457976233702
Validation Loss: 0.06556418993203035, PSNR: 11.833333002782387
Validation Loss: 0.06517419074684715, PSNR: 11.859243526622908
Validation Loss: 0.06793577691136263, PSNR: 11.679014536552172
Validation Loss: 0.06022603737191307, PSNR: 12.202157106983622
Validation Loss: 0.04860600645081062, PSNR: 13.13310059806425
Validation Loss: 0.053113837503427144, PSNR: 12.747923194365466
Validation Loss: 0.04316395941979967, PSNR: 13.648787244316436
Validation Loss: 0.03995101132993718, PSNR: 13.9847222237122
Validation Loss: 0.03662661569357901, PSNR: 14.3620320837629