In [1]:
import torch
%load_ext autoreload
%autoreload 2


## models.base.Base & models.bijectors.CircularShift Test

Problem: nflow generate samples whose require_grad=False

In [2]:
from models.model import make_model
from experiments.configs import get_config
config = get_config(16)
model, energy_fn = make_model(
    -torch.pi, torch.pi, **config.model['kwargs'])


In [3]:
from models.transforms import Dihedral2Coord
from models.energy import Energy
torch.autograd.set_detect_anomaly(True)
trans = Dihedral2Coord(
    mol=model._distribution.mol,
    angles=model._distribution.torsion_angles)
    # mol=model._distribution.mol)

In [4]:
model._distribution.torsion_angles

tensor([[ 5,  0,  1,  2],
        [ 0,  1,  2,  3],
        [ 7,  4,  0,  5],
        [ 8,  5,  0,  1],
        [12,  7,  4, 14],
        [11, 10,  5,  8],
        [13, 11, 10,  5],
        [15, 14,  4,  7]])

In [5]:
angles_sample = torch.rand(size=(128, model._distribution.torsion_angles.shape[0]), requires_grad=True)
coord_sample = trans(angles_sample)

In [6]:
energy_sample = Energy.apply(coord_sample, model._distribution.mol)
# coord_sample.register_hook(lambda grad: print(grad))
loss = energy_sample.mean()
loss.backward()


In [7]:
energy_sample.shape

torch.Size([128])

In [12]:
print(angles_sample.grad)


tensor([[-25.0157,   5.5373,   1.7646,  ...,  -9.2669,  -1.0254,  -0.1668],
        [ -3.5710,  -1.8144,   2.1803,  ...,   3.2451,   4.0137, -19.7645],
        [  1.0107,  -1.0144,   0.3451,  ...,  -2.6927,  -1.5392,  23.0678],
        ...,
        [ 23.6450,  25.6512,   0.2802,  ...,  11.0567,  -3.1502,  -6.1331],
        [  2.5039,   9.6690,  -0.4413,  ...,   1.9024,   0.6822,  -1.0110],
        [-17.2410,  13.7924,   1.1858,  ...,   6.0832,  -6.6854, -10.3659]])


In [17]:
loss, logdet = model.sample_and_log_prob(128)


In [22]:
real_loss = (loss + logdet[:, None]).mean()


In [24]:
real_loss.backward()

In [48]:
model._transform._transforms[0]._transforms[1].conditioner.linear1.weight.grad


tensor([[ 0.0239, -0.3383,  0.3175,  ..., -0.1333, -0.3284, -0.4040],
        [-0.1827,  1.1365, -0.5813,  ..., -0.3947,  0.3445,  0.7744],
        [-0.0460,  0.5030, -0.5148,  ..., -0.1440,  0.2224,  0.9553],
        ...,
        [ 0.9095, -0.1586,  0.2059,  ...,  0.0287, -0.6542, -0.1489],
        [-0.2702, -0.0838,  0.3339,  ..., -0.3490,  0.1917, -0.4530],
        [-0.1773,  0.2773, -0.5189,  ..., -0.1973, -0.0033,  0.5414]])

: 

In [12]:
j = torch.rand(size=(4, 17))
k = torch.rand(size=(4, 17))
l = torch.rand(size=(4, 17))
ls = [j, k, l]
torch.stack(ls).shape

torch.Size([3, 4, 17])

In [32]:
import torch
i=torch.rand(size=(5,4))
j=torch.rand(size=(4, 17))
k=torch.rand(size=(4, 17))
l=torch.rand(size=(4, 17))
from models.spline import _rational_quadratic_spline_fwd
_rational_quadratic_spline_fwd(i,j,k,l)

(tensor([[0.8167, 0.5797, 3.3655, 0.3703],
         [0.7785, 2.3855, 0.4710, 2.0032],
         [0.9357, 2.4022, 0.5809, 1.4372],
         [0.9357, 2.4022, 0.4772, 0.6444],
         [0.8134, 0.5167, 0.5041, 1.9168]]),
 tensor([[-2.0675, -1.5534,     nan, -0.5470],
         [-2.0675,     nan, -1.2066,     nan],
         [    nan,     nan, -0.1418,     nan],
         [    nan,     nan, -1.2066, -1.1884],
         [-2.0675, -1.5534, -0.1418,     nan]]))

