In [1]:
import json
import numpy as np
import torch
import os, sys, shutil

from tensorboardX import SummaryWriter
from time import sleep
from torch import optim
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from tqdm import tqdm
import torch.distributions as D

import data as data_
import nn as nn_
import utils
from mfp_utils import *

from experiments import cutils
from nde import distributions, flows, transforms

import klampt
from klampt.plan import cspace, robotplanning
from klampt.plan.robotcspace import RobotCSpace
from klampt.model import collide
from klampt.model.trajectory import RobotTrajectory
from klampt.io import resource

from mfg_models import construct_Q

In [2]:
import argparse
import sys
parser = argparse.ArgumentParser()
# 模拟命令行输入
sys.argv = ['--exp_name 2D --linear_transform_type lu_no_perm --reg_OT_dir gen --num_training_steps 1000 --base_transform_type rq-coupling --tail_bound 5 --OT_part block_CL_no_perm --lbd_OT 2e-1 --gaussian_multi_dim 2 --num_train_data 2000 --LU_last false --NF_loss jeffery --lr_schedule adaptive --train_batch_size 2048 --learning_rate 1e-3']
# data
parser.add_argument('--exp_name', type=str, default='2D')
parser.add_argument('--dataset_name', type=str, default='gaussian_mixture',
                    choices=['gaussian_mixture', 'crowd_motion_gaussian', 'crowd_motion_gaussian_close'
                            'crowd_motion_gaussian_nonsmooth_obs', 'moons', 
                            'gaussian', '2spirals', 'checkerboard',
                            'power', 'gas', 'hepmass', 'miniboone', 'bsds300',
                            'robot_1'],
                    help='Name of dataset to use.')
parser.add_argument('--val_frac', type=float, default=1.,
                    help='Fraction of validation set to use.')

# optimization
parser.add_argument('--learning_rate', type=float, default=1e-3,
                    help='Learning rate for optimizer.')
parser.add_argument('--num_training_steps', type=int, default=1000,
                    help='Number of total training steps.')
# parser.add_argument('--anneal_learning_rate', type=int, default=1,
#                     choices=[0, 1],
#                     help='Whether to anneal the learning rate.')
parser.add_argument('--lr_schedule', type=str, default='adaptive',
                    choices=['none', 'cyclic', 'adaptive'])
parser.add_argument('--grad_norm_clip_value', type=float, default=5.,
                    help='Value by which to clip norm of gradients.')
parser.add_argument('--lbd_reg', type=float, default=0)


# flow details
parser.add_argument('--base_transform_type', type=str, default='rq-coupling',
                    choices=['affine-coupling', 'quadratic-coupling', 'rq-coupling',
                             'affine-autoregressive', 'quadratic-autoregressive',
                             'rq-autoregressive'],
                    help='Type of transform to use between linear layers.')
parser.add_argument('--linear_transform_type', type=str, default='lu_no_perm',
                    choices=['permutation', 'lu', 'svd', 'lu_no_perm'],
                    help='Type of linear transform to use.')
parser.add_argument('--num_flow_steps', type=int, default=7,
                    help='Number of blocks to use in flow.')
parser.add_argument('--hidden_features', type=int, default=256,
                    help='Number of hidden features to use in coupling/autoregressive nets.')
parser.add_argument('--tail_bound', type=float, default=5,
                    help='Box is on [-bound, bound]^2')
parser.add_argument('--num_bins', type=int, default=8,
                    help='Number of bins to use for piecewise transforms.')
parser.add_argument('--num_transform_blocks', type=int, default=2,
                    help='Number of blocks to use in coupling/autoregressive nets.')
parser.add_argument('--use_batch_norm', type=int, default=0,
                    choices=[0, 1],
                    help='Whether to use batch norm in coupling/autoregressive nets.')
parser.add_argument('--dropout_probability', type=float, default=0.25,
                    help='Dropout probability for coupling/autoregressive nets.')
parser.add_argument('--apply_unconditional_transform', type=int, default=1,
                    choices=[0, 1],
                    help='Whether to unconditionally transform \'identity\' '
                         'features in coupling layer.')
parser.add_argument('--base_net_act', type=str, default='relu',
                    choices=['relu', 'tanh'])

# logging and checkpoints
parser.add_argument('--monitor_interval', type=int, default=250,
                    help='Interval in steps at which to report training stats.')

# reproducibility
parser.add_argument('--seed', type=int, default=1638128,
                    help='Random seed for PyTorch and NumPy.')

# MFG
parser.add_argument('--gaussian_multi_dim', type=int, default=2)
parser.add_argument('--gaussian_multi_a',   type=float, default=10.)
parser.add_argument('--num_train_data',     type=int, default=2000)
parser.add_argument('--num_val_data',       type=int, default=10000)
parser.add_argument('--num_test_data',      type=int, default=10000)
parser.add_argument('--train_batch_size',   type=int, default=2048)
parser.add_argument('--val_batch_size',     type=int, default=512)
parser.add_argument('--test_batch_size',    type=int, default=512)
parser.add_argument('--lbd_OT',             type=float, default=2e-1)
parser.add_argument('--lbd_F',              type=float, default=0)
parser.add_argument('--lbd_F_E',            type=float, default=0.01)
parser.add_argument('--lbd_F_P',            type=float, default=1)
parser.add_argument('--reg_OT_dir',         type=str, default='gen', choices=['gen', 'norm'])
parser.add_argument('--OT_comp',            type=str, default='trajectory', choices=['trajectory', 'monge'])
parser.add_argument('--OT_part',            type=str, default='block_CL_no_perm', choices=['block', 'block_CL_no_perm', 'module'])
parser.add_argument('--interaction',        type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--LU_last',            type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--NF_loss',            type=str, default='jeffery', choices=[
                                            'KL_sampling', 'KL_density', 'jeffery'])
