In [None]:
#@title Install dependencies
import os
from google.colab import files
import re
import hashlib
import random

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

USE_AMBER = True
USE_TEMPLATES = False
PYTHON_VERSION = python_version

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

if USE_AMBER or USE_TEMPLATES:
  if not os.path.isfile("CONDA_READY"):
    print("installing conda...")
    os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
    os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
    os.system("mamba config --set auto_update_conda false")
    os.system("touch CONDA_READY")

if USE_TEMPLATES and not os.path.isfile("HH_READY") and USE_AMBER and not os.path.isfile("AMBER_READY"):
  print("installing hhsuite and amber...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch HH_READY")
  os.system("touch AMBER_READY")
else:
  if USE_TEMPLATES and not os.path.isfile("HH_READY"):
    print("installing hhsuite...")
    os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
    os.system("touch HH_READY")
  if USE_AMBER and not os.path.isfile("AMBER_READY"):
    print("installing amber...")
    os.system(f"mamba install -y -c conda-forge openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
    os.system("touch AMBER_READY")

os.system("pip install biopython")

# Donwload RNA3Db structure files
os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-mmcifs.v2.tar.xz')
os.system('tar -xzf rna3db-mmcifs.v2.tar.xz')
os.system('rm rna3db-mmcifs.v2.tar.xz')
struct_path = "rna3db-mmcifs.v2/rna3db-mmcifs"

# Donwload RNA3Db sequence files
os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-jsons.tar.gz')
os.system('tar -xvzf rna3db-jsons.tar.gz')
os.system('rm rna3db-jsons.tar.gz')
seq_path = "rna3db-jsons/split.json"

ModuleNotFoundError: No module named 'google.colab'

In [None]:
def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:5]

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np

class Converter(nn.Module):
    def __init__(self, max_seq_len, d_model=64, nhead=8, num_layers=6, dim_feedforward=256, dropout=0.1):
        super(Converter, self).__init__()
        
        self.d_model = d_model
        
        self.input_embedding = nn.Linear(1, d_model)
        
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_seq_len)
        
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                    dim_feedforward=dim_feedforward, 
                                                    dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        self.output_linear = nn.Linear(d_model, 20)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x, src_key_padding_mask=None):
        # x shape: (seq_len, batch_size, 1)
        
        x = self.input_embedding(x)  # Now: (seq_len, batch_size, d_model)
        
        x = self.pos_encoder(x)
        
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)

        x = self.output_linear(x)  # Now: (seq_len, batch_size, 20)
        x = self.softmax(x)
        
        # Convert softmaxxed matrices into one-dimensional indeces
        with torch.no_grad():
            out = []
            for i in range(len(x)):
                out.append([])
                for j in range(len(x[i])):
                    out[-1].append((torch.argmax(x[i][j].detach().cpu())).item()) 
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def create_padding_mask(sequences, pad_value=0):
    # sequences shape: (seq_len, batch_size, 1)
    return (sequences.squeeze(-1) == pad_value).t()  # (batch_size, seq_len)

In [None]:
import json
import sys

def parse_json(path, a, b, max_len=150):
    num = -1
    seqs = {}
    comps = []
    macros = []
    f = open(path)
    data = json.load(f)
    for i in data["train_set"]:
        for j in data["train_set"][i]:
            for k in data["train_set"][i][j]:
                num = num + 1
                if data["train_set"][i][j][k]["length"]>max_len:
                    continue
                if num>=a and num<=b:
                    seqs[k]=data["train_set"][i][j][k]["sequence"]
                    comps.append(i)
                    macros.append(j)
                if num>b:
                    break
    f.close()
    return seqs, comps, macros


In [None]:
seqs = {} # All sequences - may get quite large

# Used for file tree searching
components = []
macro_tags = []


