### Reload .py files whenever there are changes

In [1]:
%load_ext autoreload
%autoreload 2

### Local imports from this project

In [2]:
from mdf.samplers import *
from mdf.data import *
from mdf.diff import *
from mdf.ops import *
from mdf import *

### External imports from dependencies

In [3]:
from pyvista.plotting.plotter import Plotter
from torch.utils.data import default_collate
from torch.utils.data import RandomSampler
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.distributions import Normal
from torch import Generator
from torch import normal
from torch import Tensor
import torch.nn as nn
import pyvista as pv
import torch.linalg
import typing as t
import numpy as np
import einops
import torch

### Environment setup

In [4]:
# Seed, backend and device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rng: Generator = torch.manual_seed(42)
pv.set_jupyter_backend('trame')
np.random.seed(42)

# Train Settings
k_evecs = 5
batch_size = 8
num_workers = 8
pin_memory = True
num_samples = 4000
prefetch_factor = 4

### Data Loading

In [5]:
# Handle various types of data and unify them
data: DataManager = DataManager(data_dir=data_dir, cache_dir=cache_dir)

# Fix a manifold and allow sampling function datapoints f : M -> Y
dataset: ManifoldDataset = data.dataset('stanford-bunny', 'weather', n_samples=num_samples, k_evecs=k_evecs, device=device)
min_emb: float = math.sqrt(dataset.sampler.verts.size(0)) * dataset.sampler.lbo_embedder.evecs.min().item()
max_emb: float = math.sqrt(dataset.sampler.verts.size(0)) * dataset.sampler.lbo_embedder.evecs.max().item()

# Split into train & test
train_dataset, valid_dataset, test_dataset = random_split(dataset, lengths=[0.75, 0.15, 0.10], generator=rng)

# Perform batched training by using DataLoaders
train_loader = DataLoader(
    train_dataset,
    drop_last=True,
    batch_size=batch_size,
    pin_memory=pin_memory,
    num_workers=num_workers,
    collate_fn=default_collate,
    prefetch_factor=prefetch_factor,
    sampler=RandomSampler(train_dataset, replacement=False, generator=rng),
)
valid_loader = DataLoader(
    train_dataset,
    shuffle=False,
    drop_last=True,
    pin_memory=pin_memory,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=default_collate,
    prefetch_factor=prefetch_factor,
)
test_loader = DataLoader(
    train_dataset,
    shuffle=False,
    pin_memory=pin_memory,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=default_collate,
    prefetch_factor=prefetch_factor,
)



### Training Algorithm

In [6]:
lr= 1e-4
epochs = 75
timesteps = 1000
schedule = 'linear'
loss_fn = torch.nn.MSELoss()
normal = Normal(torch.tensor(0.), torch.tensor(1.))
sampler = DDPM_Sampler(num_timesteps=timesteps, schedule=schedule)
model = DiffusionModel(
    dim_mlp=2048,
    dim_time=64,
    num_points=num_samples,
    sampler=sampler,
    dim_signal=3,
    n_heads=4,
    k_evecs=k_evecs,
    schedule=schedule,
    dropout=0.1,
    num_timesteps=timesteps,
    num_encoder_layers=6,
).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=lr)
sum(p.numel() for p in model.parameters() if p.requires_grad)

3853651

In [9]:
model = model.train()
train_loss = []

for epoch in range(epochs):
    for batch, samples in enumerate(train_loader):
        # f : M -> Y, where M = pos_embedding and Y = signal
        emb: Tensor = samples['pos_embedding'].to(device)
        sig: Tensor = samples['signal'].to(device)

        # Scale data to [-1, 1]
        emb = 2 * ((emb - min_emb) / (max_emb - min_emb)) - 1
        sig = 2 * sig - 1

        # Split batch in half for context and query
        c_emb, q_emb = torch.chunk(emb, chunks=2, dim=1)
        c_sig, q_sig = torch.chunk(sig, chunks=2, dim=1)

        # Sample timestep and noise vectors for context and query
        t: Tensor = torch.randint(0, timesteps, size=(batch_size,), device=device)

        # Create noisy context & query at timestep t
        c_t, c_sig_z, q_t, q_sig_z = model.encode(c_emb, c_sig, q_emb, q_sig, t)

        # Predict the noise for the score network
        q_sig_z_m = model.noise(c_t, q_t, t)

        # Compute the loss between the real and predicted noise values
        loss: Tensor = loss_fn(q_sig_z_m, q_sig_z)
        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()

        # Track progress
        train_loss.append(loss.detach().cpu().numpy())
        print('[epoch: {}/{}]-[batch: {}/{}]-[loss: {}]'.format(epoch, epochs, batch, len(train_loader), train_loss[-1]))

    if epoch % 10 == 0:
        torch.save(model.state_dict(), '../../res/weight{}s.pt'.format(epoch + 40))

