In [1]:
#@title Install dependencies
import os
from google.colab import files
import re
import hashlib
import random

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

os.system("pip install biopython")
from Bio.PDB import *

USE_AMBER = True
USE_TEMPLATES = False
PYTHON_VERSION = python_version

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

if USE_AMBER or USE_TEMPLATES:
  if not os.path.isfile("CONDA_READY"):
    print("installing conda...")
    os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
    os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
    os.system("mamba config --set auto_update_conda false")
    os.system("touch CONDA_READY")

if USE_TEMPLATES and not os.path.isfile("HH_READY") and USE_AMBER and not os.path.isfile("AMBER_READY"):
  print("installing hhsuite and amber...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch HH_READY")
  os.system("touch AMBER_READY")
else:
  if USE_TEMPLATES and not os.path.isfile("HH_READY"):
    print("installing hhsuite...")
    os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
    os.system("touch HH_READY")
  if USE_AMBER and not os.path.isfile("AMBER_READY"):
    print("installing amber...")
    os.system(f"mamba install -y -c conda-forge openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
    os.system("touch AMBER_READY")

if not os.path.exists("rna3db-mmcifs"):
  print("Downloading RNA3Db...")
  # Donwload RNA3Db structure files
  os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-mmcifs.v2.tar.xz')
  print("Extracting structure files...")
  os.system('sudo tar -xf rna3db-mmcifs.v2.tar.xz')
  os.system('rm rna3db-mmcifs.v2.tar.xz')

  # Donwload RNA3Db sequence files
  os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-jsons.tar.gz')
  print("Extracting sequence files...")
  os.system('tar -xzf rna3db-jsons.tar.gz')
  os.system('rm rna3db-jsons.tar.gz')
seq_path = "/content/rna3db-jsons/cluster.json"
struct_path = "/rna3db-mmcifs/"

installing colabfold...
installing conda...
installing amber...
Downloading RNA3Db...
Extracting structure files...
Extracting sequence files...


In [2]:
def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:5]

In [3]:
import torch
import torch.nn as nn
import math
import numpy as np

class Converter(nn.Module):
    def __init__(self, max_seq_len, d_model=64, nhead=8, num_layers=6, dim_feedforward=256, dropout=0.1):
        super(Converter, self).__init__()

        self.d_model = d_model

        self.input_embedding = nn.Linear(1, d_model)

        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_seq_len)

        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                    dim_feedforward=dim_feedforward,
                                                    dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        self.output_linear = nn.Linear(d_model, 20)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, src_key_padding_mask=None):
        # x shape: (seq_len, batch_size, 1)

        x = self.input_embedding(x)  # Now: (seq_len, batch_size, d_model)

        x = self.pos_encoder(x)

        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)

        x = self.output_linear(x)  # Now: (seq_len, batch_size, 20)
        x = self.softmax(x)

        # Convert softmaxxed matrices into one-dimensional indeces
        with torch.no_grad():
            out = []
            for i in range(len(x)):
                out.append([])
                for j in range(len(x[i])):
                    out[-1].append((torch.argmax(x[i][j].detach().cpu())).item())
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def create_padding_mask(sequences, pad_value=0):
    # sequences shape: (seq_len, batch_size, 1)
    return (sequences.squeeze(-1) == pad_value).t()  # (batch_size, seq_len)

In [4]:
import json
import sys

def parse_json(path, a, b, max_len=150):
    num = -1
    seqs = {}
    comps = []
    macros = []
    f = open(path)
    data = json.load(f)
    for i in data:
        for j in data[i]:
          for k in data[i][j]:
            num = num + 1
            if data[i][j][k]["length"]>max_len:
                continue
            if num>=a and num<=b:
                seqs[k]=data[i][j][k]["sequence"]
                comps.append(i)
                macros.append(j)
            if num>b:
                break
    f.close()
    return seqs, comps, macros


In [5]:
seqs = {} # All sequences - may get quite large

# Used for file tree searching
components = []
macro_tags = []


# Index to amino acid dictionary
# Largely arbitrary, but must stay consistent for any given converter
AA_DICT = {
    0: "A",
    1: "R",
    2: "N",
    3: "D",
    4: "C",
    5: "Q",
    6: "E",
    7: "G",
    8: "H",
    9: "I",
    10: "L",
    11: "K",
    12: "M",
    13: "F",
    14: "P",
    15: "S",
    16: "T",
    17: "W",
    18: "Y",
    19: "V"
}

