Skip to content

Commit

Permalink
Reduce GPU memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
ameuleman committed Nov 4, 2023
1 parent b1b14e7 commit 3905e39
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
16 changes: 7 additions & 9 deletions localTensoRF/local_tensorfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(

# Setup radiance fields
self.tensorfs = torch.nn.ParameterList()
self.rf_optimizers, self.rf_iter = [], []
self.rf_iter = []
self.world2rf = torch.nn.ParameterList()
self.append_rf()

Expand Down Expand Up @@ -143,7 +143,7 @@ def append_rf(self, n_added_frames=1):
grad_vars = self.tensorfs[-1].get_optparam_groups(
self.rf_lr_init, self.rf_lr_basis
)
self.rf_optimizers.append(torch.optim.Adam(grad_vars, betas=(0.9, 0.99)))
self.rf_optimizer = (torch.optim.Adam(grad_vars, betas=(0.9, 0.99)))

def append_frame(self):
if len(self.r_c2w) == 0:
Expand Down Expand Up @@ -237,16 +237,14 @@ def optimizer_step(self, loss, optimize_poses):
self.intrinsic_optimizer.zero_grad()

# tensorfs
for optimizer, iteration in zip(self.rf_optimizers, self.rf_iter):
if iteration < self.n_iters:
optimizer.zero_grad()
self.rf_optimizer.zero_grad()

loss.backward()

# Optimize RFs
self.rf_optimizers[-1].step()
self.rf_optimizer.step()
if self.is_refining:
for param_group in self.rf_optimizers[-1].param_groups:
for param_group in self.rf_optimizer.param_groups:
param_group["lr"] = param_group["lr"] * self.lr_factor

# Increase RF resolution
Expand All @@ -260,10 +258,10 @@ def optimizer_step(self, loss, optimize_poses):
grad_vars = self.tensorfs[-1].get_optparam_groups(
self.rf_lr_init, self.rf_lr_basis
)
self.rf_optimizers[-1] = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
self.rf_optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))

# Update alpha mask
if iteration in self.update_AlphaMask_list:
if self.rf_iter[-1] in self.update_AlphaMask_list:
reso_mask = (self.tensorfs[-1].gridSize / 2).int()
self.tensorfs[-1].updateAlphaMask(tuple(reso_mask))

Expand Down
29 changes: 20 additions & 9 deletions localTensoRF/models/tensorBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, device, aabb, alpha_volume):

self.aabb = torch.nn.Parameter(aabb.to(self.device), requires_grad=False)
self.aabbSize = self.aabb[1] - self.aabb[0]
self.invgridSize = 1.0/self.aabbSize * 2
self.invgridSize = torch.nn.Parameter(1.0/self.aabbSize * 2, requires_grad=False)
self.alpha_volume = torch.nn.Parameter(
alpha_volume.view(1,1,*alpha_volume.shape[-3:]), requires_grad=False
)
Expand All @@ -56,6 +56,10 @@ def sample_alpha(self, xyz_sampled):

def normalize_coord(self, xyz_sampled):
return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1

def to(self, device):
self.device = torch.device(device)
return super(AlphaGridMask, self).to(device)

class MLPRender_Fea(torch.nn.Module):
def __init__(self, inChanel, viewpe=6, feape=6, featureC=128):
Expand Down Expand Up @@ -497,24 +501,25 @@ def feature2density(self, density_features):
@torch.no_grad()
def getDenseAlpha(self,gridSize=None):
gridSize = self.gridSize if gridSize is None else gridSize

samples = torch.stack(torch.meshgrid(
dense_xyz = torch.stack(torch.meshgrid(
torch.linspace(0, 1, gridSize[0]),
torch.linspace(0, 1, gridSize[1]),
torch.linspace(0, 1, gridSize[2]),
), -1).to(self.device)
dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
), -1)
dense_xyz = self.aabb[0] * (1-dense_xyz) + self.aabb[1] * dense_xyz

alpha = torch.zeros_like(dense_xyz[...,0])
for i in range(gridSize[0]):
alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
return alpha, dense_xyz
return alpha

@torch.no_grad()
def updateAlphaMask(self, gridSize=(200,200,200)):
torch.cuda.empty_cache()
alpha, dense_xyz = self.getDenseAlpha(gridSize)
dense_xyz = dense_xyz.transpose(0,2).contiguous()
device = self.device
self.to("cpu")
alpha = self.getDenseAlpha(gridSize)
alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
total_voxels = gridSize[0] * gridSize[1] * gridSize[2]

Expand All @@ -523,9 +528,12 @@ def updateAlphaMask(self, gridSize=(200,200,200)):
alpha[alpha>=self.alphaMask_thres] = 1
alpha[alpha<self.alphaMask_thres] = 0

self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha)
self.alphaMask = AlphaGridMask("cpu", self.aabb, alpha)
self.alphaMask = self.alphaMask.to(self.device)
print(f"alpha rest %%%f"%(torch.sum(alpha)/total_voxels*100))
torch.cuda.empty_cache()
self.to(device)
torch.cuda.empty_cache()

def compute_alpha(self, xyz_locs, length=1):

Expand All @@ -551,6 +559,9 @@ def compute_alpha(self, xyz_locs, length=1):

def to(self, device):
self.device = torch.device(device)
self.stepSize = self.stepSize.to(device)
if self.alphaMask is not None:
self.alphaMask = self.alphaMask.to(device)
return super(TensorBase, self).to(device)

def forward(
Expand Down
4 changes: 2 additions & 2 deletions localTensoRF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,12 @@ def reconstruction(args):

writer.add_scalar(
"train/density_app_plane_lr",
local_tensorfs.rf_optimizers[-1].param_groups[0]["lr"],
local_tensorfs.rf_optimizer.param_groups[0]["lr"],
global_step=iteration,
)
writer.add_scalar(
"train/basis_mat_lr",
local_tensorfs.rf_optimizers[-1].param_groups[4]["lr"],
local_tensorfs.rf_optimizer.param_groups[4]["lr"],
global_step=iteration,
)

Expand Down

0 comments on commit 3905e39

Please sign in to comment.