In [None]:
!pip install tensorboard

In [1]:
import torch
import os
import sys
from models.autoencoder import AutoEncoder
import time
from datetime import datetime
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from models.utils import  AverageMeter, str2bool
from dataset.dataset import CompressDataset
from args.shapenet_args import parse_shapenet_args
from args.semantickitti_args import parse_semantickitti_args
from torch.optim.lr_scheduler import StepLR
from models.Chamfer3D.dist_chamfer_3D import chamfer_3DDist
chamfer_dist = chamfer_3DDist()

Loaded compiled 3D CUDA chamfer distance


In [2]:
def train(args):
    start = time.time()

    if args.batch_size > 1:
        print('The performance will degrade if batch_size is larger than 1!')

    if args.compress_normal == True:
        args.in_fdim = 6

    # load data
    train_dataset = CompressDataset(data_path=args.train_data_path, cube_size=args.train_cube_size, batch_size=args.batch_size)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=args.batch_size)

    val_dataset = CompressDataset(data_path=args.val_data_path, cube_size=args.val_cube_size, batch_size=args.batch_size)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size)

    # set up folders for checkpoints
    str_time = datetime.now().isoformat()
    print('Experiment Time:', str_time)
    checkpoint_dir = os.path.join(args.output_path, str_time, 'ckpt')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # create the model
    model = AutoEncoder(args)
    model = model.cuda()
    print('Training Arguments:', args)
    print('Model Architecture:', model)

    # optimizer for autoencoder
    parameters = set(p for n, p in model.named_parameters() if not n.endswith(".quantiles"))
    optimizer = optim.Adam(parameters, lr=args.lr)
    # lr scheduler
    scheduler_steplr = StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.gamma)
    # optimizer for entropy bottleneck
    aux_parameters = set(p for n, p in model.named_parameters() if n.endswith(".quantiles"))
    aux_optimizer = optim.Adam(aux_parameters, lr=args.aux_lr)

    # best validation metric
    best_val_chamfer_loss = float('inf')

   # train
    for epoch in range(args.epochs):
        epoch_loss = AverageMeter()
        epoch_chamfer_loss = AverageMeter()
        epoch_density_loss = AverageMeter()
        epoch_pts_num_loss = AverageMeter()
        epoch_latent_xyzs_loss = AverageMeter()
        epoch_normal_loss = AverageMeter()
        epoch_bpp_loss = AverageMeter()
        epoch_aux_loss = AverageMeter()
        
        model.train()

        for i, input_dict in enumerate(train_loader):
            # input: (b, n, c)
            input = input_dict['xyzs'].cuda()
            # input: (b, c, n)
            input = input.permute(0, 2, 1).contiguous()

            # compress normal
            if args.compress_normal == True:
                normals = input_dict['normals'].cuda().permute(0, 2, 1).contiguous()
                input = torch.cat((input, normals), dim=1)

            # model forward
            decompressed_xyzs, loss, loss_items, bpp = model(input)
            epoch_loss.update(loss.item())
            epoch_chamfer_loss.update(loss_items['chamfer_loss'])
            epoch_density_loss.update(loss_items['density_loss'])
            epoch_pts_num_loss.update(loss_items['pts_num_loss'])
            epoch_latent_xyzs_loss.update(loss_items['latent_xyzs_loss'])
            epoch_normal_loss.update(loss_items['normal_loss'])
            epoch_bpp_loss.update(loss_items['bpp_loss'])

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update the parameters of entropy bottleneck
            aux_loss = model.feats_eblock.loss()
            if args.quantize_latent_xyzs == True:
                aux_loss += model.xyzs_eblock.loss()
            epoch_aux_loss.update(aux_loss.item())

            aux_optimizer.zero_grad()
            aux_loss.backward()
            aux_optimizer.step()

            # print loss
            if (i+1) % args.print_freq == 0:
                print("train epoch: %d/%d, iters: %d/%d, loss: %f, avg chamfer loss: %f, "
                      "avg density loss: %f, avg pts num loss: %f, avg latent xyzs loss: %f, "
                      "avg normal loss: %f, avg bpp loss: %f, avg aux loss: %f" %
                      (epoch+1, args.epochs, i+1, len(train_loader), epoch_loss.get_avg(), epoch_chamfer_loss.get_avg(),
                       epoch_density_loss.get_avg(), epoch_pts_num_loss.get_avg(), epoch_latent_xyzs_loss.get_avg(),
                       epoch_normal_loss.get_avg(), epoch_bpp_loss.get_avg(), epoch_aux_loss.get_avg()))

        scheduler_steplr.step()

        # print loss
        interval = time.time() - start
        print("train epoch: %d/%d, time: %d mins %.1f secs, loss: %f, avg chamfer loss: %f, "
              "avg density loss: %f, avg pts num loss: %f, avg latent xyzs loss: %f, "
              "avg normal loss: %f, avg bpp loss: %f, avg aux loss: %f" %
              (epoch+1, args.epochs, interval/60, interval%60, epoch_loss.get_avg(), epoch_chamfer_loss.get_avg(),
               epoch_density_loss.get_avg(), epoch_pts_num_loss.get_avg(), epoch_latent_xyzs_loss.get_avg(),
               epoch_normal_loss.get_avg(), epoch_bpp_loss.get_avg(), epoch_aux_loss.get_avg()))


        # validation
        model.eval()
        val_chamfer_loss = AverageMeter()
        val_normal_loss = AverageMeter()
        val_bpp = AverageMeter()
        with torch.no_grad():
            for input_dict in val_loader:
                # xyzs: (b, n, c)
                input = input_dict['xyzs'].cuda()
                # (b, c, n)
                input = input.permute(0, 2, 1).contiguous()

                # compress normal
                if args.compress_normal == True:
                    normals = input_dict['normals'].cuda().permute(0, 2, 1).contiguous()
                    input = torch.cat((input, normals), dim=1)
                    args.in_fdim = 6

                # gt_xyzs
                gt_xyzs = input[:, :3, :].contiguous()

                # model forward
                decompressed_xyzs, loss, loss_items, bpp = model(input)
                # calculate val loss and bpp
                d1, d2, _, _ = chamfer_dist(gt_xyzs.permute(0, 2, 1).contiguous(),
                                            decompressed_xyzs.permute(0, 2, 1).contiguous())
                chamfer_loss = d1.mean() + d2.mean()
                val_chamfer_loss.update(chamfer_loss.item())
                val_normal_loss.update(loss_items['normal_loss'])
                val_bpp.update(bpp.item())

        # print loss
        print("val epoch: %d/%d, val bpp: %f, val chamfer loss: %f, val normal loss: %f" %
              (epoch+1, args.epochs, val_bpp.get_avg(), val_chamfer_loss.get_avg(), val_normal_loss.get_avg()))

        # save checkpoint
        cur_val_chamfer_loss = val_chamfer_loss.get_avg()
        if  cur_val_chamfer_loss < best_val_chamfer_loss or (epoch+1) % args.save_freq == 0:
            model_name = 'ckpt-best.pth' if cur_val_chamfer_loss < best_val_chamfer_loss else 'ckpt-epoch-%02d.pth' % (epoch+1)
            model_path = os.path.join(checkpoint_dir, model_name)
            torch.save(model.state_dict(), model_path)
            # update best val chamfer loss
            if cur_val_chamfer_loss < best_val_chamfer_loss:
                best_val_chamfer_loss = cur_val_chamfer_loss