# Index to amino acid dictionary
# Largely arbitrary, but must stay consistent for any given converter
AA_DICT = {
    0: "A",
    1: "R",
    2: "N",
    3: "D",
    4: "C",
    5: "Q",
    6: "E",
    7: "G",
    8: "H",
    9: "I",
    10: "L",
    11: "K",
    12: "M",
    13: "F",
    14: "P",
    15: "S",
    16: "T",
    17: "W",
    18: "Y",
    19: "V"
}

def load_data(path, a=0, b=float('inf'), max_len=150):
    # Load up sequences, components, and macro-tags
    seqs, components, macro_tags=parse_json(path, a, b, max_len=max_len)
    return seqs, components, macro_tags

def batch_data(iterable, n=1):
    # Data batching function
    l = len(iterable)
    iter = [(t, s) for t, s in list(iterable.items())]
    for ndx in range(0, l, n):
        yield iter[ndx:min(ndx + n, l)]

def encode_rna(seq):
    # Convert RNA sequence to nums to feed into Converter
    out = []
    for i in seq:
        if i=="A":
            out.append([0])
        elif i=="U":
            out.append([1])
        elif i=="C":
            out.append([2])
        elif i=="G":
            out.append([3])
    return out

def write_fastas(seqs):
    # Write a dict of {tag: seq} to as many FASTA files as needed
    os.makedirs('FASTAs', exist_ok=True)
    for tag, seq in list(seqs.items()):
        if os.path.exists(f'FASTAs/{tag}.fasta'):
            continue
        f = open(f"FASTAs/{tag}.fasta", "w+")
        f.write(f">{tag}\n{seq}")
        f.close()

def write_fastas(tags, seqs):
    # Write a dict of {tag: seq} to as many FASTA files as needed
    os.makedirs('FASTAs', exist_ok=True)
    for i in range(len(tags)):
        if os.path.exists(f'FASTAs/{tags[i]}.fasta'):
            continue
        f = open(f"FASTAs/{tags[i]}.fasta", "w+")
        f.write(f">{tags[i]}\n{seqs[i]}")
        f.close()

def empty_dir(path):
    # Empty any directory
    for f in os.listdir(path):
        os.remove(os.path.join(path, f))
    os.rmdir(path)

In [None]:
def get_structure(tag, path):
    # Return the structure of an RNA molecule given its tag and the path to the structure directory
    # File directory:
    # root
    #  -- colabfold.dir
    #  -- train_protify.ipynb
    #  -- data.dir
    #  ---- component 1.dir
    #  ------ tag 1.dir
    #  -------- tag 1a.cif
    #  -------- tag 1b.cif
    # ...
    index = list(seqs.keys()).index(tag)
    component = components[index]
    macro_tag = macro_tags[index]
    
    path = f"{path}\{component}\{macro_tag}\{tag}.cif"
    return path

In [None]:
#@markdown ### Advanced settings
model_type = "auto" #@param ["auto", "alphafold2_ptm", "alphafold2_multimer_v1", "alphafold2_multimer_v2", "alphafold2_multimer_v3", "deepfold_v1"]
#@markdown - if `auto` selected, will use `alphafold2_ptm` for monomer prediction and `alphafold2_multimer_v3` for complex prediction.
#@markdown Any of the mode_types can be used (regardless if input is monomer or complex).
num_recycles = "3" #@param ["auto", "0", "1", "3", "6", "12", "24", "48"]
#@markdown - if `auto` selected, will use `num_recycles=20` if `model_type=alphafold2_multimer_v3`, else `num_recycles=3` .
recycle_early_stop_tolerance = "auto" #@param ["auto", "0.0", "0.5", "1.0"]
#@markdown - if `auto` selected, will use `tol=0.5` if `model_type=alphafold2_multimer_v3` else `tol=0.0`.
relax_max_iterations = 200 #@param [0, 200, 2000] {type:"raw"}
#@markdown - max amber relax iterations, `0` = unlimited (AlphaFold2 default, can take very long)
pairing_strategy = "greedy" #@param ["greedy", "complete"] {type:"string"}
#@markdown - `greedy` = pair any taxonomically matching subsets, `complete` = all sequences have to match in one line.


