In [1]:
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

from run_nerf_helpers import *

from load_llff import load_llff_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False

In [2]:
# takeaway function and not decompose it
def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs):
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

def run_network(inputs, view_direction, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'.
    """
    
    # [DONE] fn = model
    # [NOT] pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # -1 is the last one
    
    # [NOT]  embad the position
    embedded = embed_fn(inputs_flat)

    if view_direction is not None:
        input_dirs = view_direction[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs


def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret




def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]




def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                return_raw=False,
                linear_depend_inverse_depth=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_background=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      return_raw: bool. If True, include model's raw, unprocessed predictions.
      linear_depend_inverse: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_background: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disparity_map: [num_rays]. Disparity map. 1 / depth.
      accumulated_opacity_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disparity_map. Output for coarse model.
      acc0: See accumulated_opacity_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_origin, rays_direction = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    view_direction = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples).to(device)
    
    if not linear_depend_inverse_depth:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples]).to(device)

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape).to(device)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand).to(device)

        z_vals = lower + (upper - lower) * t_rand

    # form the origin to the end of the ray, in the ray direction to form pts
    pts = rays_origin[...,None,:] + rays_direction[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    


#     raw = run_network(pts)
    raw = network_query_fn(pts, view_direction, network_fn)
    rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map = raw2outputs(raw, z_vals, rays_direction, raw_noise_std, white_background, pytest=pytest)

    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disparity_map, accumulated_opacity_map

        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_origin[...,None,:] + rays_direction[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
#         raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, view_direction, run_fn)

        rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map = raw2outputs(raw, z_vals, rays_direction, raw_noise_std, white_background, pytest=pytest)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disparity_map, 'acc_map' : accumulated_opacity_map}
    if return_raw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret



def raw2outputs(raw, z_vals, rays_direction, raw_noise_std=0, white_background=False, pytest=False):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_direction: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disparity_map: [num_rays]. Disparity map. Inverse of depth map.
        accumulated_opacity_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, distances, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*distances)

    distances = z_vals[...,1:] - z_vals[...,:-1]
    distances = torch.cat([distances, torch.Tensor([1e10]).to(device).expand(distances[...,:1].shape)], -1)  # [N_rays, N_samples]

    distances = distances * torch.norm(rays_direction[...,None,:], dim=-1)

    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape).to(device) * raw_noise_std

        # Overwrite randomly sampled data if pytest
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise).to(device)

    alpha = raw2alpha(raw[...,3] + noise, distances)  # [N_rays, N_samples]
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    # tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
    # cumprod_tensor = torch.cumprod(tensor, dim=0)
    # tensor([1., 2., 6., 24., 120.])
    # torch.cumprod() is Ti of all the point
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)
    disparity_map = 1./torch.max(1e-10 * torch.ones_like(depth_map).to(device), depth_map / torch.sum(weights, -1))
    accumulated_opacity_map = torch.sum(weights, -1)

    if white_background:
        rgb_map = rgb_map + (1.-accumulated_opacity_map[...,None])

    return rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map


def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        print(i, time.time() - t)
        t = time.time()
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)


    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps





In [3]:
def train():

    
    #########################################################################
    # python run_nerf.py --config configs/fern.txt
    N_importance = 64
    N_rand = 1024
    N_samples = 64
    basedir = './logs'
    chunk = 32768
    config = 'configs/fern.txt'
    datadir = './data/nerf_llff_data/fern'
    dataset_type = 'llff'
    expname = 'fern_test'
    factor = 8
    ft_path = None
    half_res = False
    i_embed = 0
    i_img = 500
    i_print = 100
    i_testset = 5000
    i_video = 5000
    i_weights = 10000
    lindisp = False
    llffhold = 8
    lrate = 0.0005
    lrate_decay = 250
    multires = 10
    multires_views = 4
    netchunk = 65536
    netdepth = 8
    netdepth_fine = 8
    netwidth = 256
    netwidth_fine = 256
    no_batching = False
    no_ndc = False
    no_reload = False
    perturb = 1.0
    precrop_frac = 0.5
    precrop_iters = 0
    raw_noise_std = 1.0
    render_factor = 0
    render_only = False
    render_test = False
    shape = 'greek'
    spherify = False
    testskip = 8
    use_viewdirs = True
    white_background = False
    ################################################################################

    images, poses, bds, render_poses, i_test = load_llff_data(datadir, factor, recenter=True, bd_factor=.75, spherify=spherify)
    print("================[0][load_data]=================")
    #print(images.shape)   (20, 378, 504, 3)
    #print(poses.shape)    (20, 3, 5)
    #print(bds.shape)      (20, 2)
    #print(render_poses.shape)  (120, 3, 5)   # Generate poses for spiral path
    #print(i_test) 12           print('HOLDOUT view is', i_test)
    #print(i_test.shape)

    #print(bds) [0.5500126  2.4253333 ]
    # print(bds.shape)  (20, 2)
    # []


    # what is the (20, 3, 5), 3 is for what, 5 is for what
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    #print(poses.shape) (20, 3, 4)
    print(poses[0])
    # (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
    print('Loaded llff', images.shape, render_poses.shape, hwf, datadir)

    if not isinstance(i_test, list):
        i_test = [i_test]

    #print(i_test) [12]


    if llffhold > 0:
        print('Auto LLFF holdout,', llffhold)
        i_test = np.arange(images.shape[0])[::llffhold]
    #print(i_test) [ 0  8 16]

    i_val = i_test
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                    (i not in i_test and i not in i_val)])
    #print(i_train) [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]


    print('DEFINING BOUNDS')
    if no_ndc:
        near = np.ndarray.min(bds) * .9
        far = np.ndarray.max(bds) * 1.

    else:
        near = 0.
        far = 1.
    print('NEAR FAR', near, far) # 0.4737630307674408 2.4794018268585205


    # 1 Cast intrinsics to right types
    H, W, focal = hwf
    #print(hwf)  [378.     504.     407.5658]
    H, W = int(H), int(W)
    hwf = [H, W, focal]
    
    #print(hwf) [378, 504, 407.5658]

    # Load data

    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])

    print(K)    
    
    basedir = basedir
    expname = expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)    
    
    
        
    
    # create_nerf()
    
    print("================[1][create_nerf()]=================")
    # Positional encoding
    
    embed_fn, input_ch = get_embedder(multires, i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed)
    output_ch = 5 if N_importance > 0 else 4
    skips = [4]
    
    
    model = NeRF(D=netdepth, W=netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    
    model_fine = None
    if N_importance > 0:
        model_fine = NeRF(D=netdepth_fine, W=netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=lrate, betas=(0.9, 0.999))

    start = 0
    basedir = basedir
    expname = expname

    ##########################

    # Load checkpoints
    if ft_path is not None and ft_path!='None':
        ckpts = [ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    ##########################

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : perturb,
        'N_importance' : N_importance,
        'network_fine' : model_fine,
        'N_samples' : N_samples,
        'network_fn' : model,
        'use_viewdirs' : use_viewdirs,
        'white_background' : white_background,
        'raw_noise_std' : raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if dataset_type != 'llff' or no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.
    
    
    print("================[2][get training data, validation data]=================")

    global_step = start

    bds_dict = {
        'near' : near,
        'far' : far,
    }

    # near, far to dictionary
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    
    
    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)
    #print(render_poses)
    #print(render_poses.shape) torch.Size([3, 3, 4])    
    
    # Prepare raybatch tensor if batching random rays
    N_rand = N_rand # 1024
    use_batching = not no_batching # no_batching = True
    
    
    
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    
    #print(poses)
    #print(poses.shape) #torch.Size([20, 3, 4])
    
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)    

    N_iters = 10000 + 1
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    start = start + 1   
    for i in trange(start, N_iters):
        time0 = time.time()
        
                # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
                
                
        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=chunk, rays=batch_rays, verbose=i < 10, return_raw=True, **render_kwargs_train)   
        
        #print("================[3][loss]=================")
        
        optimizer.zero_grad()
        # get the loss of model prediction and the target image
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        #input("Press Enter to continue...")
        
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = lrate_decay * 1000
        new_lrate = lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time()-time0

        # Rest is logging
        if i%i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, path)
            print('Saved checkpoints at', path)

        if i%i_video==0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

        if i%i_testset==0 and i > 0:
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')


    
        if i%i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")

        global_step += 1

In [None]:
train()

[[ 0.99569476 -0.02079598 -0.09033062 -0.3081002 ]
 [ 0.02503342  0.9986262   0.04603354  0.1346772 ]
 [ 0.0892492  -0.04809664  0.99484736  0.03989876]]
Loaded llff (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
Auto LLFF holdout, 8
DEFINING BOUNDS
NEAR FAR 0.0 1.0
[[407.5657959   0.        252.       ]
 [  0.        407.5657959 189.       ]
 [  0.          0.          1.       ]]
Found ckpts []
get rays
done, concats
shuffle rays
done
Begin
TRAIN views are [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]
TEST views are [ 0  8 16]
VAL views are [ 0  8 16]


  1%|▍                                      | 101/10000 [00:16<26:37,  6.20it/s]

[TRAIN] Iter: 100 Loss: 0.050077032297849655  PSNR: 16.044286727905273


  2%|▊                                      | 201/10000 [00:32<26:14,  6.22it/s]

[TRAIN] Iter: 200 Loss: 0.04124170541763306  PSNR: 16.878141403198242


  3%|█▏                                     | 301/10000 [00:48<26:51,  6.02it/s]

[TRAIN] Iter: 300 Loss: 0.03489493206143379  PSNR: 17.686418533325195


  4%|█▌                                     | 401/10000 [01:05<25:52,  6.18it/s]

[TRAIN] Iter: 400 Loss: 0.030470171943306923  PSNR: 18.277976989746094


  5%|█▉                                     | 501/10000 [01:22<26:05,  6.07it/s]

[TRAIN] Iter: 500 Loss: 0.028397828340530396  PSNR: 18.594423294067383


  6%|██▎                                    | 601/10000 [01:39<26:48,  5.84it/s]

[TRAIN] Iter: 600 Loss: 0.02775314450263977  PSNR: 18.61324691772461


  7%|██▋                                    | 701/10000 [01:56<26:03,  5.95it/s]

[TRAIN] Iter: 700 Loss: 0.025676608085632324  PSNR: 18.926496505737305


  8%|███                                    | 801/10000 [02:12<24:56,  6.15it/s]

[TRAIN] Iter: 800 Loss: 0.020007766783237457  PSNR: 20.03154182434082


  9%|███▌                                   | 901/10000 [02:28<25:03,  6.05it/s]

[TRAIN] Iter: 900 Loss: 0.02025044709444046  PSNR: 20.038225173950195


 10%|███▊                                  | 1001/10000 [02:45<24:40,  6.08it/s]

[TRAIN] Iter: 1000 Loss: 0.02114073559641838  PSNR: 19.832273483276367


 11%|████▏                                 | 1101/10000 [03:01<24:12,  6.13it/s]

[TRAIN] Iter: 1100 Loss: 0.018139991909265518  PSNR: 20.633419036865234


 12%|████▌                                 | 1201/10000 [03:18<24:40,  5.94it/s]

[TRAIN] Iter: 1200 Loss: 0.01878197118639946  PSNR: 20.39267921447754


 13%|████▉                                 | 1301/10000 [03:35<24:58,  5.81it/s]

[TRAIN] Iter: 1300 Loss: 0.018991604447364807  PSNR: 20.26565933227539


 14%|█████▎                                | 1401/10000 [03:52<24:04,  5.95it/s]

[TRAIN] Iter: 1400 Loss: 0.01930193416774273  PSNR: 20.25137710571289


 15%|█████▋                                | 1501/10000 [04:09<22:50,  6.20it/s]

[TRAIN] Iter: 1500 Loss: 0.016956571489572525  PSNR: 20.67484474182129


 16%|██████                                | 1601/10000 [04:25<23:19,  6.00it/s]

[TRAIN] Iter: 1600 Loss: 0.018471121788024902  PSNR: 20.367103576660156


 17%|██████▍                               | 1701/10000 [04:42<22:46,  6.07it/s]

[TRAIN] Iter: 1700 Loss: 0.016976643353700638  PSNR: 20.807165145874023


 18%|██████▊                               | 1801/10000 [04:59<22:09,  6.17it/s]

[TRAIN] Iter: 1800 Loss: 0.01617090404033661  PSNR: 20.90141487121582


 19%|███████▏                              | 1901/10000 [05:16<23:25,  5.76it/s]

[TRAIN] Iter: 1900 Loss: 0.01662593148648739  PSNR: 20.826297760009766


 20%|███████▌                              | 2001/10000 [05:33<21:40,  6.15it/s]

[TRAIN] Iter: 2000 Loss: 0.015796680003404617  PSNR: 21.067794799804688


 21%|███████▉                              | 2101/10000 [05:50<23:12,  5.67it/s]

[TRAIN] Iter: 2100 Loss: 0.012486386112868786  PSNR: 22.10917091369629


 22%|████████▎                             | 2201/10000 [06:07<22:34,  5.76it/s]

[TRAIN] Iter: 2200 Loss: 0.015159999951720238  PSNR: 21.195871353149414


 23%|████████▋                             | 2301/10000 [06:23<21:22,  6.00it/s]

[TRAIN] Iter: 2300 Loss: 0.014521108940243721  PSNR: 21.375022888183594


 24%|█████████                             | 2401/10000 [06:40<22:19,  5.67it/s]

[TRAIN] Iter: 2400 Loss: 0.014933507889509201  PSNR: 21.44516944885254


 25%|█████████▌                            | 2501/10000 [06:58<21:46,  5.74it/s]

[TRAIN] Iter: 2500 Loss: 0.013198855333030224  PSNR: 21.785236358642578


 26%|█████████▉                            | 2601/10000 [07:15<21:35,  5.71it/s]

[TRAIN] Iter: 2600 Loss: 0.012501506134867668  PSNR: 22.035245895385742


 27%|██████████▎                           | 2701/10000 [07:32<19:27,  6.25it/s]

[TRAIN] Iter: 2700 Loss: 0.013689625076949596  PSNR: 21.567167282104492


 28%|██████████▋                           | 2801/10000 [07:49<21:27,  5.59it/s]

[TRAIN] Iter: 2800 Loss: 0.013848110102117062  PSNR: 21.603015899658203


 29%|███████████                           | 2901/10000 [08:06<21:23,  5.53it/s]

[TRAIN] Iter: 2900 Loss: 0.012332865968346596  PSNR: 22.168115615844727


 30%|███████████▍                          | 3001/10000 [08:24<20:56,  5.57it/s]

[TRAIN] Iter: 3000 Loss: 0.015457745641469955  PSNR: 21.185955047607422


 31%|███████████▊                          | 3101/10000 [08:41<19:16,  5.97it/s]

[TRAIN] Iter: 3100 Loss: 0.012559417635202408  PSNR: 21.876768112182617


 32%|████████████                          | 3163/10000 [08:51<19:18,  5.90it/s]

Shuffle data after an epoch!


 32%|████████████▏                         | 3201/10000 [08:58<20:41,  5.48it/s]

[TRAIN] Iter: 3200 Loss: 0.012911895290017128  PSNR: 22.12264633178711


 33%|████████████▌                         | 3301/10000 [09:15<18:36,  6.00it/s]

[TRAIN] Iter: 3300 Loss: 0.014125118032097816  PSNR: 21.56333351135254


 34%|████████████▉                         | 3401/10000 [09:32<19:53,  5.53it/s]

[TRAIN] Iter: 3400 Loss: 0.012072659097611904  PSNR: 22.177509307861328


 35%|█████████████▎                        | 3501/10000 [09:49<17:53,  6.05it/s]

[TRAIN] Iter: 3500 Loss: 0.011682344600558281  PSNR: 22.181602478027344


 36%|█████████████▋                        | 3601/10000 [10:06<19:03,  5.60it/s]

[TRAIN] Iter: 3600 Loss: 0.012595503591001034  PSNR: 21.962404251098633


 37%|██████████████                        | 3701/10000 [10:24<18:51,  5.57it/s]

[TRAIN] Iter: 3700 Loss: 0.01421038992702961  PSNR: 21.519792556762695


 38%|██████████████▍                       | 3801/10000 [10:41<18:46,  5.50it/s]

[TRAIN] Iter: 3800 Loss: 0.010580244474112988  PSNR: 22.909259796142578


 39%|██████████████▊                       | 3901/10000 [11:00<18:16,  5.56it/s]

[TRAIN] Iter: 3900 Loss: 0.012613527476787567  PSNR: 22.050128936767578


 40%|███████████████▏                      | 4001/10000 [11:17<17:54,  5.58it/s]

[TRAIN] Iter: 4000 Loss: 0.01221007201820612  PSNR: 22.1895751953125


 41%|███████████████▌                      | 4101/10000 [11:36<17:50,  5.51it/s]

[TRAIN] Iter: 4100 Loss: 0.013459865003824234  PSNR: 21.900571823120117


 42%|███████████████▉                      | 4201/10000 [11:53<17:28,  5.53it/s]

[TRAIN] Iter: 4200 Loss: 0.01218983344733715  PSNR: 22.132488250732422


 43%|████████████████▎                     | 4301/10000 [12:11<16:50,  5.64it/s]

[TRAIN] Iter: 4300 Loss: 0.011006811633706093  PSNR: 22.481454849243164


 44%|████████████████▋                     | 4401/10000 [12:29<16:24,  5.69it/s]

[TRAIN] Iter: 4400 Loss: 0.01203480176627636  PSNR: 22.33251190185547


 45%|█████████████████                     | 4501/10000 [12:46<15:32,  5.89it/s]

[TRAIN] Iter: 4500 Loss: 0.011568516492843628  PSNR: 22.30573844909668


 46%|█████████████████▍                    | 4601/10000 [13:04<15:19,  5.87it/s]

[TRAIN] Iter: 4600 Loss: 0.011076179333031178  PSNR: 22.60428237915039


 47%|█████████████████▊                    | 4701/10000 [13:23<15:55,  5.54it/s]

[TRAIN] Iter: 4700 Loss: 0.012651467695832253  PSNR: 22.08138084411621


 48%|██████████████████▏                   | 4801/10000 [13:40<14:54,  5.81it/s]

[TRAIN] Iter: 4800 Loss: 0.01045955065637827  PSNR: 22.89117431640625


 49%|██████████████████▌                   | 4901/10000 [13:58<16:15,  5.22it/s]

[TRAIN] Iter: 4900 Loss: 0.010553423315286636  PSNR: 22.80103874206543


 50%|██████████████████▉                   | 4999/10000 [14:15<14:45,  5.64it/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


0 0.0029077529907226562



  1%|▎                                          | 1/120 [00:12<25:27, 12.84s/it][A

torch.Size([378, 504, 3]) torch.Size([378, 504])
1 12.839423179626465



  2%|▋                                          | 2/120 [00:25<24:36, 12.51s/it][A

2 12.280085802078247



  2%|█                                          | 3/120 [00:37<24:37, 12.62s/it][A

3 12.761374711990356



  3%|█▍                                         | 4/120 [00:50<24:36, 12.72s/it][A

4 12.876814603805542



  4%|█▊                                         | 5/120 [01:03<24:29, 12.77s/it][A

5 12.862267971038818



  5%|██▏                                        | 6/120 [01:16<24:19, 12.81s/it][A

6 12.866382598876953



  6%|██▌                                        | 7/120 [01:28<23:51, 12.67s/it][A

7 12.387670755386353



  7%|██▊                                        | 8/120 [01:41<23:36, 12.64s/it][A

8 12.593445777893066



  8%|███▏                                       | 9/120 [01:53<22:58, 12.42s/it][A

9 11.92132306098938



  8%|███▌                                      | 10/120 [02:06<23:07, 12.62s/it][A

10 13.059399366378784



  9%|███▊                                      | 11/120 [02:18<22:51, 12.58s/it][A

11 12.493321418762207



 10%|████▏                                     | 12/120 [02:31<22:40, 12.60s/it][A

12 12.654136657714844



 11%|████▌                                     | 13/120 [02:44<22:25, 12.58s/it][A

13 12.515276670455933



 12%|████▉                                     | 14/120 [02:56<22:06, 12.52s/it][A

14 12.376149654388428



 12%|█████▎                                    | 15/120 [03:08<21:46, 12.44s/it][A

15 12.259974241256714



 13%|█████▌                                    | 16/120 [03:21<21:33, 12.44s/it][A

16 12.430172681808472



 14%|█████▉                                    | 17/120 [03:33<21:17, 12.41s/it][A

17 12.333553075790405



 15%|██████▎                                   | 18/120 [03:46<21:21, 12.57s/it][A

18 12.9485924243927



 16%|██████▋                                   | 19/120 [03:58<20:59, 12.47s/it][A

19 12.232546091079712



 17%|███████                                   | 20/120 [04:11<20:53, 12.54s/it][A

20 12.702680110931396



 18%|███████▎                                  | 21/120 [04:23<20:34, 12.47s/it][A

21 12.303544282913208



 18%|███████▋                                  | 22/120 [04:35<20:16, 12.42s/it][A

22 12.297260522842407



 19%|████████                                  | 23/120 [04:48<20:17, 12.55s/it][A

23 12.853339195251465



 20%|████████▍                                 | 24/120 [05:00<19:47, 12.37s/it][A

24 11.940113544464111



 21%|████████▊                                 | 25/120 [05:13<19:58, 12.61s/it][A

25 13.19567322731018



 22%|█████████                                 | 26/120 [05:26<19:56, 12.73s/it][A

26 12.997981548309326



 22%|█████████▍                                | 27/120 [05:39<19:44, 12.73s/it][A

27 12.746290683746338



 23%|█████████▊                                | 28/120 [05:52<19:34, 12.77s/it][A

28 12.855650186538696



 24%|██████████▏                               | 29/120 [06:05<19:24, 12.80s/it][A

29 12.853298425674438



 25%|██████████▌                               | 30/120 [06:17<19:04, 12.71s/it][A

30 12.524933338165283



 26%|██████████▊                               | 31/120 [06:31<19:02, 12.84s/it][A

31 13.130508422851562



 27%|███████████▏                              | 32/120 [06:43<18:41, 12.74s/it][A

32 12.511590003967285



 28%|███████████▌                              | 33/120 [06:56<18:33, 12.80s/it][A

33 12.935600280761719



 28%|███████████▉                              | 34/120 [07:09<18:28, 12.89s/it][A

34 13.089622735977173



 29%|████████████▎                             | 35/120 [07:22<18:05, 12.77s/it][A

35 12.490678548812866



 30%|████████████▌                             | 36/120 [07:34<17:54, 12.79s/it][A

36 12.857465505599976



 31%|████████████▉                             | 37/120 [07:47<17:39, 12.76s/it][A

37 12.681424140930176



 32%|█████████████▎                            | 38/120 [07:59<17:15, 12.63s/it][A

38 12.320218324661255



 32%|█████████████▋                            | 39/120 [08:12<17:07, 12.68s/it][A

39 12.812465190887451



 33%|██████████████                            | 40/120 [08:25<16:55, 12.69s/it][A

40 12.708309888839722



 34%|██████████████▎                           | 41/120 [08:38<16:42, 12.69s/it][A

41 12.683748245239258



 35%|██████████████▋                           | 42/120 [08:50<16:27, 12.67s/it][A

42 12.612451791763306



 36%|███████████████                           | 43/120 [09:03<16:19, 12.71s/it][A

43 12.827600717544556



 37%|███████████████▍                          | 44/120 [09:15<15:58, 12.61s/it][A

44 12.36441707611084



 38%|███████████████▊                          | 45/120 [09:29<15:56, 12.75s/it][A

45 13.089098691940308



 38%|████████████████                          | 46/120 [09:41<15:42, 12.74s/it][A

46 12.705365180969238



 39%|████████████████▍                         | 47/120 [09:54<15:29, 12.73s/it][A

47 12.7131667137146



 40%|████████████████▊                         | 48/120 [10:07<15:14, 12.70s/it][A

48 12.634749412536621



 41%|█████████████████▏                        | 49/120 [10:19<14:59, 12.66s/it][A

49 12.569421529769897



 42%|█████████████████▌                        | 50/120 [10:32<14:46, 12.66s/it][A

50 12.66685438156128



 42%|█████████████████▊                        | 51/120 [10:45<14:34, 12.68s/it][A

51 12.719392538070679



 43%|██████████████████▏                       | 52/120 [10:57<14:21, 12.67s/it][A

52 12.632096767425537



 44%|██████████████████▌                       | 53/120 [11:10<14:08, 12.67s/it][A

53 12.666780948638916



 45%|██████████████████▉                       | 54/120 [11:23<13:58, 12.71s/it][A

54 12.795579195022583



 46%|███████████████████▎                      | 55/120 [11:35<13:32, 12.51s/it][A

55 12.045909643173218



 47%|███████████████████▌                      | 56/120 [11:48<13:32, 12.70s/it][A

56 13.135114669799805



 48%|███████████████████▉                      | 57/120 [12:00<13:13, 12.60s/it][A

57 12.375718593597412



 48%|████████████████████▎                     | 58/120 [12:13<13:01, 12.60s/it][A

58 12.59277081489563



 49%|████████████████████▋                     | 59/120 [12:25<12:43, 12.52s/it][A

59 12.336529731750488



 50%|█████████████████████                     | 60/120 [12:38<12:32, 12.55s/it][A

60 12.616849184036255



 51%|█████████████████████▎                    | 61/120 [12:50<12:12, 12.41s/it][A

61 12.084375143051147



 52%|█████████████████████▋                    | 62/120 [13:03<12:07, 12.54s/it][A

62 12.833566427230835



 52%|██████████████████████                    | 63/120 [13:15<11:51, 12.48s/it][A

63 12.335754156112671



 53%|██████████████████████▍                   | 64/120 [13:28<11:39, 12.49s/it][A

64 12.532201528549194



 54%|██████████████████████▊                   | 65/120 [13:40<11:28, 12.51s/it][A

65 12.560791492462158



 55%|███████████████████████                   | 66/120 [13:53<11:13, 12.48s/it][A

66 12.405386686325073



 56%|███████████████████████▍                  | 67/120 [14:05<11:02, 12.50s/it][A

67 12.548506736755371



 57%|███████████████████████▊                  | 68/120 [14:18<10:53, 12.57s/it][A

68 12.716951131820679



 57%|████████████████████████▏                 | 69/120 [14:30<10:41, 12.58s/it][A

69 12.616660594940186



 58%|████████████████████████▌                 | 70/120 [14:43<10:25, 12.51s/it][A

70 12.35849404335022



 59%|████████████████████████▊                 | 71/120 [14:55<10:15, 12.55s/it][A

71 12.641762971878052



 60%|█████████████████████████▏                | 72/120 [15:08<10:01, 12.53s/it][A

72 12.493408679962158



 61%|█████████████████████████▌                | 73/120 [15:20<09:46, 12.47s/it][A

73 12.32627272605896



 62%|█████████████████████████▉                | 74/120 [15:33<09:39, 12.59s/it][A

74 12.860302209854126



 62%|██████████████████████████▎               | 75/120 [15:45<09:22, 12.50s/it][A

75 12.287018060684204



 63%|██████████████████████████▌               | 76/120 [15:58<09:16, 12.64s/it][A

76 12.968978643417358



 64%|██████████████████████████▉               | 77/120 [16:11<09:03, 12.65s/it][A

77 12.659067392349243



 65%|███████████████████████████▎              | 78/120 [16:23<08:44, 12.48s/it][A

78 12.086493968963623



 66%|███████████████████████████▋              | 79/120 [16:36<08:38, 12.65s/it][A

79 13.055558443069458



 67%|████████████████████████████              | 80/120 [16:48<08:22, 12.55s/it][A

80 12.318454265594482



 68%|████████████████████████████▎             | 81/120 [17:01<08:13, 12.64s/it][A

81 12.858228206634521



 68%|████████████████████████████▋             | 82/120 [17:13<07:54, 12.49s/it][A

82 12.137702703475952



 69%|█████████████████████████████             | 83/120 [17:26<07:47, 12.63s/it][A

83 12.943719863891602



 70%|█████████████████████████████▍            | 84/120 [17:39<07:35, 12.64s/it][A

84 12.671106576919556



 71%|█████████████████████████████▊            | 85/120 [17:51<07:16, 12.47s/it][A

85 12.07816743850708



 72%|██████████████████████████████            | 86/120 [18:04<07:10, 12.67s/it][A

86 13.149263858795166



 72%|██████████████████████████████▍           | 87/120 [18:17<06:56, 12.62s/it][A

87 12.47725796699524



 73%|██████████████████████████████▊           | 88/120 [18:29<06:42, 12.57s/it][A

88 12.461882591247559



 74%|███████████████████████████████▏          | 89/120 [18:42<06:32, 12.67s/it][A

89 12.89389681816101



 75%|███████████████████████████████▌          | 90/120 [18:54<06:16, 12.55s/it][A

90 12.268596172332764



 76%|███████████████████████████████▊          | 91/120 [19:07<06:06, 12.65s/it][A

91 12.89754843711853



 77%|████████████████████████████████▏         | 92/120 [19:20<05:51, 12.56s/it][A

92 12.3559091091156



 78%|████████████████████████████████▌         | 93/120 [19:33<05:41, 12.66s/it][A

93 12.869072437286377



 78%|████████████████████████████████▉         | 94/120 [19:45<05:27, 12.61s/it][A

94 12.509981870651245



 79%|█████████████████████████████████▎        | 95/120 [19:58<05:16, 12.66s/it][A

95 12.776466131210327



 80%|█████████████████████████████████▌        | 96/120 [20:10<05:00, 12.51s/it][A

96 12.171952962875366



 81%|█████████████████████████████████▉        | 97/120 [20:23<04:50, 12.62s/it][A

97 12.86966848373413



 82%|██████████████████████████████████▎       | 98/120 [20:36<04:38, 12.67s/it][A

98 12.782849073410034



 82%|██████████████████████████████████▋       | 99/120 [20:48<04:25, 12.63s/it][A

99 12.547441482543945



 83%|██████████████████████████████████▏      | 100/120 [21:01<04:11, 12.55s/it][A

100 12.356653451919556



 84%|██████████████████████████████████▌      | 101/120 [21:14<04:00, 12.67s/it][A

101 12.953816890716553



 85%|██████████████████████████████████▊      | 102/120 [21:27<03:50, 12.80s/it][A

102 13.096081972122192



 86%|███████████████████████████████████▏     | 103/120 [21:39<03:36, 12.71s/it][A

103 12.489439249038696



 87%|███████████████████████████████████▌     | 104/120 [21:52<03:22, 12.65s/it][A

104 12.517631769180298



 88%|███████████████████████████████████▉     | 105/120 [22:04<03:10, 12.68s/it][A

105 12.747545957565308



 88%|████████████████████████████████████▏    | 106/120 [22:17<02:58, 12.74s/it][A

106 12.872883796691895



 89%|████████████████████████████████████▌    | 107/120 [22:30<02:44, 12.67s/it][A

107 12.497098207473755



 90%|████████████████████████████████████▉    | 108/120 [22:42<02:31, 12.67s/it][A

108 12.664837837219238



 91%|█████████████████████████████████████▏   | 109/120 [22:55<02:17, 12.50s/it][A

109 12.10818886756897



 92%|█████████████████████████████████████▌   | 110/120 [23:08<02:06, 12.67s/it][A

110 13.067282676696777



 92%|█████████████████████████████████████▉   | 111/120 [23:20<01:53, 12.64s/it][A

111 12.586446046829224



 93%|██████████████████████████████████████▎  | 112/120 [23:33<01:40, 12.62s/it][A

112 12.579440832138062



 94%|██████████████████████████████████████▌  | 113/120 [23:45<01:27, 12.53s/it][A

113 12.307660818099976



 95%|██████████████████████████████████████▉  | 114/120 [23:58<01:15, 12.59s/it][A

114 12.722549676895142



 96%|███████████████████████████████████████▎ | 115/120 [24:10<01:03, 12.62s/it][A

115 12.68341326713562



 97%|███████████████████████████████████████▋ | 116/120 [24:23<00:50, 12.64s/it][A

116 12.693565130233765



 98%|███████████████████████████████████████▉ | 117/120 [24:36<00:37, 12.57s/it][A

117 12.41724967956543



 98%|████████████████████████████████████████▎| 118/120 [24:48<00:25, 12.63s/it][A

118 12.758829355239868



 99%|████████████████████████████████████████▋| 119/120 [25:01<00:12, 12.65s/it][A

119 12.683225631713867



100%|█████████████████████████████████████████| 120/120 [25:14<00:00, 12.62s/it][A


Done, saving (120, 378, 504, 3) (120, 378, 504)




test poses shape torch.Size([3, 3, 4])



  0%|                                                     | 0/3 [00:00<?, ?it/s][A

0 0.0024840831756591797
torch.Size([378, 504, 3]) torch.Size([378, 504])



 33%|███████████████                              | 1/3 [00:12<00:25, 12.84s/it][A

1 12.840915441513062



 67%|██████████████████████████████               | 2/3 [00:25<00:12, 12.91s/it][A

2 12.951102018356323



100%|█████████████████████████████████████████████| 3/3 [00:38<00:00, 12.83s/it][A
 50%|████████████████▌                | 5001/10000 [40:10<453:25:55, 326.54s/it]

Saved test set
[TRAIN] Iter: 5000 Loss: 0.012368401512503624  PSNR: 22.055004119873047


 51%|███████████████████▍                  | 5101/10000 [40:27<13:40,  5.97it/s]

[TRAIN] Iter: 5100 Loss: 0.011255143210291862  PSNR: 22.595870971679688


 52%|███████████████████▊                  | 5201/10000 [40:44<13:33,  5.90it/s]

[TRAIN] Iter: 5200 Loss: 0.01055305078625679  PSNR: 22.93203353881836


 53%|████████████████████▏                 | 5301/10000 [41:01<13:10,  5.94it/s]

[TRAIN] Iter: 5300 Loss: 0.010738449171185493  PSNR: 22.5752010345459


 54%|████████████████████▌                 | 5401/10000 [41:18<12:39,  6.05it/s]

[TRAIN] Iter: 5400 Loss: 0.010279020294547081  PSNR: 23.082746505737305


 55%|████████████████████▉                 | 5501/10000 [41:35<12:43,  5.89it/s]

[TRAIN] Iter: 5500 Loss: 0.010268410667777061  PSNR: 22.79636001586914


 56%|█████████████████████▎                | 5601/10000 [41:52<12:59,  5.64it/s]

[TRAIN] Iter: 5600 Loss: 0.01105324737727642  PSNR: 22.578779220581055


 57%|█████████████████████▋                | 5701/10000 [42:09<11:53,  6.03it/s]

[TRAIN] Iter: 5700 Loss: 0.009769359603524208  PSNR: 23.163131713867188


 58%|██████████████████████                | 5801/10000 [42:26<12:26,  5.62it/s]

[TRAIN] Iter: 5800 Loss: 0.010118504986166954  PSNR: 23.283203125


 59%|██████████████████████▍               | 5901/10000 [42:43<11:28,  5.95it/s]

[TRAIN] Iter: 5900 Loss: 0.010267267003655434  PSNR: 23.23287582397461


 60%|██████████████████████▊               | 6001/10000 [43:00<11:22,  5.86it/s]

[TRAIN] Iter: 6000 Loss: 0.010047484189271927  PSNR: 23.26374053955078


 61%|███████████████████████▏              | 6101/10000 [43:17<10:42,  6.07it/s]

[TRAIN] Iter: 6100 Loss: 0.009225692600011826  PSNR: 23.507343292236328


 62%|███████████████████████▌              | 6201/10000 [43:34<10:52,  5.83it/s]

[TRAIN] Iter: 6200 Loss: 0.010416384786367416  PSNR: 23.054367065429688


 63%|███████████████████████▉              | 6301/10000 [43:51<10:09,  6.07it/s]

[TRAIN] Iter: 6300 Loss: 0.011058477684855461  PSNR: 22.705442428588867


 63%|████████████████████████              | 6326/10000 [43:55<10:24,  5.88it/s]

Shuffle data after an epoch!


 64%|████████████████████████▎             | 6401/10000 [44:07<10:11,  5.88it/s]

[TRAIN] Iter: 6400 Loss: 0.011600823141634464  PSNR: 22.529869079589844


 65%|████████████████████████▋             | 6501/10000 [44:24<10:15,  5.68it/s]

[TRAIN] Iter: 6500 Loss: 0.010250515304505825  PSNR: 22.883777618408203


 66%|█████████████████████████             | 6601/10000 [44:41<09:33,  5.92it/s]

[TRAIN] Iter: 6600 Loss: 0.009660433977842331  PSNR: 23.39360809326172


 67%|█████████████████████████▍            | 6701/10000 [44:58<09:04,  6.05it/s]

[TRAIN] Iter: 6700 Loss: 0.010124973952770233  PSNR: 23.11298179626465


 68%|█████████████████████████▊            | 6801/10000 [45:15<08:42,  6.12it/s]

[TRAIN] Iter: 6800 Loss: 0.011566266417503357  PSNR: 22.47652816772461


 69%|██████████████████████████▏           | 6901/10000 [45:32<09:00,  5.74it/s]

[TRAIN] Iter: 6900 Loss: 0.011742553673684597  PSNR: 22.48265838623047


 70%|██████████████████████████▌           | 7001/10000 [45:49<08:13,  6.08it/s]

[TRAIN] Iter: 7000 Loss: 0.009661424905061722  PSNR: 23.233057022094727


 71%|██████████████████████████▉           | 7101/10000 [46:06<08:22,  5.76it/s]

[TRAIN] Iter: 7100 Loss: 0.009099001064896584  PSNR: 23.776704788208008


 72%|███████████████████████████▎          | 7201/10000 [46:23<07:29,  6.23it/s]

[TRAIN] Iter: 7200 Loss: 0.010432284325361252  PSNR: 23.014881134033203


 73%|███████████████████████████▋          | 7301/10000 [46:40<08:03,  5.59it/s]

[TRAIN] Iter: 7300 Loss: 0.009828804060816765  PSNR: 23.222047805786133


 74%|████████████████████████████          | 7401/10000 [46:57<07:49,  5.53it/s]

[TRAIN] Iter: 7400 Loss: 0.010899966582655907  PSNR: 22.806350708007812


 75%|████████████████████████████▌         | 7501/10000 [47:13<06:41,  6.23it/s]

[TRAIN] Iter: 7500 Loss: 0.007836705073714256  PSNR: 24.11112403869629


 76%|████████████████████████████▉         | 7601/10000 [47:30<06:49,  5.85it/s]

[TRAIN] Iter: 7600 Loss: 0.008381135761737823  PSNR: 23.789772033691406


 77%|█████████████████████████████▎        | 7701/10000 [47:47<06:15,  6.13it/s]

[TRAIN] Iter: 7700 Loss: 0.009790447540581226  PSNR: 23.186330795288086


 78%|█████████████████████████████▋        | 7801/10000 [48:04<06:24,  5.72it/s]

[TRAIN] Iter: 7800 Loss: 0.011077772825956345  PSNR: 22.65816879272461


 79%|██████████████████████████████        | 7901/10000 [48:21<06:23,  5.48it/s]

[TRAIN] Iter: 7900 Loss: 0.010064703412353992  PSNR: 22.92862892150879


 80%|██████████████████████████████▍       | 8001/10000 [48:40<05:49,  5.72it/s]

[TRAIN] Iter: 8000 Loss: 0.00998847745358944  PSNR: 23.114898681640625


 81%|██████████████████████████████▊       | 8101/10000 [48:58<05:15,  6.02it/s]

[TRAIN] Iter: 8100 Loss: 0.010057785548269749  PSNR: 23.074663162231445


 82%|███████████████████████████████▏      | 8201/10000 [49:15<05:32,  5.41it/s]

[TRAIN] Iter: 8200 Loss: 0.008635971695184708  PSNR: 23.566207885742188


 83%|███████████████████████████████▌      | 8301/10000 [49:32<04:50,  5.85it/s]

[TRAIN] Iter: 8300 Loss: 0.010032523423433304  PSNR: 23.083147048950195


 84%|███████████████████████████████▉      | 8401/10000 [49:49<04:28,  5.95it/s]

[TRAIN] Iter: 8400 Loss: 0.009883929044008255  PSNR: 23.373659133911133


 85%|████████████████████████████████▎     | 8501/10000 [50:06<04:03,  6.15it/s]

[TRAIN] Iter: 8500 Loss: 0.010200629010796547  PSNR: 22.991281509399414


 86%|████████████████████████████████▋     | 8601/10000 [50:23<04:11,  5.56it/s]

[TRAIN] Iter: 8600 Loss: 0.009019998833537102  PSNR: 23.6365909576416


 87%|█████████████████████████████████     | 8701/10000 [50:40<03:38,  5.95it/s]

[TRAIN] Iter: 8700 Loss: 0.00982271134853363  PSNR: 23.272703170776367


 88%|█████████████████████████████████▍    | 8801/10000 [50:57<03:15,  6.12it/s]

[TRAIN] Iter: 8800 Loss: 0.009359730407595634  PSNR: 23.372072219848633


 89%|█████████████████████████████████▊    | 8901/10000 [51:13<03:01,  6.06it/s]

[TRAIN] Iter: 8900 Loss: 0.008702588267624378  PSNR: 23.927644729614258


 90%|██████████████████████████████████▏   | 9001/10000 [51:31<02:53,  5.77it/s]

[TRAIN] Iter: 9000 Loss: 0.009388039819896221  PSNR: 23.335432052612305


 91%|██████████████████████████████████▌   | 9101/10000 [51:47<02:26,  6.12it/s]

[TRAIN] Iter: 9100 Loss: 0.009682826697826385  PSNR: 23.251266479492188


 92%|██████████████████████████████████▉   | 9201/10000 [52:05<02:18,  5.76it/s]

[TRAIN] Iter: 9200 Loss: 0.008599881082773209  PSNR: 23.770185470581055


 93%|███████████████████████████████████▎  | 9301/10000 [52:21<01:56,  6.02it/s]

[TRAIN] Iter: 9300 Loss: 0.008780429139733315  PSNR: 23.85822868347168


 94%|███████████████████████████████████▋  | 9401/10000 [52:38<01:37,  6.14it/s]

[TRAIN] Iter: 9400 Loss: 0.009809847921133041  PSNR: 23.394330978393555


 95%|████████████████████████████████████  | 9489/10000 [52:53<01:22,  6.22it/s]

Shuffle data after an epoch!


 95%|████████████████████████████████████  | 9501/10000 [52:55<01:20,  6.17it/s]

[TRAIN] Iter: 9500 Loss: 0.009415103122591972  PSNR: 23.45441246032715


 96%|████████████████████████████████████▍ | 9601/10000 [53:12<01:09,  5.71it/s]

[TRAIN] Iter: 9600 Loss: 0.008371840231120586  PSNR: 23.89464569091797


 97%|████████████████████████████████████▊ | 9701/10000 [53:29<00:50,  5.98it/s]

[TRAIN] Iter: 9700 Loss: 0.009737253189086914  PSNR: 23.08979606628418


 98%|█████████████████████████████████████▏| 9801/10000 [53:46<00:34,  5.79it/s]

[TRAIN] Iter: 9800 Loss: 0.009470894932746887  PSNR: 23.491466522216797


 99%|█████████████████████████████████████▌| 9901/10000 [54:03<00:19,  5.08it/s]

[TRAIN] Iter: 9900 Loss: 0.00947138387709856  PSNR: 23.256811141967773


100%|█████████████████████████████████████▉| 9999/10000 [54:20<00:00,  5.89it/s]

Saved checkpoints at ./logs/fern_test/010000.tar



  0%|                                                   | 0/120 [00:00<?, ?it/s][A

0 0.0036568641662597656



  1%|▎                                          | 1/120 [00:12<25:38, 12.93s/it][A

torch.Size([378, 504, 3]) torch.Size([378, 504])
1 12.929940223693848



  2%|▋                                          | 2/120 [00:25<24:58, 12.70s/it][A

2 12.542730569839478



  2%|█                                          | 3/120 [00:38<24:37, 12.63s/it][A

3 12.542699575424194



  3%|█▍                                         | 4/120 [00:50<24:27, 12.65s/it][A

4 12.690171718597412



  4%|█▊                                         | 5/120 [01:02<23:51, 12.45s/it][A

5 12.083553314208984



  5%|██▏                                        | 6/120 [01:15<23:55, 12.59s/it][A

6 12.873130798339844



  6%|██▌                                        | 7/120 [01:28<23:45, 12.61s/it][A

7 12.648749589920044



  7%|██▊                                        | 8/120 [01:40<23:15, 12.46s/it][A

8 12.137655019760132



  8%|███▏                                       | 9/120 [01:53<23:21, 12.62s/it][A

9 12.98322582244873



  8%|███▌                                      | 10/120 [02:05<22:47, 12.43s/it][A

10 12.010024070739746



  9%|███▊                                      | 11/120 [02:18<22:57, 12.64s/it][A

11 13.111683368682861



 10%|████▏                                     | 12/120 [02:30<22:38, 12.57s/it][A

12 12.421513557434082



 11%|████▌                                     | 13/120 [02:43<22:33, 12.65s/it][A

13 12.820330381393433



 12%|████▉                                     | 14/120 [02:57<22:44, 12.87s/it][A

14 13.388191938400269



 12%|█████▎                                    | 15/120 [03:09<22:27, 12.84s/it][A

15 12.75386929512024



 13%|█████▌                                    | 16/120 [03:22<22:19, 12.88s/it][A

16 12.96521806716919



 14%|█████▉                                    | 17/120 [03:36<22:21, 13.03s/it][A

17 13.377270460128784



 15%|██████▎                                   | 18/120 [03:50<22:33, 13.26s/it][A

18 13.820388555526733



 16%|██████▋                                   | 19/120 [04:03<22:16, 13.24s/it][A

19 13.169184684753418



 17%|███████                                   | 20/120 [04:17<22:26, 13.46s/it][A

20 13.994834423065186



 18%|███████▎                                  | 21/120 [04:31<22:25, 13.59s/it][A

21 13.898075580596924



 18%|███████▋                                  | 22/120 [04:43<21:35, 13.22s/it][A

22 12.336710214614868



 19%|████████                                  | 23/120 [04:56<21:15, 13.15s/it][A

23 12.994619846343994



 20%|████████▍                                 | 24/120 [05:09<20:58, 13.11s/it][A

24 13.012344598770142



 21%|████████▊                                 | 25/120 [05:22<20:30, 12.95s/it][A

25 12.573134422302246



 22%|█████████                                 | 26/120 [05:34<20:08, 12.86s/it][A

26 12.6465482711792



 22%|█████████▍                                | 27/120 [05:47<19:55, 12.85s/it][A

27 12.840025424957275



 23%|█████████▊                                | 28/120 [06:00<19:45, 12.89s/it][A

28 12.981485605239868



 24%|██████████▏                               | 29/120 [06:13<19:33, 12.89s/it][A

29 12.88927960395813



 25%|██████████▌                               | 30/120 [06:26<19:20, 12.89s/it][A

30 12.905054807662964



 26%|██████████▊                               | 31/120 [06:39<19:08, 12.90s/it][A

31 12.90826964378357



 27%|███████████▏                              | 32/120 [06:51<18:43, 12.77s/it][A

32 12.453830480575562



 28%|███████████▌                              | 33/120 [07:04<18:35, 12.82s/it][A

33 12.945426225662231



 28%|███████████▉                              | 34/120 [07:16<18:10, 12.68s/it][A

34 12.339314699172974



 29%|████████████▎                             | 35/120 [07:29<18:01, 12.72s/it][A

35 12.832141876220703



 30%|████████████▌                             | 36/120 [07:42<17:45, 12.68s/it][A

36 12.587762355804443



 31%|████████████▉                             | 37/120 [07:55<17:49, 12.88s/it][A

37 13.346715211868286



 32%|█████████████▎                            | 38/120 [08:08<17:41, 12.94s/it][A

38 13.091004371643066



 32%|█████████████▋                            | 39/120 [08:21<17:32, 12.99s/it][A

39 13.111460208892822



 33%|██████████████                            | 40/120 [08:34<17:20, 13.01s/it][A

40 13.032986640930176



 34%|██████████████▎                           | 41/120 [08:47<17:05, 12.98s/it][A

41 12.912123918533325



 35%|██████████████▋                           | 42/120 [09:00<16:48, 12.93s/it][A

42 12.825309991836548



 36%|███████████████                           | 43/120 [09:13<16:36, 12.95s/it][A

43 12.984002113342285



 37%|███████████████▍                          | 44/120 [09:26<16:27, 12.99s/it][A

44 13.099226236343384



 38%|███████████████▊                          | 45/120 [09:39<16:18, 13.05s/it][A

45 13.176846504211426



 38%|████████████████                          | 46/120 [09:53<16:06, 13.06s/it][A

46 13.081445932388306



 39%|████████████████▍                         | 47/120 [10:05<15:47, 12.98s/it][A

47 12.784143447875977



 40%|████████████████▊                         | 48/120 [10:18<15:32, 12.95s/it][A

48 12.8732750415802



 41%|█████████████████▏                        | 49/120 [10:31<15:18, 12.94s/it][A

49 12.918467998504639



 42%|█████████████████▌                        | 50/120 [10:44<15:12, 13.04s/it][A

50 13.272547483444214



 42%|█████████████████▊                        | 51/120 [10:58<15:04, 13.11s/it][A

51 13.270893812179565



 43%|██████████████████▏                       | 52/120 [11:11<15:00, 13.25s/it][A

52 13.56804370880127



 44%|██████████████████▌                       | 53/120 [11:24<14:43, 13.19s/it][A

53 13.068780660629272



 45%|██████████████████▉                       | 54/120 [11:38<14:31, 13.20s/it][A

54 13.226862668991089



 46%|███████████████████▎                      | 55/120 [11:50<14:04, 12.99s/it][A

55 12.504000663757324



 47%|███████████████████▌                      | 56/120 [12:03<13:54, 13.03s/it][A

56 13.120784044265747



 48%|███████████████████▉                      | 57/120 [12:16<13:42, 13.05s/it][A

57 13.086625576019287



 48%|████████████████████▎                     | 58/120 [12:29<13:19, 12.90s/it][A

58 12.556792974472046



 49%|████████████████████▋                     | 59/120 [12:42<13:03, 12.84s/it][A

59 12.688363313674927



 50%|█████████████████████                     | 60/120 [12:54<12:47, 12.79s/it][A

60 12.686896562576294



 51%|█████████████████████▎                    | 61/120 [13:07<12:40, 12.89s/it][A

61 13.119068622589111



 52%|█████████████████████▋                    | 62/120 [13:20<12:24, 12.84s/it][A

62 12.719262599945068



 52%|██████████████████████                    | 63/120 [13:33<12:11, 12.83s/it][A

63 12.817808389663696



 53%|██████████████████████▍                   | 64/120 [13:46<12:02, 12.89s/it][A

64 13.036176443099976



 54%|██████████████████████▊                   | 65/120 [13:58<11:41, 12.76s/it][A

65 12.444355487823486



 55%|███████████████████████                   | 66/120 [14:12<11:39, 12.96s/it][A

66 13.435872793197632



 56%|███████████████████████▍                  | 67/120 [14:26<11:39, 13.20s/it][A

67 13.746738195419312



 57%|███████████████████████▊                  | 68/120 [14:39<11:33, 13.33s/it][A

68 13.643450021743774



 57%|████████████████████████▏                 | 69/120 [14:53<11:27, 13.48s/it][A

69 13.814877271652222



 58%|████████████████████████▌                 | 70/120 [15:06<11:12, 13.45s/it][A

70 13.380188465118408



 59%|████████████████████████▊                 | 71/120 [15:20<11:06, 13.60s/it][A

71 13.963529586791992



 60%|█████████████████████████▏                | 72/120 [15:33<10:44, 13.43s/it][A

72 13.037840366363525
