In [1]:
import math
import copy
import json

import numpy as np
import torch
from torch import nn
import wandb

from modules import MLP, PositionalEncoding
from tqdm import tqdm
from utils import load_params, init_variables, get_linear_warmup_cos_annealing, get_dataset, initialize_per_timestep, get_loss, get_frame

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
def train(seq: str):
    md = json.load(open(f"./data/{seq}/train_meta.json", 'r'))
    seq_len = 5
    params = load_params('params.pth')
    variables = init_variables(params)
    iterations = 500_000

    mlp = MLP(100, 128, seq_len, 6).cuda()
    # mlp = UNet(100, None, seq_len, None).cuda()
    mlp_optimizer = torch.optim.Adam(params=mlp.parameters(), lr=2e-3)
    scheduler = get_linear_warmup_cos_annealing(mlp_optimizer, warmup_iters=10_000, total_iters=iterations)

    means = params['means']
    rotations = params['rotations']

    means_norm = means - means.min(dim=0).values
    means_norm = (2. * means_norm / means_norm.max(dim=0).values) - 1.

    rotations_norm = rotations - rotations.min(dim=0).values
    rotations_norm = (2. * rotations_norm / rotations_norm.max(dim=0).values) - 1.

    pos_mean = PositionalEncoding(L=10)
    pos_smol = PositionalEncoding(L=4)

    means_norm = pos_mean(means_norm)
    rotations_norm = pos_smol(rotations_norm)

    ## Random Training
    dataset = []
    for t in range(1, seq_len + 1, 1):
        dataset += [get_dataset(t, md, seq)]
    for i in tqdm(range(iterations)):
        p = i / iterations
        alpha = 2. / (1. + math.exp(-6 * p)) - 1

        di = (i % seq_len)# torch.randint(0, len(dataset), (1,))
        si = torch.randint(0, len(dataset[0]), (1,))

        t = pos_smol(torch.tensor((di+1)/seq_len).view(1,1).repeat(means_norm.shape[0], 1).cuda())

        if di == 0:
            variables = initialize_per_timestep(params, variables)

        X = dataset[di][si]

        # delta = mlp(torch.cat((params['means'], params['rotations']), dim=1), torch.tensor(di).cuda())
        delta = mlp(torch.cat((params['means'], params['rotations']), dim=1), torch.cat((means_norm, rotations_norm), dim=1), torch.tensor(t).cuda())
        delta_means = delta[:,:3]
        delta_rotations = delta[:,3:]

        l = 0.01
        updated_params = copy.deepcopy(params)
        updated_params['means'] = updated_params['means'].detach()
        updated_params['means'] += delta_means * l
        updated_params['rotations'] = updated_params['rotations'].detach()
        updated_params['rotations'] += delta_rotations * l

        loss = get_loss(updated_params, X, variables, alpha)

        variables = initialize_per_timestep(updated_params, variables) # sets previous state to updated state

        wandb.log({
            f'loss-random': loss.item(),
            f'lr': mlp_optimizer.param_groups[0]['lr']
        })

        loss.backward()

        mlp_optimizer.step()
        scheduler.step()
        mlp_optimizer.zero_grad()

    for d in dataset:
        losses = []
        with torch.no_grad():
            for X in d:
                loss = get_loss(updated_params, X, variables, alpha=1.)
                losses.append(loss.item())

        wandb.log({
            f'mean-losses-new': sum(losses) / len(losses)
        })

        ## Visualize
    with torch.no_grad():
        ds = get_dataset(0, md=md, seq=seq)
        canon_batch = ds[0]

        frames = []
        images = []
        gif = []

        # Canonical frame
        frames.append(get_frame(params, canon_batch))

        for t in range(0, seq_len , 1):
            das = get_dataset(t, md=md, seq=seq)
            X = das[0]

            t = pos_smol(torch.tensor((t+1)/seq_len).view(1,1).repeat(means_norm.shape[0], 1).cuda())

            # delta = mlp(torch.cat((params['means'], params['rotations']), dim=1), torch.tensor(t).cuda())
            delta = mlp(torch.cat((params['means'], params['rotations']), dim=1), torch.cat((means_norm, rotations_norm), dim=1), t)
            delta_means = delta[:,:3]
            delta_rotations = delta[:,3:]

            updated_params = copy.deepcopy(params)
            updated_params['means'] = updated_params['means'].detach()
            updated_params['means'] += delta_means * l
            updated_params['rotations'] = updated_params['rotations'].detach()
            updated_params['rotations'] += delta_rotations * l

            fr = get_frame(updated_params, X)

            frames.append(fr)

        for frame in frames:
            frame_np = (frame.detach().cpu().clip(min=0.0, max=1.0).permute(1,2,0).numpy()*255).astype(np.uint8)
            im = Image.fromarray(frame_np)
            images.append(im)

        print('Writing image...')
        images[0].save('temp_result.gif', save_all=True,optimize=False, append_images=images[1:], loop=0)


In [4]:
wandb.login(
    key="45f1e71344c1104de0ce98dc2cf5d9e7557e88ea"
)
wandb.init(
    project="new-dynamic-gaussians",
    entity="myasincifci",

)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmyasincifci[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/yasin/.netrc


In [5]:
train('basketball')