In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from non_rigid.datasets.proc_cloth_flow import DeformablePlacementDataset
from omegaconf import OmegaConf
from pathlib import Path

dataset = DeformablePlacementDataset(
    root=Path("/home/beisner/datasets/tax3d_data/proccloth/cloth=single-fixed anchor=single-random hole=single"),
    dataset_cfg=OmegaConf.create({
        "train_size": 1,
        "scene": False,
        "sample_size_action": 512,
        "sample_size_anchor": 512,
        "world_frame": False,
        "source": "dedo",
        "anchor_occlusion": False,
        "rotation_variance": 0.0,
        "translation_variance": 0.0,
        "action_transform_type": "identity",
        "anchor_transform_type": "identity",
        "center_type": "anchor_center",
        "action_context_center_type": "center",
        "downsample_type": "fps",
    }),
    split="train_tax3d",
)
print(dataset)

In [None]:
data = dataset[0]

In [None]:
for key, value in data.items():
    print(f"{key}: {value.shape if hasattr(value, 'shape') else value}")

In [None]:
from rpad.visualize_3d.plots import segmentation_fig
import torch 

fig = segmentation_fig(
    data=torch.cat([
        data["pc_action"],
        data["pc_anchor"],
        data["pc_action"] + data["flow"],
    ]),
    labels=torch.cat([
        torch.zeros(data["pc_action"].shape[0]),
        torch.ones(data["pc_anchor"].shape[0]),
        torch.ones(data["pc_action"].shape[0]) * 2,

    ]).int(),
    labelmap={
        0: "action",
        1: "anchor",
        2: "action + flow",
    },
)
fig

In [None]:
from non_rigid.models.regression import RegressionModule, RegressionNetwork
from non_rigid.models.tax3d import CrossDisplacementModule, DiffusionTransformerNetwork

NUM_TRAINING_STEPS = 50000
REGRESSION = False

if REGRESSION:
    model_cfg = OmegaConf.create({
        "name": "regression",

        "type": "flow",
        "size": "xS",
        "rotary": False,
        "center_noise": False,
        "in_channels": 3,
        "learn_sigma": False,
        "x_encoder": "mlp",
        "y_encoder": "mlp",
        "x0_encoder": None,

        "diff_train_steps": 100,
    })
    network = RegressionNetwork(model_cfg=model_cfg)
    model = RegressionModule(network, cfg=OmegaConf.create({
        "mode": "train",
        "prediction_type": "flow",
        "model": model_cfg,
        "training": {
            "lr": 1e-4,
            "weight_decay": 1e-5,
            "num_training_steps": NUM_TRAINING_STEPS,
            "lr_warmup_steps": 100,
            "additional_train_logging_period": 1000,
            "batch_size": 1,
            "val_batch_size": 1,
            "sample_size": None,
            "sample_size_anchor": None,
            "num_wta_trials": 10,
        }
    }))
else:
    model_cfg = OmegaConf.create({
        "name": "df_cross",

        "type": "flow",
        "size": "xS",
        "rotary": False,
        "center_noise": False,
        "in_channels": 3,
        "learn_sigma": True,
        "x_encoder": "mlp",
        "y_encoder": "mlp",
        "x0_encoder": "mlp",

        "diff_train_steps": 100,
        "diff_inference_steps": 100,
        "diff_noise_scale": 1,
        "diff_noise_schedule": "linear",
        "diff_type": "gaussian",
    })
    network = DiffusionTransformerNetwork(model_cfg=model_cfg)
    model = CrossDisplacementModule(network, cfg=OmegaConf.create({
        "mode": "train",
        "prediction_type": "flow",
        "model": model_cfg,
        "training": {
            "lr": 1e-4,
            "weight_decay": 1e-5,
            "num_training_steps": NUM_TRAINING_STEPS,
            "lr_warmup_steps": 100,
            "additional_train_logging_period": 1000,
            "batch_size": 1,
            "val_batch_size": 1,
            "sample_size": None,
            "sample_size_anchor": None,
            "num_wta_trials": 10,
        }

    }))


                                

In [None]:
# Import default_collate from torch
from torch.utils.data.dataloader import default_collate
batch = default_collate([dataset[0]])

# Dataloader on top of dataset[0]
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

device = "cuda"
model = model.to(device)
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

optimizers, schedulers = model.configure_optimizers()



In [None]:
# Using the trainer...
import lightning as L
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    # precision="16-mixed",
    precision="32-true",
    max_epochs=NUM_TRAINING_STEPS,
    logger=False,
    check_val_every_n_epoch=0,
    # log_every_n_steps=2, # TODO: MOVE THIS TO TRAINING CFG
    log_every_n_steps=1,
    gradient_clip_val=1.0,
)

trainer.fit(model, dataloader)


In [None]:
model

In [None]:
from tqdm import tqdm