In [39]:
from torch import Tensor
from typing import Tuple
def _rqs_fwd_single(x: Tensor,
                                   x_pos: Tensor,
                                   y_pos: Tensor,
                                   knot_slopes: Tensor) -> Tuple[Tensor, Tensor]:
  """Applies a rational-quadratic spline to a scalar.
  Args:
    x: a scalar (0-dimensional array). The scalar `x` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the transformation and the log of the
    absolute first derivative at `x`.
  """
  # Search to find the right bin. NOTE: The bins are sorted, so we could use
  # binary search, but this is more GPU/TPU friendly.
  # The following implementation avoids indexing for faster TPU computation.
  below_range = x <= x_pos[0]
  above_range = x >= x_pos[-1]
  correct_bin = torch.logical_and(x >= x_pos[:-1], x < x_pos[1:])
  any_bin_in_range = torch.any(correct_bin)
  first_bin = torch.concat([torch.tensor([1], dtype=bool),
                               torch.zeros(len(correct_bin)-1, dtype=bool)])
  # If y does not fall into any bin, we use the first spline in the following
  # computations to avoid numerical issues.
  correct_bin = torch.where(any_bin_in_range, correct_bin, first_bin)
  # Dot product of each parameter with the correct bin mask.
  params = torch.stack([x_pos, y_pos, knot_slopes], axis=1)
  params_bin_left = torch.sum(correct_bin[:, None] * params[:-1], axis=0)
  params_bin_right = torch.sum(correct_bin[:, None] * params[1:], axis=0)

  x_pos_bin = (params_bin_left[0], params_bin_right[0])
  y_pos_bin = (params_bin_left[1], params_bin_right[1])
  knot_slopes_bin = (params_bin_left[2], params_bin_right[2])

  bin_width = x_pos_bin[1] - x_pos_bin[0]
  bin_height = y_pos_bin[1] - y_pos_bin[0]
  bin_slope = bin_height / bin_width

  z = (x - x_pos_bin[0]) / bin_width
  # `z` should be in range [0, 1] to avoid NaNs later. This can happen because
  # of small floating point issues or when x is outside of the range of bins.
  # To avoid all problems, we restrict z in [0, 1].
  z = torch.clip(z, 0., 1.)
  sq_z = z * z
  z1mz = z - sq_z  # z(1-z)
  sq_1mz = (1. - z) ** 2
  slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope
  numerator = bin_height * (bin_slope * sq_z + knot_slopes_bin[0] * z1mz)
  denominator = bin_slope + slopes_term * z1mz
  y = y_pos_bin[0] + numerator / denominator

  # Compute log det Jacobian.
  # The logdet is a sum of 3 logs. It is easy to see that the inputs of the
  # first two logs are guaranteed to be positive because we ensured that z is in
  # [0, 1]. This is also true of the log(denominator) because:
  # denominator
  # == bin_slope + (knot_slopes_bin[1] + knot_slopes_bin[0] - 2 * bin_slope) *
  # z*(1-z)
  # >= bin_slope - 2 * bin_slope * z * (1-z)
  # >= bin_slope - 2 * bin_slope * (1/4)
  # == bin_slope / 2
  logdet = 2. * torch.log(bin_slope) + torch.log(
      knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz +
      knot_slopes_bin[0] * sq_1mz) - 2. * torch.log(denominator)

  # If x is outside the spline range, we default to a linear transformation.
  y = torch.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y)
  y = torch.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y)
  logdet = torch.where(below_range, torch.log(knot_slopes[0]), logdet)
  logdet = torch.where(above_range, torch.log(knot_slopes[-1]), logdet)
  return y, logdet


In [9]:
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

weight = torch.rand(size=(2,2), requires_grad=True)
input = torch.rand(size=(1,2), requires_grad=True)
output = LinearFunction.apply(input, weight)
loss = output.mean()
loss.backward()

In [17]:
print(output.grad)


None
