In [2]:
# more advanced EGNN layers

import os, torch

from tqdm.auto import tqdm
from torch import nn, einsum, broadcast_tensors

from einops import rearrange, repeat

In [3]:
# some helper functions
def exists(val): return val is not None

def save_div(num, den, eps=1e-8):
    res = num.div(den.clamp(min=eps))
    res.masked_fill(den == 0, 0.0)
    return res

def batched_index_select(values, indices, dim = 1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims)))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# fourier encoding distance
def fourier_encode_dist(x, num_encodings=4, include_self=True): 
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtyoe, x
    scale = 2 ** torch.arange(num_encodings, device=device, dtype=dtype)
    x = x / scales
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat([x, orig_x], dim=-1) if include_self else x
    return x

def embedd_token(x, dims, layers): 
    stop_concat = -len(dims)
    to_embedd = x[:, stop_concat:].long()
    for i, emb_layer in enumerate(layers): 
        x = torch.cat([x[:, :stop_concat], emb_layer(to_embedd[:, i])], dim=-1)
        stop_concat = x.shape[-1]
    return x


class Swish_(nn.Module): 
    def forward(self, x): return x * x.sigmoid()

SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_

In [4]:
# Normalization layers
class CoorsNorm(nn.Module): 
    
    def __init__(self, eps=1e-8, scale_init=1.0): 
        super(CoorsNorm, self).__init__()
        self.eps = eps
        scale = torch.zeros(1).fill_(scale_init)
        self.scale = nn.Parameter(scale)
    
    def forward(self, coors): 
        norm = coors.norm(dim=-1, keepdim=True)
        normed_coors = coors / norm.clamp(min=self.eps)
        return normed_coors * self.scale
    


In [None]:
# Global Linear Attention
