In [1]:
from typing import List, Tuple, Union
import chex
import e3nn_jax as e3nn
import jax.numpy as jnp
import sys
sys.path.append('/u/danaru/moleculib/moleculib') 
from nucleic.datum import *
import jax
from jax.tree_util import tree_map, tree_flatten, tree_flatten_with_path
from einops import rearrange, repeat
import haiku as hk



In [2]:
import plotly.graph_objects as go
import re
from functools import partial
from typing import Callable, Tuple, Dict, List
from collections import defaultdict

In [3]:
@chex.dataclass
class InternalState:
    irreps_array: e3nn.IrrepsArray  # (seq_len, irreps) - irreps array
    mask_irreps_array: jnp.ndarray  # (seq_len,) - mask for irreps array
    coord: jnp.ndarray  # (seq_len, 3) - coordinates of C3' atom
    mask_coord: jnp.ndarray  # (seq_len,) - mask for coordinates

    @property
    def irreps(self) -> e3nn.Irreps:
        return self.irreps_array.irreps

    @property
    def seq_len(self) -> int:
        return self.irreps_array.shape[0]

    @property
    def mask(self) -> jnp.ndarray:
        return self.mask_irreps_array & self.mask_coord

In [4]:
def knn(coord: jnp.ndarray, mask: jnp.ndarray, k:int):
    n, d = coord.shape
    assert mask.shape == (n,)
    k = min(k, n-1)
    distance_matrix = jnp.sum(jnp.square(coord[:, None,:] - coord[None, :, :]), axis = -1)
    assert distance_matrix.shape == (n,n)
    matrix_mask = mask[:,None] & mask[None, :]
    assert matrix_mask.shape == (n,n)

    distance_matrix = jnp.where(matrix_mask, distance_matrix, jnp.inf)
    neg_dist, neighbors = jax.lax.top_k(-distance_matrix, k+1)

    #discard the nearest point which is the point itself
    mask = neg_dist[:, 1:] != -jnp.inf
    neighbors = neighbors[:,1:] 
    assert neighbors.shape == (n,k)
    assert mask.shape == (n,k)
    return neighbors, mask



In [5]:
class Scheduler:
    def __call__(self, step: int) -> float:
        raise NotImplementedError


class LinearScheduler(Scheduler):
    def __init__(self, init_value, end_value, transition_steps, transition_begin):
        self.init_value = init_value
        self.end_value = end_value
        self.transition_steps = transition_steps
        self.transition_begin = transition_begin

    def __call__(self, step):
        return jax.optax.schedule.linear(
            init_value=self.init_value,
            end_value=self.end_value,
            transition_steps=self.transition_steps,
            transition_begin=self.transition_begin,
        )(step)

class CyclicAnnealingScheduler(Scheduler):
    def __init__(
        self,
        init_value,
        end_value,
        transition_steps,
        transition_begin,
    ):
        self.init_value = init_value
        self.end_value = end_value
        self.transition_steps = transition_steps
        self.transition_begin = transition_begin

    def __call__(self, step):
        eff_step_raw = step - self.transition_begin
        eff_step = eff_step_raw % (self.transition_steps)
        value = jnp.where(eff_step >= self.transition_steps / 2, self.end_value, 0.0)
        value = jnp.where(
            eff_step < self.transition_steps / 2,
            (self.end_value - self.init_value) * eff_step / (self.transition_steps / 2)
            + self.init_value,
            value,
        )
        value = jnp.where(eff_step_raw < 0, self.init_value, value)
        return value

In [6]:
@chex.dataclass
class ProbParams:
    state: InternalState
    mu: jnp.ndarray
    sigma: jnp.ndarray
    sigma_basis: e3nn.Irreps
    sigma_flat: jnp.ndarray


@chex.dataclass
class ProbPair:
    prior: ProbParams
    posterior: ProbParams
    mask: jnp.ndarray
    sample: jnp.ndarray
    
@chex.dataclass
class ModelOutput:
    logits: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
    datum: NucleicDatum
    encoder_internals: List[InternalState]
    decoder_internals: List[InternalState]
    probs: Union[List[ProbPair], None]
    atom_perm_loss: jnp.ndarray

In [7]:
from typing import Callable
from typing import Union