def load_data(path, a=0, b=float('inf'), max_len=150):
    # Load up sequences, components, and macro-tags
    seqs, components, macro_tags=parse_json(path, a, b, max_len=max_len)
    print(f"Found {len(seqs)} usable RNA strands...")
    return seqs, components, macro_tags

def batch_data(iterable, n=1):
    # Random data batching function
    l = len(iterable)
    iter = [(t, s) for t, s in list(iterable.items())]
    for ndx in range(0, l, n):
        yield iter[ndx:min(ndx + n, l)]

def encode_rna(seq):
    # Convert RNA sequence to nums to feed into Converter
    out = []
    for i in seq:
        if i=="A":
            out.append([0])
        elif i=="U":
            out.append([1])
        elif i=="C":
            out.append([2])
        elif i=="G":
            out.append([3])
    return out

def write_fastas(seqs):
    # Write a dict of {tag: seq} to as many FASTA files as needed
    os.makedirs('FASTAs', exist_ok=True)
    for tag, seq in list(seqs.items()):
        if os.path.exists(f'/content/FASTAs/{tag}.fasta'):
            continue
        f = open(f"/content/FASTAs/{tag}.fasta", "w+")
        f.write(f">{tag}\n{seq}")
        f.close()

def empty_dir(path):
    # Empty any directory
    for f in os.listdir(path):
        if os.path.isfile(os.path.join(path, f)):
          os.remove(os.path.join(path, f))
        else:
          empty_dir(os.path.join(path, f))
    os.rmdir(path)

In [6]:
def get_structure(tag, path):
    # Return the structure of an RNA molecule given its tag and the path to the structure directory
    # File directory:
    # root
    #  -- colabfold.dir
    #  -- train_protify.ipynb
    #  -- data.dir
    #  ---- component 1.dir
    #  ------ tag 1.dir
    #  -------- tag 1a.cif
    #  -------- tag 1b.cif
    # ...
    index = list(seqs.keys()).index(tag)
    component = components[index]
    macro_tag = macro_tags[index]

    path = f"/content/{path}/train_set/{component}/{macro_tag}/{tag}.cif"
    return path

In [7]:
#@markdown ### Advanced settings
model_type = "auto" #@param ["auto", "alphafold2_ptm", "alphafold2_multimer_v1", "alphafold2_multimer_v2", "alphafold2_multimer_v3", "deepfold_v1"]
#@markdown - if `auto` selected, will use `alphafold2_ptm` for monomer prediction and `alphafold2_multimer_v3` for complex prediction.
#@markdown Any of the mode_types can be used (regardless if input is monomer or complex).
num_recycles = "3" #@param ["auto", "0", "1", "3", "6", "12", "24", "48"]
#@markdown - if `auto` selected, will use `num_recycles=20` if `model_type=alphafold2_multimer_v3`, else `num_recycles=3` .
recycle_early_stop_tolerance = "auto" #@param ["auto", "0.0", "0.5", "1.0"]
#@markdown - if `auto` selected, will use `tol=0.5` if `model_type=alphafold2_multimer_v3` else `tol=0.0`.
relax_max_iterations = 200 #@param [0, 200, 2000] {type:"raw"}
#@markdown - max amber relax iterations, `0` = unlimited (AlphaFold2 default, can take very long)
pairing_strategy = "greedy" #@param ["greedy", "complete"] {type:"string"}
#@markdown - `greedy` = pair any taxonomically matching subsets, `complete` = all sequences have to match in one line.


#@markdown #### Sample settings
#@markdown -  enable dropouts and increase number of seeds to sample predictions from uncertainty of the model.
#@markdown -  decrease `max_msa` to increase uncertainity
max_msa = "auto" #@param ["auto", "512:1024", "256:512", "64:128", "32:64", "16:32"]
num_seeds = 1 #@param [1,2,4,8,16] {type:"raw"}
use_dropout = False #@param {type:"boolean"}

num_recycles = None if num_recycles == "auto" else int(num_recycles)
recycle_early_stop_tolerance = None if recycle_early_stop_tolerance == "auto" else float(recycle_early_stop_tolerance)
if max_msa == "auto": max_msa = None

