Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix train_mlp_nerf and save the model at the end of training #177

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 79 additions & 52 deletions examples/train_mlp_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import torch.nn.functional as F
import tqdm
from radiance_fields.mlp import VanillaNeRFRadianceField
from utils import render_image, set_random_seed
from utils import (
MIPNERF360_UNBOUNDED_SCENES,
NERF_SYNTHETIC_SCENES,
render_image,
set_random_seed,
)

from nerfacc import ContractionType, OccupancyGrid

Expand All @@ -34,23 +39,17 @@
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
help="the path of the pretrained model",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
# nerf synthetic
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
# mipnerf360 unbounded
"garden",
],
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
Expand All @@ -74,11 +73,47 @@

render_n_samples = 1024

# setup the scene bounding box.
# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}

if args.scene in MIPNERF360_UNBOUNDED_SCENES:
from datasets.nerf_360_v2 import SubjectLoader

print("Using unbounded rendering")
target_sample_batch_size = 1 << 16
train_dataset_kwargs["color_bkgd_aug"] = "random"
train_dataset_kwargs["factor"] = 4
test_dataset_kwargs["factor"] = 4
grid_resolution = 128

elif args.scene in NERF_SYNTHETIC_SCENES:
from datasets.nerf_synthetic import SubjectLoader

target_sample_batch_size = 1 << 16
grid_resolution = 128

train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
device=device,
**train_dataset_kwargs,
)

test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
device=device,
**test_dataset_kwargs,
)

if args.unbounded:
print("Using unbounded rendering")
contraction_type = ContractionType.UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb = None
near_plane = 0.2
far_plane = 1e4
Expand Down Expand Up @@ -110,44 +145,22 @@
gamma=0.33,
)

# setup the dataset
train_dataset_kwargs = {}
test_dataset_kwargs = {}
if args.scene == "garden":
from datasets.nerf_360_v2 import SubjectLoader

target_sample_batch_size = 1 << 16
train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4}
test_dataset_kwargs = {"factor": 4}
grid_resolution = 128
else:
from datasets.nerf_synthetic import SubjectLoader

target_sample_batch_size = 1 << 16
grid_resolution = 128

train_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split=args.train_split,
num_rays=target_sample_batch_size // render_n_samples,
**train_dataset_kwargs,
)

test_dataset = SubjectLoader(
subject_id=args.scene,
root_fp=args.data_root,
split="test",
num_rays=None,
**test_dataset_kwargs,
)

occupancy_grid = OccupancyGrid(
roi_aabb=args.aabb,
resolution=grid_resolution,
contraction_type=contraction_type,
).to(device)

if args.model_path is not None:
checkpoint = torch.load(args.model_path)
radiance_field.load_state_dict(checkpoint["radiance_field_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
occupancy_grid.load_state_dict(checkpoint["occupancy_grid_state_dict"])
step = checkpoint["step"]
else:
step = 0

# training
step = 0
tic = time.time()
Expand Down Expand Up @@ -204,14 +217,28 @@
if step % 5000 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
psnr = -10.0 * torch.log(loss) / np.log(10.0)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} | "
f"psnr={psnr:.2f}"
)

if step > 0 and step % max_steps == 0:
model_save_path = str(pathlib.Path.cwd() / f"mlp_nerf_{step}")
torch.save(
{
"step": step,
"radiance_field_state_dict": radiance_field.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"occupancy_grid_state_dict": occupancy_grid.state_dict(),
},
model_save_path,
)

# evaluation
radiance_field.eval()

Expand All @@ -230,8 +257,8 @@
rays,
scene_aabb,
# rendering options
near_plane=None,
far_plane=None,
near_plane=near_plane,
far_plane=far_plane,
render_step_size=render_step_size,
render_bkgd=render_bkgd,
cone_angle=args.cone_angle,
Expand All @@ -246,7 +273,7 @@
# ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
# )
# imageio.imwrite(
# "rgb_test.png",
# f"rgb_test_{i}.png",
# (rgb.cpu().numpy() * 255).astype(np.uint8),
# )
# break
Expand Down