class Residual(hk.Module):
    def __init__(
        self,
        function: Callable[[InternalState], InternalState],
        name: Union[str, None] = None,
    ):
        super().__init__(name)
        self.function = function

    def __call__(self, state: InternalState) -> InternalState:
        assert state.irreps_array.ndim == 2
        assert state.mask.ndim == 1
        assert state.coord.ndim == 2

        new_state = self.function(state)

        seq_len = state.irreps_array.shape[0]
        new_seq_len = new_state.irreps_array.shape[0]

        if new_seq_len > seq_len:
            raise ValueError("Residual block cannot increase sequence length")

        if new_seq_len < seq_len:
            if (seq_len - new_seq_len) % 2 != 0:
                raise ValueError(
                    "Residual block cannot decrease sequence length by odd number"
                )

            pad = (seq_len - new_seq_len) // 2
            state = jax.tree_util.tree_map(lambda x: x[pad:-pad], state)

        if state.irreps_array.shape == new_state.irreps_array.shape:
            features = state.irreps_array + new_state.irreps_array
        else:
            features = e3nn.haiku.Linear(new_state.irreps_array.irreps)(
                e3nn.concatenate(
                    [
                        state.irreps_array * state.mask_irreps_array[:, None],
                        new_state.irreps_array,
                    ]
                )
            )

        return new_state.replace(irreps_array=features)

In [14]:
class LossFunction:
    def __init__(
        self, weight: float = 1.0, start_step: int = 0, scheduler: Scheduler = None
    ):
        self.weight = weight
        self.start_step = start_step
        self.scheduler = scheduler

    def _call(
        self, model_output: ModelOutput, NucleicDatum: Dict
    ) -> Tuple[ModelOutput, jnp.ndarray, Dict[str, float]]:
        raise NotImplementedError

    def __call__(
        self,
        rng_key,
        model_output: ModelOutput,
        batch: NucleicDatum,
        step: int,
    ) -> Tuple[ModelOutput, jnp.ndarray, Dict[str, float]]:
        output, loss, metrics = self._call(rng_key, model_output, batch)
        is_activated = jnp.array(self.start_step <= step).astype(loss.dtype)
        loss = loss * is_activated
        if self.scheduler is not None:
            scheduler_weight = self.scheduler(step)
            loss = loss * scheduler_weight
            loss_name = re.sub(r"(?<!^)(?=[A-Z])", "_", type(self).__name__).lower()
            metrics[loss_name + "_scheduler"] = scheduler_weight
        return output, self.weight * loss, metrics
        
class CrossVectorLoss(LossFunction):
    def __init__(
        self,
        weight=1.0,
        start_step=0,
        max_radius: float = 32.0,
        max_error: float = 800.0,
        norm_only=False,
    ):
        super().__init__(weight=weight, start_step=start_step)
        self.norm_only = norm_only
        self.max_radius = max_radius
        self.max_error = max_error

    def _call(
        self, rng_key, model_output: ModelOutput, ground: NucleicDatum
    ) -> Tuple[ModelOutput, jnp.ndarray, Dict[str, float]]:
        coords = model_output.datum.atom_coord
        all_atom_coords = rearrange(coords, "r a c -> (r a) c")
        all_atom_coords_ground = rearrange(ground.atom_coord, "r a c -> (r a) c")
        all_atom_mask = rearrange(ground.atom_mask, "r a -> (r a)")

        vector_map = lambda x: rearrange(x, "i c -> i () c") - rearrange(
            x, "j c -> () j c"
        ) #This new array contains the differences between the coordinates of all possible pairs of atoms within each nucleotide

        cross_mask = rearrange(all_atom_mask, "i -> i ()") & rearrange(
            all_atom_mask, "j -> () j"
        )

        vector_maps = vector_map(all_atom_coords)
        vector_maps_ground = vector_map(all_atom_coords_ground)
        cross_mask = cross_mask & (safe_norm(vector_maps_ground) < self.max_radius)

        if self.norm_only:
            vector_maps = safe_norm(vector_maps)[..., None]
            vector_maps_ground = safe_norm(vector_maps_ground)[..., None]

        se = jnp.square(vector_maps - vector_maps_ground).mean(-1)
        if self.max_error > 0.0:
            se = jnp.clip(se, 0.0, self.max_error)
        mse = (se * cross_mask.astype(se.dtype)).sum((-1, -2)) / (
            cross_mask.sum((-1, -2)) + 1e-6
        )
        mse = mse.mean()

        return model_output, mse, {"cross_vector_loss": mse}


In [20]:

