In [1]:
import argparse
import os
import torch

from utils.validation_utils import get_video
from utils.preprocessing import initialize_tensors, make_parameter
from train import train

In [2]:
def get_args():
    parser = argparse.ArgumentParser(conflict_handler='resolve')
    parser.add_argument('--first_iter', default=0, type=int)
    parser.add_argument('--iterations', default=5000, type=int)
    parser.add_argument('--refinement_iter', default=1, type=int)
    parser.add_argument('--densification_iter', default=250, type=int)
    parser.add_argument('--logging_iter', default=100, type=int)    
    parser.add_argument('--img_size', default=(256, 256), type=tuple)
    parser.add_argument('--gaussian_init_scale', '-gs', default=5, type=int)
    
    parser.add_argument('--scale_factor', default=1.6, type=float)
    parser.add_argument('--start_points_number', default=1000, type=int)
    parser.add_argument('--limit_points_number', default=10000, type=int)
    parser.add_argument('--grad_threshold', default=0.002, type=float)
    parser.add_argument('--gauss_threshold', default=0.05, type=float)
    
    parser.add_argument('--learning_rate', '-lr', default=0.01, type=float)
    parser.add_argument('--lambda', default=0.2, type=float)
    
    parser.add_argument('--image_path', default='images/mikki.jpg', type=str)
    parser.add_argument('--output_folder', default=None, type=str)
    parser.add_argument('--get_video', default=False, type=bool)
    args, _ = parser.parse_known_args()
    args.device = 'cuda'
    
    return args

In [3]:
args = get_args()

### Load image and initialize necessary tensors

In [6]:
gt_image, current_points_places, scaling, rotation, colors, alphas, points_locations = initialize_tensors(args)
exp_name = 'test_exp'
output_path = f'experiments/{exp_name}'
os.makedirs(output_path, exist_ok=True)
os.makedirs(f'{output_path}/images', exist_ok=True)

In [7]:
cov_matrix = torch.cat([scaling, rotation], dim=1)
rgba_matrix = torch.cat([alphas, colors], dim=1)
cov_matrix_tensor, rgba_matrix_tensor, points_locs_rensor = make_parameter([cov_matrix, rgba_matrix, points_locations])

### Run train

In [9]:
train(args, gt_image, cov_matrix_tensor, rgba_matrix_tensor, points_locs_rensor, current_points_places, output_path)

### Get video

In [3]:
exp_path = output_path
get_video(exp_path, video_name='test_video')