[epoch: 0/75]-[batch: 0/643]-[loss: 0.03193941339850426]
[epoch: 0/75]-[batch: 1/643]-[loss: 0.06879963725805283]
[epoch: 0/75]-[batch: 2/643]-[loss: 0.10785198211669922]
[epoch: 0/75]-[batch: 3/643]-[loss: 0.004457989241927862]
[epoch: 0/75]-[batch: 4/643]-[loss: 0.014005173929035664]
[epoch: 0/75]-[batch: 5/643]-[loss: 0.04817329719662666]
[epoch: 0/75]-[batch: 6/643]-[loss: 0.13354746997356415]
[epoch: 0/75]-[batch: 7/643]-[loss: 0.04752251133322716]
[epoch: 0/75]-[batch: 8/643]-[loss: 0.010667912662029266]
[epoch: 0/75]-[batch: 9/643]-[loss: 0.015980258584022522]
[epoch: 0/75]-[batch: 10/643]-[loss: 0.0811275914311409]
[epoch: 0/75]-[batch: 11/643]-[loss: 0.07009480893611908]
[epoch: 0/75]-[batch: 12/643]-[loss: 0.07133609801530838]
[epoch: 0/75]-[batch: 13/643]-[loss: 0.07956581562757492]
[epoch: 0/75]-[batch: 14/643]-[loss: 0.04597074165940285]
[epoch: 0/75]-[batch: 15/643]-[loss: 0.06465107947587967]
[epoch: 0/75]-[batch: 16/643]-[loss: 0.06893786787986755]
[epoch: 0/75]-[batch:

KeyboardInterrupt: 

### Sampling Algorithm

In [14]:
model.load_state_dict(torch.load('../../res/weight41s.pt'))
model.eval()

DiffusionModel(
  (sampler): DDPM_Sampler()
  (time_mlp): Sequential(
    (0): SinusoidalPosEmb()
    (1): Linear(in_features=64, out_features=256, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=256, out_features=64, bias=True)
  )
  (c_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=72, out_features=72, bias=True)
        )
        (linear1): Linear(in_features=72, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=72, bias=True)
        (norm1): LayerNorm((72,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((72,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (q_encoder): TransformerEncoder(
    

Sample points from manifold

In [16]:
n_samples = 10
emb = []
pts = []

for i in range(n_samples):
    pos, faces, coefs = dataset.sampler.mesh_sampler.sample(n=4000)
    pos_embedding = dataset.sampler.lbo_embedder(faces, coefs)
    emb.append(pos_embedding)
    pts.append(pos)

p_pts = torch.stack(pts)
p_emb = torch.stack(emb).to(device)

Use those points as query and perform inference

In [17]:
with torch.inference_mode():
    q_emb = 2 * ((p_emb - min_emb) / (max_emb - min_emb)) - 1
    q_sig = model.decode(q_emb, subset=q_emb.size(1))

Timestep: 999
Timestep: 998
Timestep: 997
Timestep: 996
Timestep: 995
Timestep: 994
Timestep: 993
Timestep: 992
Timestep: 991
Timestep: 990
Timestep: 989
Timestep: 988
Timestep: 987
Timestep: 986
Timestep: 985
Timestep: 984
Timestep: 983
Timestep: 982
Timestep: 981
Timestep: 980
Timestep: 979
Timestep: 978
Timestep: 977
Timestep: 976
Timestep: 975
Timestep: 974
Timestep: 973
Timestep: 972
Timestep: 971
Timestep: 970
Timestep: 969
Timestep: 968
Timestep: 967
Timestep: 966
Timestep: 965
Timestep: 964
Timestep: 963
Timestep: 962
Timestep: 961
Timestep: 960
Timestep: 959
Timestep: 958
Timestep: 957
Timestep: 956
Timestep: 955
Timestep: 954
Timestep: 953
Timestep: 952
Timestep: 951
Timestep: 950
Timestep: 949
Timestep: 948
Timestep: 947
Timestep: 946
Timestep: 945
Timestep: 944
Timestep: 943
Timestep: 942
Timestep: 941
Timestep: 940
Timestep: 939
Timestep: 938
Timestep: 937
Timestep: 936
Timestep: 935
Timestep: 934
Timestep: 933
Timestep: 932
Timestep: 931
Timestep: 930
Timestep: 929
Timest

Show unconditional function sampling over the manifold

In [18]:
p = Plotter(shape=(1, n_samples))

for i in range(n_samples):
    p.subplot(0, i)
    mesh = pv.PolyData(p_pts[i].numpy())
    mesh.texture_map_to_plane(inplace=True)
    p.add_mesh(mesh.copy(), scalars=0.5 * (1 + q_sig[i]).cpu().numpy(), rgb=True)
    p.camera.tight()

p.show()

Widget(value='<iframe src="http://localhost:35823/index.html?ui=P_0x7a8819fd1050_1&reconnect=auto" class="pyvi…

### Data sample

In [7]:
p = Plotter(shape=(1, 3))
sample = dataset[5225]

# Show positional embedding
p.subplot(0, 0)
mesh = pv.PolyData(sample['pos'].numpy())
mesh.texture_map_to_plane(inplace=True)
p.add_mesh(mesh.copy(), scalars=sample['pos_embedding'][:, 1], clim=[-1, 1], cmap='RdBu_r')
p.camera.tight()

# Show signal
p.subplot(0, 1)
p.add_mesh(mesh.copy(), scalars=sample['signal'], rgb=True)
p.camera.tight()

# Show signal
p.subplot(0, 2)
p.add_mesh(pv.Plane(), texture=pv.Texture(sample['image'].numpy()))
p.camera.tight()

p.show()

Widget(value='<iframe src="http://localhost:40259/index.html?ui=P_0x7774e56cec50_0&reconnect=auto" class="pyvi…