class SpatialConvolution(hk.Module):
    def __init__(
        self,
        irreps_out: e3nn.Irreps,
        *,
        k: int,
        radial_cut: float,
        radial_bins: int = 32,
        radial_basis: str = "gaussian",
        edge_irreps: e3nn.Irreps = e3nn.Irreps("0e + 1e + 2e"),
        norm: bool = True,
        activation: Callable = jax.nn.silu,
    ):
        super().__init__()
        self.irreps_out = e3nn.Irreps(irreps_out)
        self.radial_cut = radial_cut
        self.radial_bins = radial_bins
        self.radial_basis = radial_basis
        self.edge_irreps = e3nn.Irreps(edge_irreps)
        self.norm = norm
        self.activation = activation
        self.k = k

    def embedding(
        self, state: InternalState, nei_indices: jnp.ndarray, nei_mask: jnp.ndarray
    ) -> Tuple:
        (seq_len, _) = state.irreps_array.shape
        k = nei_indices.shape[1]
        assert nei_indices.shape == (seq_len, k)
        assert nei_mask.shape == (seq_len, k, 1)

        vectors = state.coord[nei_indices, :] - state.coord[:, None, :]
        norm_sqr = jnp.sum(vectors**2, axis=-1)
        norm = jnp.sqrt(jnp.where(norm_sqr == 0.0, 1.0, norm_sqr))
        assert norm.shape == (seq_len, k)

        # Angular embedding:
        ang_embed = nei_mask * e3nn.spherical_harmonics(
            self.edge_irreps, vectors, True, "component"
        )
        assert ang_embed.shape == (seq_len, k, self.edge_irreps.dim)

        # Radial embedding:
        rad_embed = nei_mask * e3nn.soft_one_hot_linspace(
            norm,
            start=0.0,
            end=self.radial_cut,
            number=self.radial_bins,
            basis=self.radial_basis,
            cutoff=True,
        )
        assert rad_embed.shape == (seq_len, k, self.radial_bins)

        # Envelope:
        envelope = (
            nei_mask
            * e3nn.soft_envelope(
                norm, x_max=self.radial_cut, arg_multiplicator=5.0, value_at_origin=1.0
            )[:, :, None]
        )
        assert envelope.shape == (seq_len, k, 1)

        return ang_embed, rad_embed, envelope

    def _call(self, state: InternalState) -> InternalState:
        seq_len = state.irreps_array.shape[0]
        assert state.irreps_array.shape == (seq_len, state.irreps_array.irreps.dim)
        assert state.mask.shape == (seq_len,)
        assert state.coord.shape == (seq_len, 3)

        if seq_len == 1:
            return state

        # k nearest neighbors:
        nei_indices, nei_mask = knn(state.coord, state.mask_coord, k=self.k)
        k = nei_indices.shape[1]
        assert nei_indices.shape == (seq_len, k)
        assert nei_mask.shape == (seq_len, k)

        # Embeddings:
        ang_embed, rad_embed, envelope = self.embedding(
            state, nei_indices, nei_mask[:, :, None]
        )
        assert ang_embed.shape == (seq_len, k, self.edge_irreps.dim)
        assert rad_embed.shape == (seq_len, k, self.radial_bins)
        assert envelope.shape == (seq_len, k, 1)

        # Get messages:
        nei_states = nei_mask[:, :, None] * state.irreps_array[nei_indices, :]
        messages = e3nn.haiku.Linear(nei_states.irreps)(nei_states)
        assert messages.shape == (seq_len, k, nei_states.irreps.dim)

        # Angular part:
        ang_embed = e3nn.haiku.Linear(messages.irreps)(ang_embed)
        messages = messages + ang_embed

        assert messages.shape == (seq_len, k, messages.irreps.dim)

        # Radial part:
        mix = e3nn.haiku.MultiLayerPerceptron(
            [self.radial_bins, messages.irreps.num_irreps],
            self.activation,
            output_activation=False,
        )(rad_embed)
        assert mix.shape == (seq_len, k, messages.irreps.num_irreps)

        # Sum over neighbors:
        features = e3nn.sum(envelope * messages * mix, axis=1) / k
        features = e3nn.haiku.Linear(self.irreps_out)(features)
        assert features.shape == (seq_len, features.irreps.dim)

        # Update coordinates:
        update = 1e-3 * e3nn.haiku.Linear("1e")(features).array
        new_coord = state.coord + update

        # Normalization multiplicity-wise:
        if self.norm:
            features = EquivariantLayerNorm()(features)

        return state.replace(irreps_array=features, coord=new_coord)

    # def __call__(self, state: InternalState) -> InternalState:
    #     return Residual(self._call)(state)

In [10]:
#fetch dna:
dna_datum = NucleicDatum.fetch_pdb_id('5JZQ')

#get central atom and its indices along the atom arrays:
center_atom_token = atom_index("C3'") #center atom
mask_center = dna_datum.atom_token == center_atom_token

#expand to match the shape of coords: (12,24,3) , not needed actually
mask_center_3D = (mask_center[..., np.newaxis] * np.ones(3)).astype(bool)

