# Playground

In [1]:
import torch
from uda import UNet
from uda.models import vanilla_unet, uda_unet
from pprint import pprint

In [2]:
config = vanilla_unet(1, 1, dim=3)
model = UNet(config)

pprint(config.__dict__)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"# parameters: {n_params:,}")

{'concat_hidden': True,
 'decoder_blocks': [[1024, 512, 512],
                    [512, 256, 256],
                    [256, 128, 128],
                    [128, 64, 64]],
 'dim': 3,
 'encoder_blocks': [[1, 64, 64],
                    [64, 128, 128],
                    [128, 256, 256],
                    [256, 512, 512],
                    [512, 1024, 1024]],
 'out_channels': 1,
 'use_pooling': True}
# parameters: 90,306,113


In [3]:
config = uda_unet(1, 1, dim=3)
model = UNet(config)

pprint(config.__dict__)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"# parameters: {n_params:,}")

{'concat_hidden': False,
 'decoder_blocks': [[128, 64, 64, 64],
                    [64, 32, 32, 32],
                    [32, 16, 16, 16],
                    [16, 8, 8, 8]],
 'dim': 3,
 'encoder_blocks': [[1, 8],
                    [8, 16, 16, 16],
                    [16, 32, 32, 32],
                    [32, 64, 64, 64],
                    [64, 128, 128, 128]],
 'out_channels': 1,
 'use_pooling': False}
# parameters: 2,277,977


In [4]:
model = model.cuda()

x = torch.empty(2, 1, 128, 128, 128).normal_()
x_ = model(x.cuda()).detach().cpu()

x_.shape

torch.Size([2, 1, 128, 128, 128])

In [1]:
import torch
from uda import VAE
from uda.models import uda_vae
from pprint import pprint

In [3]:
config = uda_vae((128, 128, 128), 1, dim=3)
model = VAE(config)

pprint(config.__dict__)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"# parameters: {n_params:,}")

{'decoder_blocks': [[256, 128, 128, 128],
                    [128, 64, 64, 64],
                    [64, 32, 32, 32],
                    [32, 16, 16, 16],
                    [16, 8, 8, 8]],
 'dim': 3,
 'encoder_blocks': [[1, 8],
                    [8, 16, 16, 16],
                    [16, 32, 32, 32],
                    [32, 64, 64, 64],
                    [64, 128, 128, 128],
                    [128, 256, 256, 256]],
 'input_size': (128, 128, 128),
 'latent_dim': 1024,
 'use_pooling': False}
# parameters: 59,480,409


In [4]:
model = model.cuda()

x = torch.empty(2, 1, 128, 128, 128).normal_()
x = x.cuda()

x_ = model(x).detach().cpu()

x_.shape

torch.Size([2, 1, 128, 128, 128])