def reset_model_args(train_args, model_args):
    for arg in vars(train_args):
        setattr(model_args, arg, getattr(train_args, arg))


In [11]:
class Arguments:
    def __init__(self):
        self.dataset = 'shapenet'
        self.lr = 0.001
        self.aux_lr = 0.001
        self.weight_decay = 0.001
        self.betas = (0.9, 0.999)
        self.lr_decay_step = 20
        self.gamma = 0.5
        self.train_data_path = './data/shapenet/shapenet_train_cube_size_22.pkl'
        self.train_cube_size = 22
        self.val_data_path = './data/shapenet/shapenet_val_cube_size_22.pkl'
        self.val_cube_size = 22
        self.test_data_path = './data/shapenet/shapenet_test_cube_size_22.pkl'
        self.test_cube_size = 22
        self.peak = 30
        self.epochs = 100
        self.batch_size = 32
        self.print_freq = 10
        self.save_freq = 5
        self.output_path = './output'
        self.compress_normal = False
        self.in_fdim = 3
        self.k = 32
        self.downsample_rate = [1/2, 1/2, 1/2]
        self.max_upsample_num = [4, 4, 4]
        self.layer_num = 3
        self.dim = 64
        self.hidden_dim = 64
        self.ngroups = 8
        self.quantize_latent_xyzs = True
        self.latent_xyzs_conv_mode = 'mlp'
        self.sub_point_conv_mode = 'mlp'
        self.chamfer_coe = 1.0
        self.pts_num_coe = 1.0
        self.normal_coe = 1.0
        self.bpp_lambda = 1e-3
        self.mean_distance_coe = 1.0
        self.density_coe = 1.0
        self.latent_xyzs_coe = 1.0
        self.model_path = './output/model.pth'
        self.density_radius = 0.05
        self.dist_coe = 1.0
        self.omega_xyzs = 1.0
        self.omega_normals = 1.0

train_args = Arguments()


In [12]:
train(train_args)

The performance will degrade if batch_size is larger than 1!
Experiment Time: 2023-08-09T10:40:49.171484
Training Arguments: <__main__.Arguments object at 0x7f93a5f34650>
Model Architecture: AutoEncoder(
  (pre_conv): Sequential(
    (0): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
    (1): GroupNorm(8, 64, eps=1e-05, affine=True)
    (2): ReLU()
    (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  )
  (encoder): Encoder(
    (encoder_layers): ModuleList(
      (0): DownsampleLayer(
        (pre_conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
        (feats_agg_nn): PointTransformerLayer(
          (w_qs): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
          (w_ks): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
          (w_vs): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
          (conv_delta): Sequential(
            (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): GroupNorm(8, 64, eps=1e-05, affine=True)
            (2): ReLU(inplace=True)
      