In [38]:
state_coorda = np.array([[1.0, 2.0], [3.0, 4.0]])

nei_indicesa = np.array([[0, 1]])
vectors = state_coorda[nei_indicesa, :] - state_coorda[:, None, :]
vectors


array([[[ 0.,  0.],
        [ 2.,  2.]],

       [[-2., -2.],
        [ 0.,  0.]]])

In [11]:
res_center_coords = dna_datum.atom_coord[..., mask_center, :]

In [12]:
#get the coords of central atom:
all_atom_coord = dna_datum.atom_coord
res_center_coord = dna_datum.atom_coord[..., mask_center, :]
res_center_coord[...,None,:].shape

(12, 1, 3)

In [13]:
#subtract center coord from each atom's coordinates in all_atom_coord to get relative vectors:
relative_vectors = all_atom_coord - res_center_coord[:, np.newaxis,:]
relative_vectors.shape
relative_vectors_flat = np.reshape(relative_vectors, (12,-1))#"... a e -> ... (a e)")
relative_vectors_flat.shape

(12, 72)

In [22]:
nuc_token_reshaped = dna_datum.nuc_token[:, jnp.newaxis]
print(nuc_token_reshaped.shape, relative_vectors_flat.shape)
irreps_array = jnp.concatenate((nuc_token_reshaped , relative_vectors_flat), axis = -1)

print(irreps_array.shape)
irreps_array = e3nn.IrrepsArray('1x0e + 24x1o',irreps_array)
mask_irr = jnp.array(dna_datum.nuc_mask)
internal = InternalState(irreps_array=irreps_array, mask_irreps_array=mask_irr, coord=res_center_coords, mask_coord=mask_irr,)


(12, 1) (12, 72)
(12, 73)


In [24]:
res_center_coords.shape

(12, 3)

In [25]:
#gaussian input:
size1 = (12, 1)
size2 = (12, 72)
size3 = (12,3)

# Generate Gaussian noise arrays
gaussian_noise1 = np.random.normal(loc=0.0, scale=1.0, size=size1)
gaussian_noise2 = np.random.normal(loc=0.0, scale=1.0, size=size2)
gaussian_noise3 = np.random.normal(loc=0.0, scale=1.0, size=size3)

gaussian_irreps = jnp.concatenate((gaussian_noise1 , gaussian_noise2), axis = -1)

print(irreps_array.shape)
gaussian_irreps = e3nn.IrrepsArray('1x0e + 24x1o',gaussian_irreps)
gaussian_state = InternalState(irreps_array=gaussian_irreps, mask_irreps_array=mask_irr, coord=gaussian_noise3, mask_coord=mask_irr,)



(12, 73)


In [19]:
# Define a function to create the SpatialConvolution instance
def create_spatial_convolution(irreps_array, rc, k):
    return SpatialConvolution(e3nn.IrrepsArray('1x0e + 24x1o'), radial_cut=rc, k=k)


# Define a Haiku transform function that includes the SpatialConvolution
@hk.transform
def my_model(x):
    rc = jnp.sqrt(4.5 * 16)
    spatial_convolution = create_spatial_convolution(irreps_array, rc, k=16)
    return spatial_convolution(x)

model = my_model()
#training:
epochs=1000
input_state = gaussian_state
final_desired_state = internal
rc = jnp.sqrt(4.5 * 16 ) 
    
for epoch in range(epochs):
    model_output_state = model(input_state)
    loss = final_desired_state, model_output_state

    #update weights?


ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

In [11]:
tent_coord = internal.irreps_array.filter("1o").array
# tent_coord = np.array(tent_coord)[:,:71]
#getting the original coordinates back:
atom_coord = internal.coord[..., None, :]  + rearrange(
    tent_coord, "... (h c) -> ... h c", c=3
)


In [12]:
internal.irreps_array.shape

(12, 73)

In [13]:
atom_coord.shape

(12, 24, 3)

In [32]:
internal.coord[0,None,:]

array([[ 7.65799999, 24.06299973, 18.35400009]])

In [34]:
(internal.coord[0,:,None]-internal.coord[0,None,:])

array([[  0.        , -16.40499973, -10.6960001 ],
       [ 16.40499973,   0.        ,   5.70899963],
       [ 10.6960001 ,  -5.70899963,   0.        ]])

In [15]:
import plotly.graph_objs as go
import numpy as np

