In [None]:
!pip install --upgrade --target=temp_pip git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

In [None]:
!pip install --upgrade --target=temp_pip commentjson

In [None]:
import cv2
from math import log10, sqrt
from skimage.metrics import structural_similarity as ssim

def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

def SSIM(original, compressed):
    return ssim(original, compressed, channel_axis=-1)

In [None]:
import sys
sys.path.insert(0, "/global/u1/j/jswomley/tiny-cuda-nn/temp_pip")

try:
	import tinycudann as tcnn
except ImportError:
	print("This sample requires the tiny-cuda-nn extension for PyTorch.")
	print("You can install it by running:")
	print("============================================================")
	print("tiny-cuda-nn$ cd bindings/torch")
	print("tiny-cuda-nn/bindings/torch$ python setup.py install")
	print("============================================================")
	sys.exit()

SCRIPTS_DIR = "/global/u1/j/jswomley/tiny-cuda-nn/scripts" #os.path.join(os.path.dirname(os.path.dirname(__file__)), "scripts")
sys.path.insert(0, SCRIPTS_DIR)
print(sys.path)    
    
import argparse
import commentjson as json
import numpy as np
import os
import sys
import torch
import time

In [None]:
from common import read_image, write_image, ROOT_DIR

In [None]:
class original_Image(torch.nn.Module):
	def __init__(self, filename, device):
		super(original_Image, self).__init__()
		self.data = read_image(filename)
		self.shape = self.data.shape
		self.data = torch.from_numpy(self.data).float().to(device)

	def forward(self, xs):
		with torch.no_grad():
			# Bilinearly filtered lookup from the image. Not super fast,
			# but less than ~20% of the overall runtime of this example.
			shape = self.shape

			xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float()
			indices = xs.long()
			lerp_weights = xs - indices.float()

			x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
			y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
			x1 = (x0 + 1).clamp(max=shape[1]-1)
			y1 = (y0 + 1).clamp(max=shape[0]-1)

			return (
				self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
				self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
				self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
				self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
			)


class Image(torch.nn.Module):
	def __init__(self, filename, device, x_coord, y_coord, width, height):
		super(Image, self).__init__()
		self.data = read_image(filename)
		self.shape = self.data.shape
		self.data = torch.from_numpy(self.data).float().to(device)
		self.x_coord = x_coord
		self.y_coord = y_coord
		self.width = width
		self.height = height

	def forward(self, xs):
		with torch.no_grad():
			# Bilinearly filtered lookup from the image. Not super fast,
			# but less than ~20% of the overall runtime of this example.
			# shape = (self.x_shape, self.y_shape, 3)
			shape = self.shape
            
			reverse_output = xs
			xs = xs * torch.tensor([shape[1], shape[0]], device=xs.device).float()
			indices = xs.long()
			lerp_weights = xs - indices.float()

			x0 = indices[:, 0].clamp(min=0, max=shape[1]-1)
			y0 = indices[:, 1].clamp(min=0, max=shape[0]-1)
			x1 = (x0 + 1).clamp(max=shape[1]-1)
			y1 = (y0 + 1).clamp(max=shape[0]-1)
            
			out = (
				self.data[y0, x0] * (1.0 - lerp_weights[:,0:1]) * (1.0 - lerp_weights[:,1:2]) +
				self.data[y0, x1] * lerp_weights[:,0:1] * (1.0 - lerp_weights[:,1:2]) +
				self.data[y1, x0] * (1.0 - lerp_weights[:,0:1]) * lerp_weights[:,1:2] +
				self.data[y1, x1] * lerp_weights[:,0:1] * lerp_weights[:,1:2]
			)
            
            # For width/height only
			# x_vect = torch.where(x1 <= self.width, True, False)
			# y_vect = torch.where(y1 <= self.height, True, False)
			# vect = torch.logical_and(x_vect, y_vect)
            
            
			x0_vect = torch.where(x0 >= self.x_coord, True, False)
			y0_vect = torch.where(y0 >= self.y_coord, True, False)
			x1_vect = torch.where(x1 <= self.width + self.x_coord, True, False)
			y1_vect = torch.where(y1 <= self.height + self.y_coord, True, False)
            
			vect = torch.logical_and(torch.logical_and(x0_vect, y0_vect), torch.logical_and(x1_vect, y1_vect))
            
			out = out[vect, :]
			reverse_output = reverse_output[vect, :]
			
			return (out, reverse_output)