#@markdown #### Sample settings
#@markdown -  enable dropouts and increase number of seeds to sample predictions from uncertainty of the model.
#@markdown -  decrease `max_msa` to increase uncertainity
max_msa = "auto" #@param ["auto", "512:1024", "256:512", "64:128", "32:64", "16:32"]
num_seeds = 1 #@param [1,2,4,8,16] {type:"raw"}
use_dropout = False #@param {type:"boolean"}

num_recycles = None if num_recycles == "auto" else int(num_recycles)
recycle_early_stop_tolerance = None if recycle_early_stop_tolerance == "auto" else float(recycle_early_stop_tolerance)
if max_msa == "auto": max_msa = None

In [None]:
#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)
msa_mode = "mmseqs2_uniref_env" #@param ["mmseqs2_uniref_env", "mmseqs2_uniref","single_sequence","custom"]
pair_mode = "unpaired_paired" #@param ["unpaired_paired","paired","unpaired"] {type:"string"}
#@markdown - "unpaired_paired" = pair sequences from same species + unpaired MSA, "unpaired" = seperate MSA for each chain, "paired" - only use paired sequences.

def a3ms(jobname):
    # decide which a3m to use
    if "mmseqs2" in msa_mode:
        a3m_file = os.path.join(jobname,f"{jobname}.a3m")

    elif msa_mode == "custom":
        a3m_file = os.path.join(jobname,f"{jobname}.custom.a3m")
        if not os.path.isfile(a3m_file):
            custom_msa_dict = files.upload()
            custom_msa = list(custom_msa_dict.keys())[0]
            header = 0
            import fileinput
            for line in fileinput.FileInput(custom_msa,inplace=1):
                if line.startswith(">"):
                    header = header + 1
                if not line.rstrip():
                    continue
                if line.startswith(">") == False and header == 1:
                    query_sequence = line.rstrip()
                print(line, end='')

            os.rename(custom_msa, a3m_file)
            queries_path=a3m_file
            print(f"moving {custom_msa} to {a3m_file}")

    else:
        a3m_file = os.path.join(jobname,f"{jobname}.single_sequence.a3m")
        with open(a3m_file, "w") as text_file:
            text_file.write(">1\n%s" % query_sequence)

In [None]:
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio import BiopythonDeprecationWarning
warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
from colabfold.plot import plot_msa_v2
import os
from google.colab import drive

def input_features_callback(input_features):
  pass

def prediction_callback(protein_obj, length,
                        prediction_result, input_features, mode):
  model_name, relaxed = mode
  pass