In [8]:
#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)
msa_mode = "mmseqs2_uniref_env" #@param ["mmseqs2_uniref_env", "mmseqs2_uniref","single_sequence","custom"]
pair_mode = "unpaired_paired" #@param ["unpaired_paired","paired","unpaired"] {type:"string"}
#@markdown - "unpaired_paired" = pair sequences from same species + unpaired MSA, "unpaired" = seperate MSA for each chain, "paired" - only use paired sequences.

def a3ms(jobname):
    # decide which a3m to use
    if "mmseqs2" in msa_mode:
        a3m_file = os.path.join(jobname,f"{jobname}.a3m")

    elif msa_mode == "custom":
        a3m_file = os.path.join(jobname,f"{jobname}.custom.a3m")
        if not os.path.isfile(a3m_file):
            custom_msa_dict = files.upload()
            custom_msa = list(custom_msa_dict.keys())[0]
            header = 0
            import fileinput
            for line in fileinput.FileInput(custom_msa,inplace=1):
                if line.startswith(">"):
                    header = header + 1
                if not line.rstrip():
                    continue
                if line.startswith(">") == False and header == 1:
                    query_sequence = line.rstrip()
                print(line, end='')

            os.rename(custom_msa, a3m_file)
            queries_path=a3m_file
            print(f"moving {custom_msa} to {a3m_file}")

    else:
        a3m_file = os.path.join(jobname,f"{jobname}.single_sequence.a3m")
        with open(a3m_file, "w") as text_file:
            text_file.write(">1\n%s" % query_sequence)

In [9]:
import copy