In [None]:
tcnn.free_temporary_memory()

In [None]:
# # Real ROI script

# x_coord = 0; y_coord = 0; width=1000; height=1000

# class Args:
#     pass
# args = Args()
# args.image="/global/u1/j/jswomley/tiny-cuda-nn/data/images/hubble.jpg"
# args.image_roi="/global/u1/j/jswomley/tiny-cuda-nn/hubble_roi.jpg"
# args.config="/global/u1/j/jswomley/tiny-cuda-nn/config_hash.json"
# args.n_steps=10000
# args.result_filename="/global/u1/j/jswomley/tiny-cuda-nn/result.png"


# device = torch.device("cuda")

# with open(args.config) as config_file:
#     config = json.load(config_file)

# image = Image(args.image, device, x_coord, y_coord, width, height)
# n_channels = image.data.shape[2]

# model = tcnn.NetworkWithInputEncoding(n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]).to(device)
# # print(model)

# #===================================================================================================
# # The following is equivalent to the above, but slower. Only use "naked" tcnn.Encoding and
# # tcnn.Network when you don't want to combine them. Otherwise, use tcnn.NetworkWithInputEncoding.
# #===================================================================================================
# # encoding = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])
# # network = tcnn.Network(n_input_dims=encoding.n_output_dims, n_output_dims=n_channels, network_config=config["network"])
# # model = torch.nn.Sequential(encoding, network)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# # Variables for saving/displaying image results
# resolution = image.data.shape[0:2]

# cut_resolution = torch.Size([height, width])
# img_shape = cut_resolution + torch.Size([image.data.shape[2]])
# real_img_shape = resolution + torch.Size([image.data.shape[2]])

# n_pixels = resolution[0] * resolution[1]

# half_dx =  0.5 / resolution[0]
# half_dy =  0.5 / resolution[1]
# xs = torch.linspace(half_dx, 1-half_dx, resolution[0], device=device)
# ys = torch.linspace(half_dy, 1-half_dy, resolution[1], device=device)
# xv, yv = torch.meshgrid([xs, ys])

# xy = torch.stack((yv.flatten(), xv.flatten())).t()

# path = f"reference.jpg"
# print(f"Writing '{path}'... ", end="")
# write_image(path, image(xy)[0].reshape(img_shape).detach().cpu().numpy())
# print("done.")

# prev_time = time.perf_counter()

# batch_size = 2**18
# interval = 10

# print(f"Beginning optimization with {args.n_steps} training steps.")

# try:
#     batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
#     traced_image = torch.jit.trace(image, batch)
# except:
#     # If tracing causes an error, fall back to regular execution
#     print(f"WARNING: PyTorch JIT trace failed. Performance will be slightly worse than regular.")
#     traced_image = image


# for i in range(args.n_steps):
#     batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
#     targets = traced_image(batch)[0]
#     output = model(traced_image(batch)[1])

#     relative_l2_error = (output - targets.to(output.dtype))**2 / (output.detach()**2 + 0.01)
#     loss = relative_l2_error.mean()

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     if i % interval == 0:
#         loss_val = loss.item()
#         torch.cuda.synchronize()
#         elapsed_time = time.perf_counter() - prev_time
#         print(f"Step#{i}: loss={loss_val} time={int(elapsed_time*1000000)}[µs]")

