In [1]:
import torch
import math
import torch.nn.functional as F

# Set ranges and bins
xr = [0.1, 3.5]  # z1-axis range (CV1)
yr = [0.1, 3.5]  # z2-axis range (CV2)
zr = [0.1, 3.5]  # z3-axis range (CV3)
nx, ny, nz = 100, 100, 100  # Number of bins for each dimension
dx = (xr[1] - xr[0]) / nx
dy = (yr[1] - yr[0]) / ny
dz = (zr[1] - zr[0]) / nz
# Apply smoothing
sigma = 0.3  # Standard deviation of Gaussian kernel
#logP_tensor = smooth_tensor(logP_tensor, sigma=sigma)

# Define grid points
x = torch.linspace(xr[0], xr[1], nx)
y = torch.linspace(yr[0], yr[1], ny)
z = torch.linspace(zr[0], zr[1], nz)


def gaussian_kernel(size: int, sigma: float) -> torch.Tensor:
    """
    Create a 1D Gaussian kernel using PyTorch.
    """
    x = torch.arange(size, dtype=torch.float64) - (size - 1) / 2.0
    kernel = torch.exp(-0.5 * (x / sigma)**2)
    kernel /= kernel.sum()  # Normalize
    return kernel

def smooth_tensor(tensor: torch.Tensor, sigma: float) -> torch.Tensor:
    """
    Smooth a 3D tensor using a Gaussian kernel.
    """
    # Create 1D Gaussian kernel
    kernel_size = int(2 * math.ceil(3 * sigma) + 1)  # Ensure sufficient coverage
    kernel = gaussian_kernel(kernel_size, sigma)

    if tensor.ndim == 3:  # For 3D tensors
        kernel_nd = kernel.view(1, 1, -1) * kernel.view(1, -1, 1) * kernel.view(-1, 1, 1)
    elif tensor.ndim == 2:  # For 2D tensors
        kernel_nd = kernel.view(1, -1) * kernel.view(-1, 1)
    else:
        raise ValueError("Tensor dimensionality not supported for smoothing.")

    kernel_nd = kernel_nd / kernel_nd.sum()  # Normalize the kernel

    # Expand dimensions for convolution
    kernel_nd = kernel_nd.to(tensor.device)
    kernel_nd = kernel_nd.unsqueeze(0).unsqueeze(0)  # Shape [1, 1, *kernel_shape]

    # Pad tensor to handle edges
    padding = kernel_size // 2
    tensor = tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    padding_args = (padding, padding, padding) if tensor.ndim == 5 else (padding, padding)
    tensor_padded = torch.nn.functional.pad(tensor, padding_args, mode="replicate")

    # Convolve with Gaussian kernel
    conv_fn = torch.nn.functional.conv3d if tensor.ndim == 5 else torch.nn.functional.conv2d
    smoothed_tensor = conv_fn(tensor_padded, kernel_nd)
    return smoothed_tensor.squeeze()

def interpolate_logP(coords, logP_tensor):
    """
    Interpolates logP using bilinear or trilinear interpolation with PyTorch.
    Handles both 2D and 3D cases dynamically.
    """
    
    dim = 1

    # Dummy domain range — replace with real bounds if known
    ranges = [(0.1, 3.5)] * dim

    # Ensure coords is torch tensor on same device as logP_tensor
    coords = torch.as_tensor(coords, dtype=torch.float32, device=logP_tensor.device)
    #print(coords)

    # Normalize to [-1, 1]
    mins = torch.tensor([r[0] for r in ranges], dtype=coords.dtype, device=coords.device)
    spans = torch.tensor([r[1] - r[0] for r in ranges], dtype=coords.dtype, device=coords.device)

    coords_normalized = (coords - mins) / spans
    coords_normalized = 2.0 * coords_normalized - 1.0
    coords_normalized = torch.clamp(coords_normalized, -1.0, 1.0)

    logP_tensor = logP_tensor.view(1, 1, 1, -1)
    coords_normalized = coords_normalized.view(1, 1, -1, 1)
    grid = torch.cat([coords_normalized, torch.zeros_like(coords_normalized)], dim=-1)  # [1, 1, N, 2]
    sampled = F.grid_sample(logP_tensor, grid, mode='bilinear', align_corners=True)
    return sampled.view(-1)