parser.add_argument('--val_score',          type=str, default='loss', choices=[
                                            'loss', 'L', 'G', 'F'])
parser.add_argument('--mixture_base',       type=str, default='gaussian', choices=[
                                            'gaussian', 'gaussian_mixture'])
parser.add_argument('--mixture_weight',     type=str, default='identical', choices=[
                                            'identical', 'undersample_one'])  
parser.add_argument('--F_ld_weight',        type=str, default='identical', choices=['identical'])                                                                          
parser.add_argument('--disc_scheme',        type=str, default='forward', choices=[
                                            'forward', 'centered', 'forward_2nd',
                                            'FD4_simp', 'FD1_simp', 'FD4_simp_symmetric'])
parser.add_argument('--NF_model',           type=str, default='default', choices=[
                                            'default', 'single_flow'])                                     
parser.add_argument('--obs_nonsmooth_val',  type=float, default=100.)
parser.add_argument('--interp_hist',        type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--n_interp',           type=int, default=5, 
                    help='Number of interpolated points inserted between flow points to better approximate MFG costs')
## robotics
parser.add_argument('--robot_init_pos',     type=str, default='default', choices=['default', 
                                            'under_table', 'under_table_2', 'under_table_3', 'under_table_4',
                                            'under_table_hard']) 
parser.add_argument('--robot_term_pos',     type=str, default='cup', choices=['cup']) 
parser.add_argument('--robot_var',          type=float, default=1e-5)
parser.add_argument('--robot_obs_val',      type=float, default=1e2)
parser.add_argument('--robot_1_obs',        type=str, default='thick_sigmoid_B=256')
parser.add_argument('--robot_1_obs_l',      type=int, default=3)
parser.add_argument('--robot_1_obs_act',    type=str, default='relu', choices=['relu', 'tanh'])
parser.add_argument('--robot_1_base_dist',  type=str, default='default', choices=['default', 'two_init'])
parser.add_argument('--obs_robot_sig',      type=lambda x: (str(x).lower() == 'true'), default=True) 

# misc.
parser.add_argument('--plotting_subset',    type=int, default=10000)
parser.add_argument('--load_best_val',      type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--compute_lip_bound',  type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--save_train_traj',    type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--syn_noise',          type=float, default=0.1)
parser.add_argument('--marker_size',        type=float, default=5)
parser.add_argument('--color',              type=str, default='order', choices=[
                                            'order', 'radius'])
parser.add_argument('--tabular_subset',     type=lambda x: (str(x).lower() == 'true'), default=False)
parser.add_argument('--tensor_type',        type=str, default='float', choices=['float', 'double']) 

args = parser.parse_args()
args = sanitize_args(args)

In [3]:
# =================================================================================== #
#                                       Meta                                          #
# =================================================================================== #

os.environ['DATAROOT'] = 'experiments/dataset/data/'
os.environ['SLURM_JOB_ID'] = '1'

torch.manual_seed(args.seed)
np.random.seed(args.seed)

assert torch.cuda.is_available()
device = torch.device('cuda')
torch.set_default_tensor_type(torch.cuda.FloatTensor)

  _C._set_default_tensor_type(t)


In [4]:
import torch

# 检查CUDA是否可用，然后选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Current device:", device)

import torch

# 假设 'cuda' 可用
if torch.cuda.is_available():
    device = torch.device("cuda")
    generator = torch.Generator(device=device)
else:
    device = torch.device("cpu")
    generator = torch.Generator(device=device)

# 使用这个生成器与 torch.randperm
rand_indices = torch.randperm(1, generator=generator)


Current device: cuda


In [5]:
# =================================================================================== #
#                                       Dataset                                       #
# =================================================================================== #
target_dist = None
space = None
num_mixtures = 8
weight = D.Categorical(torch.ones(num_mixtures,).to(device))
X_train, _, _, train_loader, val_loader, test_loader, target_dist = make_gaussian_mixture_data(args.mixture_base, args.gaussian_multi_dim, args.num_train_data, \
    args.num_val_data, args.num_test_data, args.train_batch_size, args.val_batch_size, args.test_batch_size, weight=weight)
train_generator = data_.batch_generator(train_loader)
train_loader
test_batch      = next(iter(train_loader)).to(device)
features        = args.gaussian_multi_dim

In [6]:
# =================================================================================== #
#                                       Model                                         #
# =================================================================================== #


# methods for chaining together the flow transformations
def create_linear_transform():
    return transforms.LULinear(features, identity_init=True)


def create_base_transform(i):
    act = F.relu


    return transforms.PiecewiseRationalQuadraticCouplingTransform(
        mask=utils.create_alternating_binary_mask(features, even=(i % 2 == 0)),
        transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
            in_features=in_features,
            out_features=out_features,
            hidden_features=args.hidden_features,
            context_features=None,
            num_blocks=args.num_transform_blocks,
            activation=act,
            dropout_probability=args.dropout_probability,
            use_batch_norm=args.use_batch_norm
        ),
        num_bins=args.num_bins,
        tails='linear',
        tail_bound=args.tail_bound,
        apply_unconditional_transform=args.apply_unconditional_transform
    )

