In [1]:
import torch

In [2]:
DEVICE = "cpu"

In [3]:
from improved_diffusion.unet import UNetModel
from improved_diffusion.losses import ODEFlowMatchingLoss
from improved_diffusion.functional import ode_euler_integration


model = UNetModel(
    in_channels=198, # should be equal to num_features (input features) 
    dims=1, #this states, that we are using 1D U-Net
    condition_dims=1, # num_condition_features
    model_channels=256, # inner model features
    out_channels=198, # should be equal to num_features (input features) 
    num_res_blocks=10, # idk
    attention_resolutions=("16",) # idk
)

x_0 = torch.rand(23, 64, 198) # our input [batch_size, num_atoms, num_features]
#num_atoms should be a 2 to some power
t = torch.rand(23) # our time [batch_size]
y = torch.rand(23, 1) * 10 # features to condition on [batch_size, num_condition_features]

model(
    x=x_0, 
    timesteps=t, 
    y=y
).shape # torch.Size([23, 64, 198]), which matches x.shape torch.Size([23, 64, 198])

torch.Size([23, 64, 198])

In [4]:
loss = ODEFlowMatchingLoss(reduction="mean")

In [5]:
x_1 = torch.rand(23, 64, 198) # our input [batch_size, num_atoms, num_features]

loss(model, x_0, x_1, t, y)

tensor(0.1671, grad_fn=<MseLossBackward0>)

In [None]:
x_pred = ode_euler_integration(model, x_0, y, DEVICE)

In [None]:
x_pred.shape

In [1]:
from improved_diffusion.unet import UNetRegressor

model = UNetRegressor(
    in_channels=3, # should be equal to num_features (input features)
    out_dims=1, # regressor dims
    dims=1, #this states, that we are using 1D U-Net
    condition_dims=256,
    model_channels=128, # inner model features
    num_res_blocks=10, # idk
    attention_resolutions=("16",) # idk
)

model.eval()

UNetRegressor(
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (label_emb): Linear(in_features=256, out_features=512, bias=True)
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv1d(3, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (1-10): 10 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=128, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0, inplace=False)
          (3): Conv1d(128, 128, kernel_size=(3

In [2]:
import torch

x_cond = torch.zeros(23, 256)
x_coords = torch.nn.Parameter(torch.rand((23, 64, 3), requires_grad=True))

In [None]:
# Optimization 

optimizer = torch.optim.SGD([{'params': x_coords, 'lr':1e-3}])

for epoch in range(10):
    loss = model(x_coords, y=x_cond).mean()
    print(loss.shape)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(1)

torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
1
torch.Size([23, 8192])
torch.Size([])