class Monomer:
    def __init__(self, macro):
        self.atoms = dict()
        self.name = ""
        self.macro = macro

    def add_atom(self, x, y, z, e):
        self.atoms[e] = np.array([x, y, z])

    def element(self, e):
        if e[0]=="\"":
            return e[1]
        else:
            return e[0]

    def get_atoms(self):
        return self.atoms

    def add_name(self, name):
        self.name = name

    def apply_transformation(self, x, y, z):
        out = self
        for atom in out.atoms:
            out.atoms[atom] += [x,y,z]
        return out

    def calculate_normal(self):
        # Get the triangle vertices
        c4_pos = np.array(self.atoms['"C4\'"'])
        c1_pos = np.array(self.atoms['"C1\'"'])
        translation = -np.array(self.atoms['P'])

        # Calculate triangle vectors
        p_to_c4 = c4_pos + translation  # Vector from P to C4'
        p_to_c1 = c1_pos + translation  # Vector from P to C1'

        # Calculate normal to triangle
        normal = np.cross(p_to_c4, p_to_c1)
        normal = normal / np.linalg.norm(normal)
        return normal

    def align_triangle_to_xy(self):
        """
        Aligns the triangle formed by C4', C1', and N1/N9 atoms to the positive xy plane.
        """
        out = copy.deepcopy(self)
        # Get the coordinates of the three atoms forming the triangle
        c4_coords = np.array(out.get_atom_coordinates('\"C4\'\"'))
        c1_coords = np.array(out.get_atom_coordinates('\"C1\'\"'))
        base_coords = np.array(out.get_atom_coordinates("P"))

        if c4_coords is None or c1_coords is None or base_coords is None:
            raise ValueError("Could not find required atoms for alignment")

        # Create vectors from C4' to C1' and C4' to N1/N9
        v1 = c1_coords - c4_coords
        v2 = base_coords - c4_coords

        # Calculate the normal vector of the triangle
        normal = np.cross(v1, v2)
        normal_magnitude = np.linalg.norm(normal)

        if normal_magnitude < 1e-10:
            raise ValueError("Colinear points cannot form a triangle")

        normal = normal / normal_magnitude

        # Calculate rotation matrix to align normal vector with z-axis
        z_axis = np.array([0, 0, 1])
        rotation_axis = np.cross(normal, z_axis)
        rotation_axis_magnitude = np.linalg.norm(rotation_axis)

        if rotation_axis_magnitude < 1e-10:
            # If vectors are parallel, no rotation needed or rotate 180° if antiparallel
            if normal[2] < 0:
                # If normal points in negative z, rotate 180° around x-axis
                rotation_matrix = np.array([
                    [1, 0, 0],
                    [0, -1, 0],
                    [0, 0, -1]
                ])
            else:
                return  # Already aligned correctly
        else:
            rotation_axis = rotation_axis / rotation_axis_magnitude
            angle = np.arccos(np.clip(np.dot(normal, z_axis), -1.0, 1.0))

            # Create rotation matrix using Rodrigues' rotation formula
            K = np.array([
                [0, -rotation_axis[2], rotation_axis[1]],
                [rotation_axis[2], 0, -rotation_axis[0]],
                [-rotation_axis[1], rotation_axis[0], 0]
            ])
            rotation_matrix = (np.eye(3) + np.sin(angle) * K +
                            (1 - np.cos(angle)) * np.matmul(K, K))

        # Apply rotation to all atoms
        for atom in out.atoms.keys():
            coords = np.array(out.atoms[atom]) - c4_coords  # Center around C4'
            rotated_coords = np.dot(rotation_matrix, coords)
            atom.set_coordinates(rotated_coords + c4_coords)  # Move back to original position

        # After first rotation, calculate the angle in xy plane between C4'-C1' vector and x-axis
        c4_coords = np.array(out.get_atom_coordinates('C4\''))
        c1_coords = np.array(out.get_atom_coordinates('C1\''))
        v1_xy = c1_coords[:2] - c4_coords[:2]  # Only consider x and y components
        v1_xy_magnitude = np.linalg.norm(v1_xy)

        if v1_xy_magnitude < 1e-10:
            return  # Vector is vertical, no need for xy rotation

        cos_theta = np.clip(np.dot(v1_xy, [1, 0]) / v1_xy_magnitude, -1.0, 1.0)
        theta = np.arccos(cos_theta)

        # Determine if we need to rotate clockwise or counterclockwise
        if v1_xy[1] < 0:
            theta = -theta

        # Create rotation matrix around z-axis
        rotation_matrix_z = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]
        ])

        # Apply second rotation to all atoms
        for atom in out.atoms:
            coords = np.array(atom.get_coordinates()) - c4_coords
            rotated_coords = np.dot(rotation_matrix_z, coords)
            atom.set_coordinates(rotated_coords + c4_coords)

        # Final check to ensure the molecule is in the positive xy plane
        # If the base atom is in the negative x region, rotate 180° around y-axis
        base_coords = np.array(out.get_atom_coordinates(out.base_atom))
        if base_coords[0] - c4_coords[0] < 0:
            rotation_matrix_y = np.array([
                [-1, 0, 0],
                [0, 1, 0],
                [0, 0, -1]
            ])
            for atom in out.atoms:
                coords = np.array(atom.get_coordinates()) - c4_coords
                rotated_coords = np.dot(rotation_matrix_y, coords)
                atom.set_coordinates(rotated_coords + c4_coords)


    def align_to_normal(self, target_normal):
        """
        Rotates the monomer so that the normal vector of its P-C1'-C4' triangle
        aligns with the given target normal vector.

        Args:
            target_normal (np.ndarray): The target normal vector to align with (should be normalized)

        Returns:
            Monomer: A new Monomer instance with the rotated coordinates
        """
        out = copy.deepcopy(self)

        try:
            # Get current normal vector
            current_normal = self.calculate_normal()

            # Normalize target vector
            target_normal = target_normal / np.linalg.norm(target_normal)

            # Calculate rotation axis and angle
            rotation_axis = np.cross(current_normal, target_normal)

            # If vectors are parallel (or anti-parallel), rotation axis will be zero
            if np.linalg.norm(rotation_axis) < 1e-10:
                # If normals are anti-parallel, rotate 180° around any perpendicular axis
                if np.dot(current_normal, target_normal) < 0:
                    # Find a perpendicular vector to rotate around
                    if abs(current_normal[0]) < abs(current_normal[1]):
                        rotation_axis = np.cross(current_normal, [1, 0, 0])
                    else:
                        rotation_axis = np.cross(current_normal, [0, 1, 0])
                    angle = np.pi
                else:
                    # Vectors are already aligned
                    return out
            else:
                # Calculate rotation angle
                angle = np.arccos(np.clip(np.dot(current_normal, target_normal), -1.0, 1.0))

            # Normalize rotation axis
            rotation_axis = rotation_axis / np.linalg.norm(rotation_axis)

            # Create rotation matrix using Rodrigues' rotation formula
            cos_theta = np.cos(angle)
            sin_theta = np.sin(angle)
            K = np.array([
                [0, -rotation_axis[2], rotation_axis[1]],
                [rotation_axis[2], 0, -rotation_axis[0]],
                [-rotation_axis[1], rotation_axis[0], 0]
            ])
            R = np.eye(3) + sin_theta * K + (1 - cos_theta) * np.dot(K, K)

            # Apply rotation to all atoms
            for atom in out.atoms:
                out.atoms[atom] = np.dot(R, out.atoms[atom])

            return out

        except KeyError as e:
            raise KeyError(f"Required atom {e} not found in this monomer")


    def load_template(self, n):
        if n=="A": path = "templates\\Adenine_template.cif"
        elif n=="C": path = "templates\\Cytosine_template.cif"
        elif n=="G": path = "templates\\Guanine_template.cif"
        elif n=="U": path = "templates\\Uracil_template.cif"
        atoms = []
        atom_xs = []
        atom_ys = []
        atom_zs = []
        try:
          parser = MMCIFParser()
          structure = parser.get_structure(n, path)
          data = []
          for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        datum = list(atom.get_vector())
                        temp = (datum[0], datum[1], datum[2], atom.get_name())
                        data.append(temp)
        except Exception as e:
            print("Oops. %s" % e)
        for i in range(len(data)):
            self.add_atom(data[i])


    def __str__(self, start=1):
        #Return what this monomer would look like in an mmCIF file#
        out = ""
        c = start
        for i in self.atoms:
            out += f"ATOM {c}\t{self.element(i)}\t{i}\t{self.name}\t. A 1 1\t?\t{round(self.atoms[i][0],3)}\t{round(self.atoms[i][1],3)}\t{round(self.atoms[i][2],3)}\n"
            c += 1
        out += "\b\b"
        return out

    def print(self, start=1, number=1):
        #Return what this monomer would look like in an mmCIF file#
        out = ""
        c = start
        for i in self.atoms:
            out += f"ATOM {c}\t{self.element(i)}\t{i}\t{self.name}\t. A 1 {number}\t?\t{round(self.atoms[i][0],3)}\t{round(self.atoms[i][1],3)}\t{round(self.atoms[i][2],3)}\n"
            c += 1
        return out

    def __len__(self):
        return len(self.atoms)