losses = []
with tqdm(range(NUM_TRAINING_STEPS)) as pbar:
    for i in pbar:
        loss = model.training_step(batch)
        optimizers[0].zero_grad()
        loss.backward()

        losses.append(loss.item())

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizers[0].step()

        # Scheduler step
        schedulers[0].step()

        if i % 10 == 0:
            pbar.set_description(f"Step {i}, loss: {loss}")
    

In [None]:
# with_clipping_losses = losses
without_clipping_losses = losses

In [None]:
import matplotlib.pyplot as plt

plt.plot(without_clipping_losses, label="Without clipping")
plt.plot(with_clipping_losses, label="With clipping")
plt.legend()

plt.show()

In [None]:
model.device

In [None]:
# Make a prediction
model.cuda()
model.eval()
with torch.no_grad():
    prediction = model.predict(batch, num_samples=10)

pred_flow = prediction["flow"]["pred"].cpu()
pred_point = prediction["point"]["pred"].cpu()

fig = segmentation_fig(
    data=torch.cat([
        data["pc_action"],
        data["pc_anchor"],
        data["pc_action"] + data["flow"],
        # *[data["pc_action"] + pred_flow[i] for i in range(pred_flow.shape[0])],
        *[pred_point[i] for i in range(pred_point.shape[0])],
    ]),
    labels=torch.cat([
        torch.zeros(data["pc_action"].shape[0]),
        torch.ones(data["pc_anchor"].shape[0]),
        torch.ones(data["pc_action"].shape[0]) * 2,
        *[torch.ones(data["pc_action"].shape[0]) * (3 + i) for i in range(pred_flow.shape[0])],
    ]).int(),
    labelmap={
        0: "action",
        1: "anchor",
        2: "action + gt flow",
        **{3 + i: f"action + pred flow {i}" for i in range(pred_flow.shape[0])},
    },
)
fig

In [None]:
# Get Average RMSE of the predictions
prediction = model.predict_wta(batch, num_samples=10)
print(f"RMSE: {prediction['rmse'].cpu()}")
print(f"RMSE wta: {prediction['rmse_wta'].item()}")

In [None]:
import wandb

from non_rigid.utils.script_utils import create_model, load_checkpoint_config_from_wandb

# MOdel ID to verify against
model_id = "kr93ivph"

cfg = load_checkpoint_config_from_wandb(
        OmegaConf.create({
            "mode": "train",
            "wandb": {
                "entity": "r-pad",
                "project": "non-rigid",
                "artifact_dir": "/home/beisner/artifacts",
            },
            "checkpoint": {
                "reference": "r-pad/non_rigid/model-kr93ivph:v0",
            },
            "model": model_cfg,
            "training": {
                "lr": 1e-4,
                "weight_decay": 1e-5,
                "num_training_steps": NUM_TRAINING_STEPS,
                "lr_warmup_steps": 100,
                "additional_train_logging_period": 1000,
                "batch_size": 1,
                "val_batch_size": 1,
                "sample_size": None,
                "sample_size_anchor": None,
                "num_wta_trials": 10,
            },
            "dataset": {
                "data_dir": "/home/beisner/datasets/tax3d_data/proccloth/cloth=single-fixed anchor=single-random hole=single",
                "train_size": 1,
                "scene": False,
                "sample_size_action": 512,
                "sample_size_anchor": 512,
                "world_frame": False,
                "source": "dedo",
                "anchor_occlusion": False,
                "rotation_variance": 0.0,
                "translation_variance": 0.0,
                "action_transform_type": "identity",
                "anchor_transform_type": "identity",
                "center_type": "anchor_center",
                "action_context_center_type": "center",
                "downsample_type": "fps",
            }
        }), 
        {}, 
        "r-pad", 
        "non_rigid", 
        model_id,
    )


network, model = create_model(cfg)


# get checkpoint file (for now, this does not log a run)
checkpoint_reference = cfg.checkpoint.reference
if checkpoint_reference.startswith(cfg.wandb.entity):
    api = wandb.Api()
    artifact_dir = cfg.wandb.artifact_dir
    artifact = api.artifact(checkpoint_reference, type="model")
    ckpt_file = artifact.get_path("model.ckpt").download(root=artifact_dir)
else:
    ckpt_file = checkpoint_reference
# Load the network weights.
ckpt = torch.load(ckpt_file, map_location=device)
network.load_state_dict(
    {k.partition(".")[2]: v for k, v, in ckpt["state_dict"].items()}
)
# set model to eval mode
network.eval()
model.eval()

In [None]:
model = model.to(device)
ckpt_model_preds = model.predict_wta(batch, num_samples=10)
print(f"RMSE: {ckpt_model_preds['rmse'].cpu()}")
print(f"RMSE wta: {ckpt_model_preds['rmse_wta'].item()}")