In [2]:
pot_dir = './One_Ca'
sequence = 'KLVFFAE'
sequence_list = list(sequence)

def main(data,res):
    # Create a mapping dictionary for the sequence
    sequence_to_file = {}
    file_name = f"{sequence_list[res]}_logP1D.pt"
    sequence_to_file[sequence_list[res]] = os.path.join(pot_dir, file_name)
   
    # Load saved logP data from .pt file
    pot_name = sequence_to_file[sequence_list[res]]
    #logP_tensor = torch.load(pot_name).double()  # Ensure double precision for smoothing
    logP_tensor = torch.load(pot_name).float()
    data_torch = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    predicted_logP = interpolate_logP(data_torch,logP_tensor)

    return predicted_logP

In [7]:
import numpy as np
import os

# Load coordinates (assumed shape: n_frames, 7, 3)
coordinates = np.load("Ca_residues_1_to_7.npy")

# Convert from Å to nm
coordinates = coordinates * 0.1

# Thermal energy
kBT_kcal = 0.593
#kBT_kj = kBT_kcal * 4.184  # kJ/mol

# Define the range of z adjustments to test (in nm)
adjust_z_range = np.linspace(-0.6, 0.6, 21)  # from -1.0 to +1.0 nm in 0.05 nm steps

best_energies = []
best_adjusts = []

for frame_idx in range(coordinates.shape[0]):
    min_energy = np.inf
    best_shift = None

    for z_shift in adjust_z_range:
        frame_logp = []

        for res_idx in range(7):
            coord = coordinates[frame_idx, res_idx].copy()
            coord[2] += z_shift  # apply z-shift to all residues
            #print(coord[2])
            try:
                logp_value = main(Au-coord[2], res_idx)
                frame_logp.append(logp_value.detach().numpy())
            except Exception as e:
                frame_logp.append(np.nan)

        frame_logp = np.array(frame_logp)
        total_logp = np.nansum(frame_logp)
        total_energy = -kBT_kcal * total_logp

        if total_energy < min_energy:
            min_energy = total_energy
            best_shift = z_shift

    best_energies.append(min_energy)
    best_adjusts.append(best_shift)

# Convert to numpy arrays
best_energies = np.array(best_energies)  # shape: (n_frames,)
best_adjusts = np.array(best_adjusts)    # shape: (n_frames,)

# Print or save results
#print("Minimum energy per frame (kJ/mol):", best_energies[137])
print("Best z-shift per frame (nm):", best_adjusts)
np.save('z_adjust',best_adjusts)

Best z-shift per frame (nm): [-0.12 -0.06 -0.06  0.06 -0.06  0.06 -0.12 -0.18 -0.06  0.    0.    0.12
  0.06 -0.12  0.06 -0.06 -0.24  0.06  0.06 -0.06  0.06  0.    0.   -0.06
  0.    0.12  0.   -0.06 -0.06 -0.18  0.12 -0.06  0.   -0.12  0.06 -0.06
  0.   -0.06  0.   -0.06 -0.12  0.12 -0.12  0.12  0.06  0.06 -0.12  0.
  0.06 -0.12 -0.06 -0.06  0.   -0.12 -0.06  0.06 -0.18  0.12  0.   -0.06
  0.06  0.06  0.   -0.06 -0.06  0.06  0.06 -0.06  0.06 -0.12 -0.12 -0.18
 -0.24 -0.06  0.06 -0.12  0.    0.06  0.06 -0.18 -0.06  0.    0.   -0.18
  0.   -0.06 -0.06  0.06  0.06  0.12  0.12  0.    0.    0.18 -0.06 -0.12
  0.06  0.    0.   -0.18]