#         path = f"{i}.jpg"
#         print(f"Writing '{path}'... ", end="")
#         with torch.no_grad():
#             write_image(path, model(image(xy)[1]).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
#         print("done.")

#         # Ignore the time spent saving the image
#         prev_time = time.perf_counter()

#         if i > 0 and interval < 1000:
#             interval *= 10

In [None]:
# #Original

# class Args:
#     pass
# args = Args()
# args.image="/global/u1/j/jswomley/tiny-cuda-nn/data/images/hubble.jpg"
# args.image_roi="/global/u1/j/jswomley/tiny-cuda-nn/hubble_roi.jpg"
# args.config="/global/u1/j/jswomley/tiny-cuda-nn/config_hash.json"
# args.n_steps=10000
# args.result_filename="/global/u1/j/jswomley/tiny-cuda-nn/result.png"

# device = torch.device("cuda")

# with open(args.config) as config_file:
#     config = json.load(config_file)

# image = Image(args.image, device)
# n_channels = image.data.shape[2]

# model = tcnn.NetworkWithInputEncoding(n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]).to(device)
# # print(model)

# #===================================================================================================
# # The following is equivalent to the above, but slower. Only use "naked" tcnn.Encoding and
# # tcnn.Network when you don't want to combine them. Otherwise, use tcnn.NetworkWithInputEncoding.
# #===================================================================================================
# # encoding = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])
# # network = tcnn.Network(n_input_dims=encoding.n_output_dims, n_output_dims=n_channels, network_config=config["network"])
# # model = torch.nn.Sequential(encoding, network)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# # Variables for saving/displaying image results
# resolution = image.data.shape[0:2]
# img_shape = resolution + torch.Size([image.data.shape[2]])
# n_pixels = resolution[0] * resolution[1]

# half_dx =  0.5 / resolution[0]
# half_dy =  0.5 / resolution[1]
# xs = torch.linspace(half_dx, 1-half_dx, resolution[0], device=device)
# ys = torch.linspace(half_dy, 1-half_dy, resolution[1], device=device)
# xv, yv = torch.meshgrid([xs, ys])

# xy = torch.stack((yv.flatten(), xv.flatten())).t()

# path = f"reference.jpg"
# print(f"Writing '{path}'... ", end="")
# write_image(path, image(xy).reshape(img_shape).detach().cpu().numpy())
# print("done.")

# prev_time = time.perf_counter()

# batch_size = 2**18
# interval = 10

# print(f"Beginning optimization with {args.n_steps} training steps.")

# try:
#     batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
#     traced_image = torch.jit.trace(image, batch)
# except:
#     # If tracing causes an error, fall back to regular execution
#     print(f"WARNING: PyTorch JIT trace failed. Performance will be slightly worse than regular.")
#     traced_image = image


# for i in range(args.n_steps):
#     batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
#     targets = traced_image(batch)
#     output = model(batch)

#     relative_l2_error = (output - targets.to(output.dtype))**2 / (output.detach()**2 + 0.01)
#     loss = relative_l2_error.mean()

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     if i % interval == 0:
#         loss_val = loss.item()
#         torch.cuda.synchronize()
#         elapsed_time = time.perf_counter() - prev_time
#         print(f"Step#{i}: loss={loss_val} time={int(elapsed_time*1000000)}[µs]")

#         path = f"{i}.jpg"
#         print(f"Writing '{path}'... ", end="")
#         with torch.no_grad():
#             write_image(path, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
#         print("done.")

#         # Ignore the time spent saving the image
#         prev_time = time.perf_counter()

#         if i > 0 and interval < 1000:
#             interval *= 10

In [None]:
# Naive Cropped

class Args:
    pass
args = Args()
args.image="/global/u1/j/jswomley/tiny-cuda-nn/data/images/hubble.jpg"
args.image_roi="/global/u1/j/jswomley/tiny-cuda-nn/hubble_roi.jpg"
args.config="/global/u1/j/jswomley/tiny-cuda-nn/config_hash.json"
args.n_steps=10000
args.result_filename="/global/u1/j/jswomley/tiny-cuda-nn/result.png"