In [10]:
def RMSD(p1, p2):
    loss = torch.sqrt(torch.mean((p1[:len(p2)] - p2)**2))
    return loss

def tm_score(p1, p2, lt):
    d0 = lambda l: 1.24 * torch.power(l-15, 3) - 1.8
    loss = torch.mean(1/(1+torch.power(torch.abs(torch.norm(p1-p2))/d0(lt),2)))
    return loss

In [11]:
def parse_rna(path):
    try:
        parser = MMCIFParser()
        structure = parser.get_structure("RNA", path)
        data = []
        for model in structure:
          for chain in model:
              for residue in chain:
                  for atom in residue:
                    if residue.get_resname() in ['A', 'U', 'C', 'G']:
                      datum = list(atom.get_vector())
                      temp = (datum[0], datum[1], datum[2], atom.get_name())
                      data.append(temp)

        points = []
        angle_points = []
        norms = []

        for x, y, z, atom in data:
            x = float(x)
            y = float(y)
            z = float(z)

            if atom == "P":
                points.append(np.array([x, y, z]))
                angle_points.append(np.array([x, y, z]))
            elif atom == "\"C1'\"":
                angle_points.append(np.array([x, y, z]))
            elif atom == "\"C4'\"":
                angle_points.append(np.array([x, y, z]))
                v1 = angle_points[-1]-angle_points[-2]
                v2 = angle_points[-3]-angle_points[-2]
                norms.append(np.cross(v1, v2))
                angle_points = []
        points = np.array(points)
        norms = np.array(norms)
        return torch.tensor(points, requires_grad=True, dtype=torch.float32), torch.tensor(norms, requires_grad=True, dtype=torch.float32)

    except Exception as e:
        print("Oops. %s" % e)
        sys.exit(1)

