In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tava.models.basic.mlp import StraightMLP
from tava.models.basic.posi_enc import PositionalEncoder, TCNNHashPositionalEncoder


class SkinWeightsNetMLP(nn.Module):
    def __init__(self, n_transforms: int):
        super().__init__()
        self.posi_enc = PositionalEncoder(
            in_dim=3, min_deg=0, max_deg=4, append_identity=True
        )
        self.net = StraightMLP(
            net_depth=4,
            net_width=128,
            input_dim=self.posi_enc.out_dim,
            output_dim=n_transforms,
        )


class SkinWeightsNetNGP(nn.Module):
    def __init__(self, n_transforms: int):
        super().__init__()
        bounding_box = [-0.07, -0.25, -0.05, 0.07, 0.25, 0.40]
        self.posi_enc = TCNNHashPositionalEncoder(
            bounding_box=bounding_box,
            in_dim=3, n_levels=4, per_level_scale=1.5
        )
        self.net = StraightMLP(
            net_depth=1,
            net_width=64,
            input_dim=self.posi_enc.out_dim,
            output_dim=n_transforms,
        )

In [None]:
from tava.datasets.animal_parser import SubjectParser

parser = SubjectParser("Hare_male_full_RM", "./data/animal")