device = torch.device("cuda")

with open(args.config) as config_file:
    config = json.load(config_file)

image = original_Image(args.image, device)

def naive_crop(x_coord, y_coord, width, height):
    
    image_roi_np = image.data[y_coord:y_coord+height, x_coord:x_coord+width, :].detach().cpu().numpy()
    write_image(args.image_roi, image_roi_np)
    image_roi = original_Image(args.image_roi, device)
    
    n_channels = image_roi.data.shape[2]

    model = tcnn.NetworkWithInputEncoding(n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]).to(device)
    print(model)

    #===================================================================================================
    # The following is equivalent to the above, but slower. Only use "naked" tcnn.Encoding and
    # tcnn.Network when you don't want to combine them. Otherwise, use tcnn.NetworkWithInputEncoding.
    #===================================================================================================
    # encoding = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])
    # network = tcnn.Network(n_input_dims=encoding.n_output_dims, n_output_dims=n_channels, network_config=config["network"])
    # model = torch.nn.Sequential(encoding, network)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Variables for saving/displaying image results
    resolution = image_roi.data.shape[0:2]
    img_shape = resolution + torch.Size([image_roi.data.shape[2]])
    n_pixels = resolution[0] * resolution[1]

    half_dx =  0.5 / resolution[0]
    half_dy =  0.5 / resolution[1]
    xs = torch.linspace(half_dx, 1-half_dx, resolution[0], device=device)
    ys = torch.linspace(half_dy, 1-half_dy, resolution[1], device=device)
    xv, yv = torch.meshgrid([xs, ys])

    xy = torch.stack((yv.flatten(), xv.flatten())).t()

    path = f"reference.jpg"
    print(f"Writing '{path}'... ", end="")
    write_image(path, image_roi(xy).reshape(img_shape).detach().cpu().numpy())
    print("done.")

    prev_time = time.perf_counter()

    batch_size = 2**18
    interval = 10

    print(f"Beginning optimization with {args.n_steps} training steps.")

    try:
        batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
        traced_image = torch.jit.trace(image_roi, batch)
    except:
        # If tracing causes an error, fall back to regular execution
        print(f"WARNING: PyTorch JIT trace failed. Performance will be slightly worse than regular.")
        traced_image = image_roi
        
    steps = []; recon_psnr = []; recon_ssim = []; recon_loss = []; total_time = 0
    
    for i in range(args.n_steps):
        batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
        targets = traced_image(batch)
        output = model(batch)

        relative_l2_error = (output - targets.to(output.dtype))**2 / (output.detach()**2 + 0.01)
        loss = relative_l2_error.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % interval == 0:
            loss_val = loss.item()
            torch.cuda.synchronize()
            elapsed_time = time.perf_counter() - prev_time
            total_time += elapsed_time
            print(f"Step#{i}: loss={loss_val} time={int(elapsed_time*1000000)}[µs]")
            recon_psnr.append(PSNR(image_roi_np, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())/100)
            recon_ssim.append(SSIM(image_roi_np, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy()))
            recon_loss.append(loss_val)
            steps.append(i)
            
            # path = f"{i}.jpg"
            # print(f"Writing '{path}'... ", end="")
            # with torch.no_grad():
            #     write_image(path, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
            # print("done.")

            # Ignore the time spent saving the image
            prev_time = time.perf_counter()

            if i > 0 and interval < 1000:
                interval *= 10
    
    total_time += time.perf_counter() - prev_time
    final_loss = loss.item()
    if args.result_filename:
        print(f"Writing '{args.result_filename}'... ", end="")
        with torch.no_grad():
            write_image(args.result_filename, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
        print("done.")
    
    tcnn.free_temporary_memory()
    return [steps, recon_psnr, recon_ssim, recon_loss, final_loss, total_time]

In [None]:
# Real ROI function

def recon_roi(x_coord, y_coord, width, height):
    #coordinates count from top left

    class Args:
        pass
    args = Args()
    args.image="/global/u1/j/jswomley/tiny-cuda-nn/data/images/hubble.jpg"
    args.image_roi="/global/u1/j/jswomley/tiny-cuda-nn/hubble_roi.jpg"
    args.config="/global/u1/j/jswomley/tiny-cuda-nn/config_hash.json"
    args.n_steps=10000
    args.result_filename="/global/u1/j/jswomley/tiny-cuda-nn/result.png"


    device = torch.device("cuda")

    with open(args.config) as config_file:
        config = json.load(config_file)

    image = Image(args.image, device, x_coord, y_coord, width, height)
    n_channels = image.data.shape[2]
    image_roi_np = image.data[y_coord:y_coord+height, x_coord:x_coord+width, :].detach().cpu().numpy()

    model = tcnn.NetworkWithInputEncoding(n_input_dims=2, n_output_dims=n_channels, encoding_config=config["encoding"], network_config=config["network"]).to(device)
    # print(model)

    #===================================================================================================
    # The following is equivalent to the above, but slower. Only use "naked" tcnn.Encoding and
    # tcnn.Network when you don't want to combine them. Otherwise, use tcnn.NetworkWithInputEncoding.
    #===================================================================================================
    # encoding = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])
    # network = tcnn.Network(n_input_dims=encoding.n_output_dims, n_output_dims=n_channels, network_config=config["network"])
    # model = torch.nn.Sequential(encoding, network)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Variables for saving/displaying image results
    resolution = image.data.shape[0:2]

    cut_resolution = torch.Size([height, width])
    img_shape = cut_resolution + torch.Size([image.data.shape[2]])
    real_img_shape = resolution + torch.Size([image.data.shape[2]])

    n_pixels = resolution[0] * resolution[1]

    half_dx =  0.5 / resolution[0]
    half_dy =  0.5 / resolution[1]
    xs = torch.linspace(half_dx, 1-half_dx, resolution[0], device=device)
    ys = torch.linspace(half_dy, 1-half_dy, resolution[1], device=device)
    xv, yv = torch.meshgrid([xs, ys])

    xy = torch.stack((yv.flatten(), xv.flatten())).t()

    path = f"reference.jpg"
    print(f"Writing '{path}'... ", end="")
    write_image(path, image(xy)[0].reshape(img_shape).detach().cpu().numpy())
    print("done.")

    prev_time = time.perf_counter()

    batch_size = 2**18
    interval = 10

    print(f"Beginning optimization with {args.n_steps} training steps.")

    try:
        batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
        traced_image = torch.jit.trace(image, batch)
    except:
        # If tracing causes an error, fall back to regular execution
        print(f"WARNING: PyTorch JIT trace failed. Performance will be slightly worse than regular.")
        traced_image = image

        
    steps = []; recon_psnr = []; recon_ssim = []; recon_loss = []; total_time = 0
    
    for i in range(args.n_steps):
        batch = torch.rand([batch_size, 2], device=device, dtype=torch.float32)
        targets = traced_image(batch)[0]
        output = model(traced_image(batch)[1])

        relative_l2_error = (output - targets.to(output.dtype))**2 / (output.detach()**2 + 0.01)
        loss = relative_l2_error.mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % interval == 0:
            loss_val = loss.item()
            torch.cuda.synchronize()
            elapsed_time = time.perf_counter() - prev_time
            total_time += elapsed_time
            print(f"Step#{i}: loss={loss_val} time={int(elapsed_time*1000000)}[µs]")
            recon_psnr.append(PSNR(image_roi_np, model(image(xy)[1]).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())/100)
            recon_ssim.append(SSIM(image_roi_np, model(image(xy)[1]).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy()))
            recon_loss.append(loss_val)
            steps.append(i)
            
            # path = f"{i}.jpg"
            # print(f"Writing '{path}'... ", end="")
            # with torch.no_grad():
            #     write_image(path, model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
            # print("done.")

            # Ignore the time spent saving the image
            prev_time = time.perf_counter()

            if i > 0 and interval < 1000:
                interval *= 10
    
    total_time += time.perf_counter() - prev_time
    final_loss = loss.item()
    if args.result_filename:
        print(f"Writing '{args.result_filename}'... ", end="")
        with torch.no_grad():
            write_image(args.result_filename, model(image(xy)[1]).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy())
        print("done.")
    
    tcnn.free_temporary_memory()
    return [steps, recon_psnr, recon_ssim, recon_loss, final_loss, total_time]

In [None]:
### Tokyo_cropped
# quarter_metrics = recon_roi(0, 0, 6250, 7000)
# half_metrics = recon_roi(0, 0, 12500, 7000)
# three_quarter_metrics = recon_roi(0, 0, 18750, 7000)
# full_metrics = recon_roi(0, 0, 25000, 7000)

### Hubble
# quarter_metrics = recon_roi(0,0,1695,7071)
# half_metrics = recon_roi(0,0,3390,7071)
# three_quarter_metrics = recon_roi(0,0,5085,7071)
# full_metrics = recon_roi(0,0,6780,7071)


In [None]:
# tcnn.free_temporary_memory()
real_metrics = recon_roi(2000, 2000, 20, 20)
tcnn.free_temporary_memory()
naive_metrics = naive_crop(2000, 2000, 20, 20)

In [None]:
import matplotlib.pyplot as plt

def view_metrics(metrics):
    plt.subplot(1,2,1)
    plt.plot(metrics[0], metrics[1], label="PSNR")
    plt.plot(metrics[0], metrics[2], label="SSIM")
    plt.xlabel("recon step")
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(metrics[0], metrics[3], label="loss")
    plt.legend()
    plt.show()
    print("final loss: {}\nfinal PSNR: {}\nfinal SSIM: {}\ntotal time: {} seconds".format(metrics[4], metrics[1][-1], metrics[2][-1], metrics[5]))

In [None]:
view_metrics(real_metrics)
view_metrics(naive_metrics)

In [None]:
# enc = torch.Tensor.cpu(encoding.params.data)
# print(enc, enc.shape, enc.dtype)
# weights = torch.Tensor.cpu(network.params.data)
# print(f"{enc.shape[0]*4/1e6} MB + {weights.shape[0]*4/1e6} MB")

In [None]:
# encoding = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"],seed=0)
# network = tcnn.Network(n_input_dims=encoding.n_output_dims, n_output_dims=n_channels, network_config=config["network"],seed=0)
# model1 = torch.nn.Sequential(encoding, network)

In [None]:
# print(f"{image.shape[0]*image.shape[1]*image.shape[2]/1e6} MB")

In [None]:
# a = model(xy).reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy().astype(np.float32)

In [None]:
# b = image.data.detach().cpu().numpy()

In [None]:
# print(np.max(a), np.max(b))
# print(np.min(a), np.min(b))

In [None]:
# c=a-b
# print(np.max(c), np.min(c))

In [None]:
# plt.subplot(121)
# plt.imshow(b); plt.title("original"); 
# plt.subplot(122)
# plt.imshow(a); plt.title("compressed");
# plt.show()

In [None]:
# plt.subplot(121)
# plt.imshow(b[2500:3000,2600:2900]); plt.title("original"); 
# plt.subplot(122)
# plt.imshow(a[2500:3000,2600:2900]); plt.title("compressed");
# plt.show()

In [None]:
# write_image("compressed.png",a)

In [None]:
# write_image("compressed_zoom.png",a[2500:3000,2600:2900])

In [None]:
# write_image("original_zoom.png",b[2500:3000,2600:2900])