def train(epochs=50, batch_size=32,tm_score=False, max_seq_len=150):
    drive.mount('/content/gdrive')
    !mkdir -p "/content/drive/My Drive/ConverterWeights"
    try:
        K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
    except:
        K80_chk = "0"
        pass
    if "1" in K80_chk:
        print("WARNING: found GPU Tesla K80: limited to total length < 1000")
        if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
            del os.environ["TF_FORCE_UNIFIED_MEMORY"]
        if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
            del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

    # For some reason we need that to get pdbfixer to import
    if f"/usr/local/lib/python{python_version}/site-packages/" not in sys.path:
        sys.path.insert(0, f"/usr/local/lib/python{python_version}/site-packages/")

    conv = Converter(max_seq_len=max_seq_len)
    conv.train()
    optimizer = torch.optim.Adam(conv.parameters(), lr=1e-3)


    model_type = set_model_type(False, "auto")
    download_alphafold_params(model_type, Path("."))

    for epoch in range(epochs):
        for batch in batch_data(seqs, batch_size):
            optimizer.zero_grad()
            # batch: ([(tag, seq), (tag, seq),...])

            # LAYER 1: RNA-AMINO CONVERSION
            tags = [s[0] for s in batch]
            structs = [get_structure(tags[i]) for i in range(len(tags))]

            # Check that structure files exist
            # if not os.path.isfile(get_structure(tags[0])):
            #     continue
            
            # Preprocessing sequences
            processed_seqs = [torch.Tensor(np.transpose(np.array(encode_rna(s[1])), (0,1))) for s in batch] # (batch, seq, base)

            # Send sequences through the converter
            aa_seqs = [conv(s) for s in processed_seqs][0] # (seq, batch, aa)
            temp = []
            
            # Reconvert to letter representation
            for i in range(len(aa_seqs)):
                temp.append(''.join([AA_DICT[n] for n in aa_seqs[i]]))
                    
            aa_seqs = temp # (seq: String, batch)

            final_seqs = {} # {tag: seq}
            for i in range(len(tags)):
                final_seqs[tags[i]] = aa_seqs[i]
            write_fastas(final_seqs)

            num_relax = 0 #@param [0, 1, 5] {type:"raw"}
            #@markdown - specify how many of the top ranked structures to relax using amber
            template_mode = "none" #@param ["none", "pdb100","custom"]
            #@markdown - `none` = no template information is used. `pdb100` = detect templates in pdb100 (see [notes](#pdb100)). `custom` - upload and search own templates (PDB or mmCIF format, see [notes](#custom_templates))

            use_amber = num_relax > 0
            use_cluster_profile = True

            if template_mode == "pdb100":
                use_templates = True
                custom_template_path = None
            elif template_mode == "custom":
                custom_template_path = os.path.join(jobname,f"template")
                os.makedirs(custom_template_path, exist_ok=True)
                uploaded = files.upload()
                use_templates = True
            for fn in uploaded.keys():
                os.rename(fn,os.path.join(custom_template_path,fn))
            else:
                custom_template_path = None
                use_templates = False

            for i in range(len(final_seqs)):
                queries, _ = get_queries(f'FASTAs/{list(final_seqs.keys())[i]}')
                jobname = hash(list(final_seqs.keys())[i])
                results =  run(
                    queries=queries,
                    result_dir=jobname,
                    use_templates=USE_TEMPLATES,
                    custom_template_path=None,
                    num_relax=num_relax,
                    msa_mode=msa_mode,
                    model_type=model_type,
                    num_models=1,
                    num_recycles=num_recycles,
                    relax_max_iterations=relax_max_iterations,
                    recycle_early_stop_tolerance=recycle_early_stop_tolerance,
                    num_seeds=num_seeds,
                    use_dropout=use_dropout,
                    model_order=[1,2,3,4,5],
                    is_complex=False,
                    data_dir=Path("."),
                    keep_existing_results=False,
                    rank_by="auto",
                    pair_mode=pair_mode,
                    pairing_strategy=pairing_strategy,
                    stop_at_score=float(100),
                    prediction_callback=prediction_callback,
                    dpi=200,
                    zip_results=False,
                    save_all=False,
                    max_msa=max_msa,
                    use_cluster_profile=use_cluster_profile,
                    input_features_callback=input_features_callback,
                    save_recycles=False,
                    user_agent="colabfold/google-colab-main",
                )
            
            from protein_to_rna import protein_to_rna
            path = ""
            for file in os.listdir(f"/{jobname}"):
                if file.endswith(".pdb"):
                    path = os.path.join(f"/{jobname}", file)
                    break
            loss = protein_to_rna(path, get_structure(list(final_seqs.keys())[i], struct_path), tm=tm_score)
            empty_dir(f"/{jobname}")
            loss = torch.Tensor([loss])
            loss.backward()
            optimizer.step()
        torch.save(conv, f'/content/gdrive/My Drive/ConverterWeights/converter_epoch_{epoch}.pt')
        torch.save(conv.state_dict, f'/content/gdrive/My Drive/ConverterWeights/converter_params_epoch_{epoch}.pt')

            

In [None]:
load_data(seq_path, 0, 1615)
train()