def _scatter(name, ca_coord, atom_coord, atom_mask, color, visible=True):
    centers = np.array(ca_coord)
    arrows = np.array(atom_coord).reshape((:,-1))
    bb_x, bb_y, bb_z = centers.T
    sc_x, sc_y, sc_z = arrows.T
    print(arrows.T.shape)
    print(centers.T.shape)
    data = [
        go.Scatter3d(
            name=name + " coord",
            x=bb_x,
            y=bb_y,
            z=bb_z,
            mode="markers",
            marker=dict(size=2, color=color),
            line=dict(color=color, width=4),
            visible="legendonly" if not visible else True,
        ),
        go.Scatter3d(
            name=name + " vecs",
            x=sc_x,
            y=sc_y,
            z=sc_z,
            mode="lines",
            marker=dict(size=4, color='red'),
            line=dict(color='red', width=2),
            visible="legendonly"if not visible else True,
        ),
    ]
    return data



SyntaxError: invalid syntax (1958227518.py, line 6)

In [None]:
scatter_data = _scatter(name='visualize', ca_coord = internal.coord, atom_coord=atom_coord, atom_mask=np.array(internal.mask_coord).astype(bool), color='blue', visible=True)
fig = go.Figure(data=scatter_data)

fig.show()

(3, 24, 12)
(3, 12)


In [None]:
start_coords.shape

(12, 3)

In [49]:
# Starting and ending coordinates of the line
# scatter_data = _scatter(name='visualize', ca_coord = internal.coord, atom_coord=atom_coord, atom_mask=np.array(internal.mask_coord).astype(bool), color='blue', visible=True)
start_coords = internal.coord
end_coords = atom_coord

# Create a 3D scatter plot
fig = go.Figure()
for coord in range(len(start_coords)):
    for atom in range(24):
        fig.add_trace(go.Scatter3d(x=[start_coords[coord][0], end_coords[coord][atom][0]], 
                            y=[start_coords[coord][1], end_coords[coord][atom][1]], 
                            z=[start_coords[coord][2], end_coords[coord][atom][2]], 
                            mode='lines', line=dict(color='blue', width=2)))

# Update layout if needed (e.g., axis labels, title, etc.)
fig.update_layout(scene=dict(xaxis=dict(title='X-axis'), yaxis=dict(title='Y-axis'), zaxis=dict(title='Z-axis')), title='Line in 3D Space')

# Show the plot
fig.show()

In [18]:
#einops tutorial:
a = np.array([[1,2,3,11],[4,5,6, 44],[7,8,9, 77]])
a.shape, a

((3, 4),
 array([[ 1,  2,  3, 11],
        [ 4,  5,  6, 44],
        [ 7,  8,  9, 77]]))

In [16]:
b = rearrange(a, 'i j -> () j i')
b.shape, b

((1, 4, 3),
 array([[[ 1,  4,  7],
         [ 2,  5,  8],
         [ 3,  6,  9],
         [11, 44, 77]]]))

In [19]:
b = rearrange(a, 'i j -> () (j i)')
b.shape, b

((1, 12), array([[ 1,  4,  7,  2,  5,  8,  3,  6,  9, 11, 44, 77]]))

In [28]:
k = np.array([[[ 1,  4,  7],
         [ 2,  5,  8],
         [ 3,  6,  9],
         [11, 44, 77]]])
b = rearrange(k, ' m (i1 i2) j -> m i1 (i2 j)', i1=2) #1,4,3 -> 1,2,6
b.shape, b

((1, 2, 6),
 array([[[ 1,  4,  7,  2,  5,  8],
         [ 3,  6,  9, 11, 44, 77]]]))

In [9]:
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c = np.array([7, 8, 9])
d = np.array([[1,2,3],[4,5,6],[7,8,9]])
print(d, d.shape)
# Stack the arrays along axis 0 (rows)
stacked_array = np.stack(d, axis=1)
print(stacked_array)

[[1 2 3]
 [4 5 6]
 [7 8 9]] (3, 3)
[[1 4 7]
 [2 5 8]
 [3 6 9]]


In [None]:
from typing import Tuple, Union
import haiku as hk

from model.base.decoder import Decoder
from model.base.encoder import Encoder

from .base.utils import ModelOutput
from moleculib.protein.datum import ProteinDatum


class Ophiuchus(hk.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder):
        super().__init__()
        self.encoder = encoder()
        self.decoder = decoder()

    def __call__(
        self,
        datum: Union[ProteinDatum, None] = None,
        is_training: bool = False,
        conditionals: Tuple[bool] = None,
    ):
    
        
        return ModelOutput(
            datum=datum_out,
            logits=None,
            encoder_internals=None,
            decoder_internals=None,
            probs=None,
            atom_perm_loss=None,
        )