def parse_protein(path):
    try:
        parser = PDBParser()
        structure = parser.get_structure("Protein", path)
        data = []
        for model in structure:
          for chain in model:
              for residue in chain:
                  for atom in residue:
                      datum = list(atom.get_vector())
                      temp = (datum[0], datum[1], datum[2], atom.get_name())
                      data.append(temp)

        points = []
        angle_points = []
        norms = []
        for x, y, z, atom in data:
            x = float(x)
            y = float(y)
            z = float(z)

            if atom == "CA":
                points.append(np.array([x, y, z]))
                angle_points.append(np.array([x, y, z]))
            elif atom == "N":
                angle_points.append(np.array([x, y, z]))
            elif atom == "C":
                angle_points.append(np.array([x, y, z]))
                v1 = angle_points[-1]-angle_points[-2]
                v2 = angle_points[-3]-angle_points[-2]
                norms.append(np.cross(v1, v2))
                angle_points = []

        points = np.array(points)
        norms = np.array(norms)
        return torch.tensor(points, requires_grad=True), torch.tensor(norms, requires_grad=True)

    except Exception as e:
        print("Oops. %s" % e)
        sys.exit(1)

In [12]:
def protein_to_rna(protein, rna_path, tm=False):
    prot_points, _ = parse_protein(protein)
    rna_points, _ = parse_rna(rna_path)
    prot_points = correct_protein_coords(prot_points)
    if tm:
        return tm_score(prot_points, rna_points)
    return RMSD(prot_points, rna_points)

def correct_protein_coords(points):
  # Apply correction factor to the protein coordinates to account for bond lengths
  correction_factor = torch.zeros(3, dtype=torch.float32, requires_grad=True) # Delta correction factor
  pp_dist = 6.8 # Approximated this value from what ChatGPT tells me - will look for rigorous results
                # Also haha pp
  # Create a copy of the points to avoid in-place modification
  new_points = points.clone()
  for i in range(1,len(points)):
    v = points[i]-points[i-1] # Vector between two points
    pt = pp_dist*v/torch.norm(v) # Vector to new point
    correction_factor = correction_factor + pt-v
    # Update the copy instead of the original points
    new_points[i] = points[i]+correction_factor # Apply correction factor
  return new_points

In [13]:
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio import BiopythonDeprecationWarning
warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
from colabfold.plot import plot_msa_v2
import os
from google.colab import drive

def input_features_callback(input_features):
  pass

def prediction_callback(protein_obj, length,
                        prediction_result, input_features, mode):
  model_name, relaxed = mode
  pass