def create_transform():
    flows = [
        transforms.CompositeTransform([
            create_linear_transform(),
            create_base_transform(i),
            create_base_transform(i+1)
        ]) for i in range(0, 2*args.num_flow_steps, 2)
    ]

    # flows = [
    #     transforms.CompositeTransform([
    #         create_linear_transform(),
    #         create_base_transform(1)
    #     ])
    # ]

    print((flows[0]))
    K = args.num_flow_steps
    transform = transforms.CompositeTransform(flows)

    return transform, K

# base dist

cov          = 0.0625 * torch.eye(features).to(device)
mean         = torch.zeros(features).to(device)
distribution = distributions.MultivarNormal((features,), mean=mean, cov=cov)

# create flows
transform, K = create_transform()
flow = flows.Flow(transform, distribution).to(device)
n_params = utils.get_num_parameters(flow)
print('There are {} trainable parameters in this model.'.format(n_params))

# create optimizer
optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate, weight_decay=args.lbd_reg)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True)



CompositeTransform(
  (_transforms): ModuleList(
    (0): LULinear()
    (1-2): 2 x PiecewiseRationalQuadraticCouplingTransform(
      (transform_net): ResidualNet(
        (initial_layer): Linear(in_features=1, out_features=256, bias=True)
        (blocks): ModuleList(
          (0-1): 2 x ResidualBlock(
            (linear_layers): ModuleList(
              (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
            )
            (dropout): Dropout(p=0.25, inplace=False)
          )
        )
        (final_layer): Linear(in_features=256, out_features=23, bias=True)
      )
      (unconditional_transform): PiecewiseRationalQuadraticCDF()
    )
  )
)
There are 3774638 trainable parameters in this model.


In [7]:
from torch import nn
# KKKKKKKK = nn.ModuleList(transforms)
len(list(transform._transforms)[::-1])
print(list(transform._transforms))

[CompositeTransform(
  (_transforms): ModuleList(
    (0): LULinear()
    (1-2): 2 x PiecewiseRationalQuadraticCouplingTransform(
      (transform_net): ResidualNet(
        (initial_layer): Linear(in_features=1, out_features=256, bias=True)
        (blocks): ModuleList(
          (0-1): 2 x ResidualBlock(
            (linear_layers): ModuleList(
              (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
            )
            (dropout): Dropout(p=0.25, inplace=False)
          )
        )
        (final_layer): Linear(in_features=256, out_features=23, bias=True)
      )
      (unconditional_transform): PiecewiseRationalQuadraticCDF()
    )
  )
), CompositeTransform(
  (_transforms): ModuleList(
    (0): LULinear()
    (1-2): 2 x PiecewiseRationalQuadraticCouplingTransform(
      (transform_net): ResidualNet(
        (initial_layer): Linear(in_features=1, out_features=256, bias=True)
        (blocks): ModuleList(
          (0-1): 2 x ResidualBlock(
            (li

In [8]:
print(K)
flow
# """
# Flow 类是一个包含变换和基础分布的模型。
# 它通常用于定义一个可逆的变换链（CompositeTransform），
# 以及一个简单的概率分布（_distribution），在这个例子中是多变量正态分布（MultivarNormal()）
# """

# """
# CompositeTransform：这是一个复合变换，它将多个变换组合在一起。
# 在此例子中，它包含了一个模块列表（ModuleList），其中每个模块也是一个 CompositeTransform。

# ModuleList (0-9)：ModuleList 包含了10个 CompositeTransform 对象，这表明模型中有10个相似的流程步骤或层，每个都执行相似的变换。
# """

# """ 
# 每个 CompositeTransform 包括：

# LULinear()：这是一个线性变换，通常用于进行可逆的线性操作，如LU分解。

# PiecewiseRationalQuadraticCouplingTransform (1-2)：包括两个此类型的变换，这类变换通常用于在保持一部分输入不变的同时，通过一种复杂的函数（在这里是由 ResidualNet 实现）修改输入的其余部分。

#     transform_net：一个 ResidualNet 网络，用于根据输入特征生成变换的参数。这个网络包括：
#         initial_layer：输入层，一个线性变换，将输入特征从1维扩展到256维。
#         blocks：包含多个 ResidualBlock，每个块包括两个线性层和一个Dropout层，用于增加模型的非线性和防止过拟合。
#         final_layer：输出层，另一个线性变换，将特征从256维压缩到23维。
#     unconditional_transform：一个 PiecewiseRationalQuadraticCDF，用于在变换中实现无条件的、复杂的可逆函数。
# """

7


Flow(
  (_transform): CompositeTransform(
    (_transforms): ModuleList(
      (0-6): 7 x CompositeTransform(
        (_transforms): ModuleList(
          (0): LULinear()
          (1-2): 2 x PiecewiseRationalQuadraticCouplingTransform(
            (transform_net): ResidualNet(
              (initial_layer): Linear(in_features=1, out_features=256, bias=True)
              (blocks): ModuleList(
                (0-1): 2 x ResidualBlock(
                  (linear_layers): ModuleList(
                    (0-1): 2 x Linear(in_features=256, out_features=256, bias=True)
                  )
                  (dropout): Dropout(p=0.25, inplace=False)
                )
              )
              (final_layer): Linear(in_features=256, out_features=23, bias=True)
            )
            (unconditional_transform): PiecewiseRationalQuadraticCDF()
          )
        )
      )
    )
  )
  (_distribution): MultivarNormal()
)

In [9]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np


# 设置图形
def setup_figure(rows, cols, figsize=(18, 6)):
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    return fig, axs

# 绘制矢量场
def plot_quiver(ax, data, last_data, title):
    U = data[:, 0] - last_data[:, 0]
    V = data[:, 1] - last_data[:, 1]
    ax.quiver(last_data[:, 0], last_data[:, 1], U, V, scale=10, scale_units='width', color='b', width=0.005)
    ax.set_title(title)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)

# 绘制散点图
def plot_scatter(ax, data, title):
    ax.plot(data[:, 0], data[:, 1], '.')
    ax.set_title(title)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)

# 绘制密度图
def plot_density(ax, data, title):
    sns.kdeplot(x=data[:, 0], y=data[:, 1], fill=True, thresh=0, levels=100, cmap="viridis", ax=ax)
    ax.set_title(title)
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)

# 主绘图函数
def main_plot(part, z_0):
    fig1, ax1 = setup_figure(2, 5)
    fig2, ax2 = setup_figure(2, 5)
    fig3, ax3 = setup_figure(2, 5)

    last_x = z_0.data.cpu().numpy()
    for i in range(8):
        x_plot = part[:, i, :].data.cpu().numpy()
        plot_quiver(ax1.flatten()[i], x_plot, last_x, f'Time step {i}')
        plot_scatter(ax2.flatten()[i], x_plot, f'Time step {i}')
        plot_density(ax3.flatten()[i], x_plot, f'Time step {i}')
        last_x = x_plot

    for fig in [fig1, fig2, fig3]:
        fig.tight_layout()
        fig.show()


# z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
# hist = hist_gen
# part = torch.cat(hist).reshape(len(hist), hist[0].shape[0], hist[0].shape[-1]).permute(1,0,2)
# K = part.shape[1]
# I = torch.arange(0, K, step=4)
# part = part[:,I,:]

# main_plot(part, z_0)

In [10]:
# z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
# data = hist_gen[28]
# data.requires_grad_(True)

# for N in range(1,30):
#     # 初始化结果为 logprob
#     log_density = torch.zeros_like(hist_ld_gen[0])

#     # 累加操作
#     for idx in range(N):
#         log_density += hist_ld_gen[idx] # 用 squeeze() 去除单一维度

#     log_density = torch.sum(log_density)

#     # log_density = torch.sum(torch.sum(torch.stack(hist_ld_norm[:N]), dim=0))
#     # # return log_prob + logabsdet, log_prob, logabsdet, hist, hist_ld, OT_cost, noise
#     grad = torch.autograd.grad(outputs=log_density, inputs=data, create_graph=True, grad_outputs=torch.ones_like(log_density), allow_unused=True)[0]
#     if grad is not None:
#         print('????????????', N)


In [88]:
# data = points
z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
data = z_K
print(data.shape)
# data.requires_grad_(True)
_, log_prob, _, hist_norm, hist_ld_norm, OT_cost_norm, noise = flow.log_prob(data)
# print("????????????", (hist_ld_norm[2]).requires_grad )
i = 15
print(len(hist_gen), hist_gen[0].shape, len(hist_norm), hist_norm[0].shape)
print(len(hist_ld_gen), hist_ld_gen[0].shape, len(hist_ld_norm), hist_ld_norm[0].shape)
print('sample', torch.sum(hist_gen[-1]-z_K), torch.sum(hist_gen[0]-z_0))
print('inverse', torch.sum(hist_norm[-1]-noise), torch.sum(hist_norm[0]-z_K))
for i in range(0, 29, 1):
    # print('noise, i',i , torch.sum(hist_gen[i]-hist_norm[28-i]))
    # print('log, i',i , torch.sum(hist_ld_gen[28-i]), torch.sum(hist_ld_norm[0+i]))
    print('log, i',i , torch.sum(hist_ld_gen[i]), torch.sum(hist_ld_norm[28-i]))
    # print(i, torch.mean(torch.sum(torch.stack(hist_ld_gen[-i:]), dim=0) + torch.sum(torch.stack(hist_ld_norm[:i]), dim=0)))
print('noise',torch.mean(noise-hist_norm[-1]))

# print(log_prob.shape,hist_ld_norm[0].shape)


data = hist_norm[4]
# data.requires_grad_(True)
print(hist_ld_gen.requires_grad)
for N in range(1,30):
    # 初始化结果为 logprob
    # log_density = torch.zeros_like(log_prob).requires_grad_(True)
    
    # # 累加操作
    # for idx in range(N):
    #     if hist_ld_norm[idx].requires_grad:
    #         print(idx)
    #         log_density += hist_ld_norm[idx].squeeze()  # 用 squeeze() 去除单一维度

    log_density = torch.mean(torch.sum(torch.stack(hist_ld_gen[:28]), dim=0))
    
    # log_density = torch.sum(torch.sum(torch.stack(hist_ld_norm[:N]), dim=0))
    # # return log_prob + logabsdet, log_prob, logabsdet, hist, hist_ld, OT_cost, noise
    # print('????????????', N, torch.sum(hist_ld_norm[N]))
    grad = torch.autograd.grad(outputs=log_density, inputs=data, create_graph=True, grad_outputs=torch.ones_like(log_density), allow_unused=True)[0]
    if grad is not None:
        print('????????????', N)

torch.Size([2048, 2])
29 torch.Size([2048, 2]) 29 torch.Size([2048, 2])
29 torch.Size([2048, 1]) 29 torch.Size([2048, 1])
sample tensor(0., grad_fn=<SumBackward0>) tensor(0.)
inverse tensor(0., grad_fn=<SumBackward0>) tensor(0., grad_fn=<SumBackward0>)
log, i 0 tensor(0.) tensor(318.8028, grad_fn=<SumBackward0>)
log, i 1 tensor(0.) tensor(941.0545, grad_fn=<SumBackward0>)
log, i 2 tensor(-318.6759, grad_fn=<SumBackward0>) tensor(0., grad_fn=<SumBackward0>)
log, i 3 tensor(-941.1310, grad_fn=<SumBackward0>) tensor(0.)
log, i 4 tensor(0., grad_fn=<SumBackward0>) tensor(-35.8481, grad_fn=<SumBackward0>)
log, i 5 tensor(0.) tensor(217.1418, grad_fn=<SumBackward0>)
log, i 6 tensor(35.5407, grad_fn=<SumBackward0>) tensor(0., grad_fn=<SumBackward0>)
log, i 7 tensor(-217.0289, grad_fn=<SumBackward0>) tensor(0.)
log, i 8 tensor(0., grad_fn=<SumBackward0>) tensor(-56.9320, grad_fn=<SumBackward0>)
log, i 9 tensor(0.) tensor(-459.4680, grad_fn=<SumBackward0>)
log, i 10 tensor(56.9140, grad_fn=<Sum

AttributeError: 'list' object has no attribute 'requires_grad'

In [54]:
z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
# 假设 hist_ld_norm 是一个包含张量的列表
for i, tensor in enumerate(hist_gen):
    print(i, tensor.requires_grad)
    # 只有当张量需要梯度时才尝试计算梯度
    if tensor.requires_grad:
        # 确保对应的 hist_norm[i] 也是张量，并且已经设置了 requires_grad
        for j in range(29):
            if hist_gen[j].requires_grad:  # 检查 hist_gen[j] 是否可以计算梯度
                grad = torch.autograd.grad(outputs=tensor, inputs=hist_gen[j], create_graph=True, grad_outputs=torch.ones_like(tensor), allow_unused=True)[0]
                if grad is not None:
                    print('Gradient available for index:', i, j, torch.sum(grad))
    # else:
    #     print('Tensor at index', i, 'does not require grad or does not have a grad_fn')

0 False
1 False
2 True
Gradient available for index: 2 2 tensor(4096.)
3 True
Gradient available for index: 3 2 tensor(3361.1338, grad_fn=<SumBackward0>)
Gradient available for index: 3 3 tensor(4096.)
4 True
Gradient available for index: 4 2 tensor(3361.1338, grad_fn=<SumBackward0>)
Gradient available for index: 4 3 tensor(4096., grad_fn=<SumBackward0>)
Gradient available for index: 4 4 tensor(4096.)
Gradient available for index: 4 5 tensor(4096.)
5 True
Gradient available for index: 5 2 tensor(3361.1338, grad_fn=<SumBackward0>)
Gradient available for index: 5 3 tensor(4096., grad_fn=<SumBackward0>)
Gradient available for index: 5 4 tensor(4096.)
Gradient available for index: 5 5 tensor(4096.)
6 True
Gradient available for index: 6 2 tensor(3222.2886, grad_fn=<SumBackward0>)
Gradient available for index: 6 3 tensor(4020.2825, grad_fn=<SumBackward0>)
Gradient available for index: 6 4 tensor(4020.2825, grad_fn=<SumBackward0>)
Gradient available for index: 6 5 tensor(4020.2825, grad_fn=<

In [81]:
z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
for i in range(28):
    print(i, i+1, torch.sum(hist_gen[i]-hist_gen[i+1].squeeze()))

0 1 tensor(0.)
1 2 tensor(-1872.8127, grad_fn=<SumBackward0>)
2 3 tensor(2225.5869, grad_fn=<SumBackward0>)
3 4 tensor(0., grad_fn=<SumBackward0>)
4 5 tensor(0., grad_fn=<SumBackward0>)
5 6 tensor(-41.2097, grad_fn=<SumBackward0>)
6 7 tensor(1384.4688, grad_fn=<SumBackward0>)
7 8 tensor(0., grad_fn=<SumBackward0>)
8 9 tensor(0., grad_fn=<SumBackward0>)
9 10 tensor(4238.3594, grad_fn=<SumBackward0>)
10 11 tensor(-1215.6812, grad_fn=<SumBackward0>)
11 12 tensor(0., grad_fn=<SumBackward0>)
12 13 tensor(0., grad_fn=<SumBackward0>)
13 14 tensor(-349.7755, grad_fn=<SumBackward0>)
14 15 tensor(560.6802, grad_fn=<SumBackward0>)
15 16 tensor(0., grad_fn=<SumBackward0>)
16 17 tensor(0., grad_fn=<SumBackward0>)
17 18 tensor(-2484.1509, grad_fn=<SumBackward0>)
18 19 tensor(382.2468, grad_fn=<SumBackward0>)
19 20 tensor(0., grad_fn=<SumBackward0>)
20 21 tensor(0., grad_fn=<SumBackward0>)
21 22 tensor(-2142.1882, grad_fn=<SumBackward0>)
22 23 tensor(-850.0787, grad_fn=<SumBackward0>)
23 24 tensor(0.

In [86]:
_, log_prob, _, hist_norm, hist_ld_norm, OT_cost_norm, noise = flow.log_prob(data)
print(len(hist_ld_norm))
for i in range(28):
    print(i, i+1, torch.sum(hist_norm[i]-hist_norm[i+1].squeeze()))

29
0 1 tensor(0., grad_fn=<SumBackward0>)
1 2 tensor(0., grad_fn=<SumBackward0>)
2 3 tensor(466.9605, grad_fn=<SumBackward0>)
3 4 tensor(741.4856, grad_fn=<SumBackward0>)
4 5 tensor(0., grad_fn=<SumBackward0>)
5 6 tensor(0., grad_fn=<SumBackward0>)
6 7 tensor(822.1868, grad_fn=<SumBackward0>)
7 8 tensor(1926.5005, grad_fn=<SumBackward0>)
8 9 tensor(0., grad_fn=<SumBackward0>)
9 10 tensor(0., grad_fn=<SumBackward0>)
10 11 tensor(-173.1436, grad_fn=<SumBackward0>)
11 12 tensor(2578.5776, grad_fn=<SumBackward0>)
12 13 tensor(0., grad_fn=<SumBackward0>)
13 14 tensor(0., grad_fn=<SumBackward0>)
14 15 tensor(216.7531, grad_fn=<SumBackward0>)
15 16 tensor(410.6828, grad_fn=<SumBackward0>)
16 17 tensor(0., grad_fn=<SumBackward0>)
17 18 tensor(0., grad_fn=<SumBackward0>)
18 19 tensor(985.2815, grad_fn=<SumBackward0>)
19 20 tensor(-3719.9143, grad_fn=<SumBackward0>)
20 21 tensor(0., grad_fn=<SumBackward0>)
21 22 tensor(0., grad_fn=<SumBackward0>)
22 23 tensor(-1192.4768, grad_fn=<SumBackward0>)


In [84]:
_, log_prob, _, hist_norm, hist_ld_norm, OT_cost_norm, noise = flow.log_prob(data)
# 假设 hist_ld_norm 是一个包含张量的列表
for i, tensor in enumerate(hist_ld_norm):
    # tensor = log_prob
    print(i, tensor.requires_grad)
    # 只有当张量需要梯度时才尝试计算梯度
    if tensor.requires_grad:
        # 确保对应的 hist_norm[i] 也是张量，并且已经设置了 requires_grad
        for j in range(29):
            grad = torch.autograd.grad(outputs=tensor, inputs=hist_norm[j], create_graph=True, grad_outputs=torch.ones_like(tensor), allow_unused=True)[0]
            if grad is not None:
                print('Gradient available for index:', i, j, torch.sum(grad))
    # else:
    #     print('Tensor at index', i, 'does not require grad or does not have a grad_fn')

0 False
1 False
2 True
3 True
Gradient available for index: 3 0 tensor(-654.1703, grad_fn=<SumBackward0>)
Gradient available for index: 3 1 tensor(-654.1703, grad_fn=<SumBackward0>)
Gradient available for index: 3 2 tensor(-654.1703, grad_fn=<SumBackward0>)
4 True
Gradient available for index: 4 0 tensor(-349.1279, grad_fn=<SumBackward0>)
Gradient available for index: 4 1 tensor(-349.1279, grad_fn=<SumBackward0>)
Gradient available for index: 4 2 tensor(-349.1279, grad_fn=<SumBackward0>)
Gradient available for index: 4 3 tensor(-1559.9524, grad_fn=<SumBackward0>)
5 False
6 True
7 True
Gradient available for index: 7 0 tensor(647.9644, grad_fn=<SumBackward0>)
Gradient available for index: 7 1 tensor(647.9644, grad_fn=<SumBackward0>)
Gradient available for index: 7 2 tensor(647.9644, grad_fn=<SumBackward0>)
Gradient available for index: 7 3 tensor(1392.3075, grad_fn=<SumBackward0>)
Gradient available for index: 7 4 tensor(3165.7761, grad_fn=<SumBackward0>)
Gradient available for index: 7

In [45]:
log_density = torch.mean(torch.sum(torch.stack(hist_ld_norm[:28]), dim=0))
print(log_density.requires_grad)
# log_density = hist_ld_norm[24]
data = hist_norm[4]
# data.requires_grad_(True)
for N in range(0,29):
    data = hist_norm[i]
    data.requires_grad_(True)  # 确保data可以计算梯度
    # log_density = torch.mean(torch.sum(torch.stack(hist_ld_gen[:28]), dim=0))
    grad = torch.autograd.grad(outputs=log_density, inputs=data, create_graph=True, grad_outputs=torch.ones_like(log_density), allow_unused=True)[0]
    if grad is not None:
        print('????????????', N)


True


In [83]:
samples, log_prob, logabsdet,  hist, hist_ld, noise = flow.sample1(args.train_batch_size)
# 计算 log_prob 相对于 noise 的梯度
grad = torch.autograd.grad(outputs=log_prob, inputs=noise, create_graph=True, grad_outputs=torch.ones_like(log_prob), allow_unused=True)[0]

# 输出梯度
print(grad)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [78]:
torch.sum(torch.stack(hist_ld_norm[:i]), dim=0).shape

# Running cost
hist   = hist_gen
part = torch.cat(hist).reshape(len(hist), hist[0].shape[0], hist[0].shape[-1]).permute(1,0,2) # B x K x d
K    = part.shape[1]
I = torch.arange(0, K, step=4)
print(I)
part = part[:,I,:]
K = part.shape[1]-1
I = torch.arange(K)


tensor([ 0,  4,  8, 12, 16, 20, 24, 28])


In [77]:
# =================================================================================== #
#                                      Training                                       #
# =================================================================================== #
beta = 0.5
def Energy(x, logp):
    energy = 0

    potential = (x[:,0] - 1)**2 + (x[:,1] - 1)**2
    energy = torch.mean(potential) + 1/beta * torch.mean(logp)
    return energy

def GradF2(x, logp):
    x = torch.squeeze(x)
    N = x.shape[0]
    energy = Energy(x, logp) * N
    grad = torch.autograd.grad(outputs=energy, inputs=x, create_graph=True, grad_outputs=torch.ones_like(energy))[0]
    gradF2 = (torch.sum(grad ** 2) * 1/N)
    return gradF2

def compute_loss(flow,  args, data=None, mode='train'):
    # forward
    # log_density, _, _, hist_norm, hist_ld_norm, OT_cost_norm, _ = flow.log_prob(data)
    # # Distribution.log_prob, Flow._log_prob, Flow._transform(inputs, context=context), CouplingTransform.forward
    # # return log_prob + logabsdet, log_prob, logabsdet, hist, hist_ld, OT_cost, noise
    # log_density = torch.mean(log_density)

    # sample if needed
    # z_K, logp0, logdet, hist_gen, hist_ld_gen, z_0 = flow.sample1(args.train_batch_size)
    # samples, log_prob, logabsdet,  hist, hist_ld, noise
    # grad = torch.autograd.grad(outputs=logp0-logdet, inputs=z_K, create_graph=True, grad_outputs=torch.ones_like(ld_gen))[0]
    # print('????????????',grad)
    # summed_tensor = torch.sum(torch.stack(hist_ld_gen), dim=0)
    # print('print(torch.mean(summed_tensor-ld_gen))',torch.mean(summed_tensor-ld_gen))
    # print('print(torch.sum(hist_ld_gen[0]))',torch.sum(hist_ld_gen[0]))
    # print('print(len(hist_ld_gen))',len(hist_ld_gen), hist_ld_gen[0].shape)
    # Distribution.sample, Flow._sample, Flow._transform.inverse, CouplingTransform.inverse
    # distribution matching / terminal cost (G)

    z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)


    
    KL_density  = torch.Tensor([0]).to(device)
    KL_sampling = torch.Tensor([0]).to(device)
    
    log_prob_1  = torch.mean(target_dist.log_prob(data))
    log_prob_0   = torch.mean(distribution.log_prob(z_0))
    log_prob_gen = torch.mean(target_dist.log_prob(z_K))
    ld_gen       = torch.mean(ld_gen)
    KL_density   = - (log_prob_gen + ld_gen - log_prob_0)
    G_cost       = log_prob_1 - log_density + KL_density
    G_cost =  Energy(z_K, log_density) - log_density

    # OT regularization (L)
    L_cost = torch.Tensor([0]).to(device)


    # sample a batch from the base to compute the OT cost
    
    # Running cost
    hist   = hist_gen
    part = torch.cat(hist).reshape(len(hist), hist[0].shape[0], hist[0].shape[-1]).permute(1,0,2) # B x K x d
    K    = part.shape[1]
    I = torch.arange(0, K, step=4)
    print(I)
    part = part[:,I,:]
    K = part.shape[1]-1
    I = torch.arange(K)
    
    Energy_seq = torch.zeros(K+1)
    Energy_seq[0]=Energy(part[:,0,:], logp0)   
    arclength_seq = torch.zeros(K)
    action_seq = torch.zeros(K)
    cosvalue = torch.zeros(K)
    running_loss = 0
    for i in range(1,K+1):
        # print(i)
        # print(part.shape,part[:,i,:].shape)
        Energy_seq[i] = Energy(part[:,i,:])
        arclength2 = torch.mean(torch.norm(part[:,i,:] - part[:,i-1,:], dim=[1])**2)
        arclength_seq[i-1] =arclength2
        gradF2 = GradF2(part[:,i,:])
        running_loss += torch.sqrt(gradF2 * arclength2)
        action_seq[i-1] = torch.sqrt(gradF2 * arclength2)
        cosvalue[i-1] = torch.sqrt((Energy_seq[i]-Energy_seq[i-1])**2 / (gradF2 * arclength2) )
    L = torch.mean(K * torch.norm(part[:,I,:] - part[:,I+1,:], dim=[1,2])**2)
    differences = Energy_seq[1:] - Energy_seq[:-1]
    energydiff_loss = (torch.max(differences)/torch.min(differences)-1)**2
    arclendiff_loss = (torch.max(arclength_seq)/torch.min(arclength_seq)-1)**2
    arcactdiff_loss = (torch.max(action_seq)/torch.min(action_seq)-1)**2
    # interaction (F)
    F_P    = torch.Tensor([0]).to(device)
    F_E    = torch.Tensor([0]).to(device)
    F_cost = torch.Tensor([0]).to(device)
    
    # Overall loss
    loss = G_cost + args.lbd_OT * (L)

    # return loss, G_cost, L_cost, F_cost
    return {'loss': loss, 'Terminal': G_cost, 'Running_OT': L, 'Running_action':running_loss, 'arc': (arcactdiff_loss), 'F': F_cost, 'Energy_seq': Energy_seq, 'arclength_seq': arclength_seq, 'action_seq': action_seq, 'cosvalue': cosvalue}

# main loop
# 初始化存储损失值的字典
losses = {'loss': [], 'Terminal': [], 'Running_OT': [], 'Running_action': [], 'arc': []}

args.num_training_steps = 10
for step in tqdm(range(args.num_training_steps), desc='Training Progress', unit='iter'):
    flow.train()
    optimizer.zero_grad()

    # grab data
    batch = next(train_generator).to(device)
    loss_dict = compute_loss(flow, args, mode='train')
    # 保存loss值
    losses['loss'].append(loss_dict['loss'].item())
    losses['Terminal'].append(loss_dict['Terminal'].item())
    losses['Running_OT'].append(loss_dict['Running_OT'].item())
    losses['Running_action'].append(loss_dict['Running_action'].item())
    losses['arc'].append(loss_dict['arc'].item())


    loss      = loss_dict['loss']

    loss.backward()
    clip_grad_norm_(flow.parameters(), args.grad_norm_clip_value)
    optimizer.step()


# 画图部分
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
loss_types = ['loss', 'Terminal', 'Running_OT', 'Running_action','arc']
titles = ['loss', 'Terminal', 'Running_OT', 'Running_action','arc']

for i, ax in enumerate(axs):
    ax.plot(losses[loss_types[i]], label=loss_types[i])
    ax.set_title(titles[i])
    ax.set_xlabel('Iterations')
    ax.set_ylabel('Value')
    ax.legend()

    # 标注最后一点的值
    final_value = losses[loss_types[i]][-1]
    ax.annotate(f'{final_value:.2f}', # 格式化为两位小数
                xy=(len(losses[loss_types[i]]) - 1, final_value),
                xytext=(8, 0), 
                textcoords='offset points',
                ha='center', va='bottom')

plt.tight_layout()
plt.show()

print('running_loss',losses['Running_action'][-1])

loss_dict = compute_loss(flow, points, args, mode='train')
loss_types = ['Energy_seq', 'arclength_seq', 'action_seq', 'cosvalue']
titles = ['Energy Sequence', 'Arclength Sequence', 'Action Sequence', 'cosvalue']

# 创建4个子图
fig, axs = plt.subplots(1, 4, figsize=(20, 5))

# 遍历每个子图并绘制相应的loss数据
for i, ax in enumerate(axs):
    ax.plot(loss_dict[loss_types[i]].data.cpu().numpy(), marker='o')
    ax.set_title(titles[i])
    ax.set_xlabel('Iterations')
    ax.legend()

plt.tight_layout()
plt.show()

print('energy difference', loss_dict['Energy_seq'].max()-loss_dict['Energy_seq'].min())

z_K, ld_gen, OT_cost_gen, hist_gen, hist_ld_gen, z_0 = flow.sample(args.train_batch_size)
hist = hist_gen
part = torch.cat(hist).reshape(len(hist), hist[0].shape[0], hist[0].shape[-1]).permute(1,0,2)
K = part.shape[1]
I = torch.arange(0, K, step=4)
part = part[:,I,:]

main_plot(part, z_0)

Training Progress:   0%|          | 0/10 [00:00<?, ?iter/s]


ValueError: The value argument to log_prob must be a Tensor