def train(seqs, epochs=50, batch_size=32,tm_score=False, max_seq_len=150):
    drive.mount('/content/drive')
    !mkdir -p "/content/drive/My Drive/ConverterWeights"
    try:
        K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
    except:
        K80_chk = "0"
        pass
    if "1" in K80_chk:
        print("WARNING: found GPU Tesla K80: limited to total length < 1000")
        if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
            del os.environ["TF_FORCE_UNIFIED_MEMORY"]
        if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
            del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

    # For some reason we need that to get pdbfixer to import
    if f"/usr/local/lib/python{python_version}/site-packages/" not in sys.path:
        sys.path.insert(0, f"/usr/local/lib/python{python_version}/site-packages/")

    conv = Converter(max_seq_len=max_seq_len)
    conv.train()
    optimizer = torch.optim.Adam(conv.parameters(), lr=1e-3)


    model_type = set_model_type(False, "auto")
    download_alphafold_params(model_type, Path("."))
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in batch_data(seqs, batch_size):
            optimizer.zero_grad()
            # batch: ([(tag, seq), (tag, seq),...])

            # LAYER 1: RNA-AMINO CONVERSION
            tags = [s[0] for s in batch]

            # Check that structure files exist
            # if not os.path.isfile(get_structure(tags[0])):
            #     continue

            # Preprocessing sequences
            processed_seqs = [torch.tensor(np.transpose(np.array(encode_rna(s[1])), (0,1)), requires_grad=False, dtype=torch.float32) for s in batch] # (batch, seq, base)

            # Send sequences through the converter
            aa_seqs = [conv(s) for s in processed_seqs][0] # (seq, batch, aa)
            temp = []

            # Reconvert to letter representation
            for i in range(len(aa_seqs)):
                temp.append(''.join([AA_DICT[n] for n in aa_seqs[i]]))

            aa_seqs = temp # (seq: String, batch)

            final_seqs = {} # {tag: seq}
            for i in range(len(tags)):
                final_seqs[tags[i]] = aa_seqs[i]
            write_fastas(final_seqs)

            num_relax = 0 #@param [0, 1, 5] {type:"raw"}
            #@markdown - specify how many of the top ranked structures to relax using amber
            template_mode = "none" #@param ["none", "pdb100","custom"]
            #@markdown - `none` = no template information is used. `pdb100` = detect templates in pdb100 (see [notes](#pdb100)). `custom` - upload and search own templates (PDB or mmCIF format, see [notes](#custom_templates))

            use_amber = num_relax > 0
            use_cluster_profile = True

            if template_mode == "pdb100":
                use_templates = True
                custom_template_path = None
            elif template_mode == "custom":
                custom_template_path = os.path.join(jobname,f"template")
                os.makedirs(custom_template_path, exist_ok=True)
                uploaded = files.upload()
                use_templates = True
                for fn in uploaded.keys():
                    os.rename(fn,os.path.join(custom_template_path,fn))
            else:
                custom_template_path = None
                use_templates = False

            loss = torch.tensor([], dtype=torch.float32, requires_grad=True)

            for i in range(len(final_seqs)):
                queries, _ = get_queries(f'/content/FASTAs/{list(final_seqs.keys())[i]}.fasta')
                jobname = add_hash(list(final_seqs.keys())[i], list(final_seqs.values())[i])
                results =  run(
                    queries=queries,
                    result_dir=jobname,
                    use_templates=USE_TEMPLATES,
                    custom_template_path=None,
                    num_relax=num_relax,
                    msa_mode=msa_mode,
                    model_type=model_type,
                    num_models=1,
                    num_recycles=num_recycles,
                    relax_max_iterations=relax_max_iterations,
                    recycle_early_stop_tolerance=recycle_early_stop_tolerance,
                    num_seeds=num_seeds,
                    use_dropout=use_dropout,
                    model_order=[1,2,3,4,5],
                    is_complex=False,
                    data_dir=Path("."),
                    keep_existing_results=False,
                    rank_by="auto",
                    pair_mode=pair_mode,
                    pairing_strategy=pairing_strategy,
                    stop_at_score=float(100),
                    prediction_callback=prediction_callback,
                    dpi=200,
                    zip_results=False,
                    save_all=False,
                    max_msa=max_msa,
                    use_cluster_profile=use_cluster_profile,
                    input_features_callback=input_features_callback,
                    save_recycles=False,
                    user_agent="colabfold/google-colab-main",
                )
                path = ""
                for file in os.listdir(f"/content/{jobname}"):
                  if file.endswith(".pdb"):
                    path = os.path.join(f"/content/{jobname}", file)
                    break
                temp_loss = (protein_to_rna(path, get_structure(list(final_seqs.keys())[i], struct_path), tm=tm_score))
                epoch_loss+=temp_loss
                empty_dir(f"/content/{jobname}")
                torch.cat((loss, torch.tensor([temp_loss], dtype=torch.float32)), 0)

            loss = torch.mean(loss)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch}: Loss {epoch_loss/len(seqs)}")
        torch.save(conv, f'/content/drive/My Drive/ConverterWeights/converter_epoch_{epoch}.pt')
        torch.save(conv.state_dict(), f'/content/drive/My Drive/ConverterWeights/converter_params_epoch_{epoch}.pt')


In [14]:
seqs, components, macro_tags = load_data(seq_path, 0, 10)
train(seqs, epochs=1, batch_size=3, max_seq_len=200)

Found 10 usable RNA strands...
Mounted at /content/drive


Downloading alphafold2_ptm weights to .: 100%|██████████| 3.47G/3.47G [00:34<00:00, 109MB/s]
PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 7s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:08 remaining: 00:00]


tensor(187.6265, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 10s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:11 remaining: 00:00]


tensor(200.3732, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 9s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:10 remaining: 00:00]


tensor(105.2179, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 5s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:06 remaining: 00:00]


tensor(198.2539, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 6s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:07 remaining: 00:00]


tensor(205.4314, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 9s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:10 remaining: 00:00]


tensor(179.1743, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 9s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:10 remaining: 00:00]


tensor(45.4130, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 7s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:08 remaining: 00:00]


tensor(39.4401, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 6s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:07 remaining: 00:00]


tensor(44.0015, dtype=torch.float64, grad_fn=<SqrtBackward0>)


PENDING:   0%|          | 0/150 [elapsed: 00:00 remaining: ?]ERROR:colabfold.colabfold:Sleeping for 7s. Reason: PENDING
COMPLETE: 100%|██████████| 150/150 [elapsed: 00:08 remaining: 00:00]


tensor(59.2969, dtype=torch.float64, grad_fn=<SqrtBackward0>)
Epoch 1: Loss nan
