### Install dependencies

In [None]:
import os
from google.colab import files
import re
import hashlib
import random

from sys import version_info
python_version = f"{version_info.major}.{version_info.minor}"

os.system("pip install biopython")
from Bio.PDB import *

USE_AMBER = True
USE_TEMPLATES = False
PYTHON_VERSION = python_version

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  os.system("touch COLABFOLD_READY")

if USE_AMBER or USE_TEMPLATES:
  if not os.path.isfile("CONDA_READY"):
    print("installing conda...")
    os.system("wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh")
    os.system("bash Miniforge3-Linux-x86_64.sh -bfp /usr/local")
    os.system("mamba config --set auto_update_conda false")
    os.system("touch CONDA_READY")

if USE_TEMPLATES and not os.path.isfile("HH_READY") and USE_AMBER and not os.path.isfile("AMBER_READY"):
  print("installing hhsuite and amber...")
  os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
  os.system("touch HH_READY")
  os.system("touch AMBER_READY")
else:
  if USE_TEMPLATES and not os.path.isfile("HH_READY"):
    print("installing hhsuite...")
    os.system(f"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'")
    os.system("touch HH_READY")
  if USE_AMBER and not os.path.isfile("AMBER_READY"):
    print("installing amber...")
    os.system(f"mamba install -y -c conda-forge openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer")
    os.system("touch AMBER_READY")

if not os.path.exists("rna3db-mmcifs"):
  print("Downloading RNA3Db...")
  # Donwload RNA3Db structure files
  os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-mmcifs.v2.tar.xz')
  print("Extracting structure files...")
  os.system('sudo tar -xf rna3db-mmcifs.v2.tar.xz')
  os.system('rm rna3db-mmcifs.v2.tar.xz')

  # Donwload RNA3Db sequence files
  os.system('wget https://github.com/marcellszi/rna3db/releases/download/incremental-update/rna3db-jsons.tar.gz')
  print("Extracting sequence files...")
  os.system('tar -xzf rna3db-jsons.tar.gz')
  os.system('rm rna3db-jsons.tar.gz')
seq_path = "/content/rna3db-jsons/cluster.json"
struct_path = "/rna3db-mmcifs/"

installing colabfold...
installing conda...
installing amber...
Downloading RNA3Db...
Extracting structure files...
Extracting sequence files...


In [2]:
def add_hash(x,y):
  return x+"_"+hashlib.sha1(y.encode()).hexdigest()[:5]

### Converter Class
This class is essentially just a Transformer network. Input: list of list of ints, ex [[0], [3], [1],...].
Possible future routes of inquiry: Mamba modifications to the Transformer, other types of RNN

In [4]:
import torch
import torch.nn as nn
import math
import numpy as np

class Converter(nn.Module):
    def __init__(self, max_seq_len=150, d_model=64, nhead=8, num_layers=6, dim_feedforward=256, dropout=0.1):
        super(Converter, self).__init__()

        self.d_model = d_model

        self.input_embedding = nn.Linear(1, d_model)

        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_seq_len)

        self.transformer = nn.Transformer(d_model=d_model,
                                    nhead=nhead,
                                    dim_feedforward=dim_feedforward,
                                    num_encoder_layers=num_layers, num_decoder_layers=num_layers)

        self.output_linear = nn.Linear(d_model, 20)
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, x, src_key_padding_mask=None):
        # x shape: (seq_len, batch_size, 1)

        x = self.input_embedding(x)  # Now: (seq_len, batch_size, d_model)

        x = self.pos_encoder(x)

        x = self.transformer(x, x, src_key_padding_mask=src_key_padding_mask)

        x = self.output_linear(x)  # Now: (seq_len, batch_size, 20)
        x = self.softmax(x)

        # Convert softmaxxed matrices into one-dimensional indeces
        with torch.no_grad():
            out = []
            for i in range(len(x)):
                out.append([])
                for j in range(len(x[i])):
                    out[-1].append((torch.argmax(x[i][j].detach().cpu())).item())
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def create_padding_mask(sequences, pad_value=0):
    # sequences shape: (seq_len, batch_size, 1)
    return (sequences.squeeze(-1) == pad_value).t()  # (batch_size, seq_len)

### Parse json files for sequence information

In [5]:
import json
import sys

def parse_json(path, a, b, max_len=150):
    num = -1
    seqs = {}
    comps = []
    macros = []
    f = open(path)
    data = json.load(f)
    for i in data:
        for j in data[i]:
          for k in data[i][j]:
            num = num + 1
            if data[i][j][k]["length"]>max_len:
                continue
            if num>=a and num<=b:
                seqs[k]=data[i][j][k]["sequence"]
                comps.append(i)
                macros.append(j)
            if num>b:
                break
    f.close()
    return seqs, comps, macros


### Data Methods

In [6]:
seqs = {} # All sequences - may get quite large

# Used for file tree searching
components = []
macro_tags = []


# Index to amino acid dictionary
# Largely arbitrary, but must stay consistent for any given converter
AA_DICT = {
    0: "A",
    1: "R",
    2: "N",
    3: "D",
    4: "C",
    5: "Q",
    6: "E",
    7: "G",
    8: "H",
    9: "I",
    10: "L",
    11: "K",
    12: "M",
    13: "F",
    14: "P",
    15: "S",
    16: "T",
    17: "W",
    18: "Y",
    19: "V"
}

def load_data(path, a=0, b=float('inf'), max_len=150):
    # Load up sequences, components, and macro-tags
    seqs, components, macro_tags=parse_json(path, a, b, max_len=max_len)
    print(f"Found {len(seqs)} usable RNA strands...")
    return seqs, components, macro_tags

def batch_data(iterable, n=1):
    # Random data batching function
    l = len(iterable)
    iter = [(t, s) for t, s in list(iterable.items())]
    random.shuffle(iter)
    for ndx in range(0, l, n):
        yield iter[ndx:min(ndx + n, l)]

def encode_rna(seq):
    # Convert RNA sequence to nums to feed into Converter
    out = []
    for i in seq:
        if i=="A":
            out.append([0])
        elif i=="U":
            out.append([1])
        elif i=="C":
            out.append([2])
        elif i=="G":
            out.append([3])
    return out

def write_fastas(seqs):
    # Write a dict of {tag: seq} to as many FASTA files as needed
    os.makedirs('FASTAs', exist_ok=True)
    for tag, seq in list(seqs.items()):
        if os.path.exists(f'/content/FASTAs/{tag}.fasta'):
            continue
        f = open(f"/content/FASTAs/{tag}.fasta", "w+")
        f.write(f">{tag}\n{seq}")
        f.close()

def empty_dir(path, delete=True):
    # Empty any directory
    for f in os.listdir(path):
        if os.path.isfile(os.path.join(path, f)):
          os.remove(os.path.join(path, f))
        else:
          empty_dir(os.path.join(path, f))
    if delete:
      os.rmdir(path)

### Search the RNA3Db filetree for structure mmcifs

In [7]:
def get_structure(tag, path):
    # Return the structure of an RNA molecule given its tag and the path to the structure directory
    # File directory:
    # root
    #  -- colabfold.dir
    #  -- train_protify.ipynb
    #  -- data.dir
    #  ---- component 1.dir
    #  ------ tag 1.dir
    #  -------- tag 1a.cif
    #  -------- tag 1b.cif
    # ...
    index = list(seqs.keys()).index(tag)
    component = components[index]
    macro_tag = macro_tags[index]

    path = f"/content/{path}/train_set/{component}/{macro_tag}/{tag}.cif"
    return path

### Advanced settings

In [None]:
model_type = "auto"
num_recycles = "3"
recycle_early_stop_tolerance = "auto"
relax_max_iterations = 200 
pairing_strategy = "greedy"

max_msa = "auto" #@param ["auto", "512:1024", "256:512", "64:128", "32:64", "16:32"]
num_seeds = 1 #@param [1,2,4,8,16] {type:"raw"}
use_dropout = False #@param {type:"boolean"}

num_recycles = None if num_recycles == "auto" else int(num_recycles)
recycle_early_stop_tolerance = None if recycle_early_stop_tolerance == "auto" else float(recycle_early_stop_tolerance)
if max_msa == "auto": max_msa = None

### MSA options
For now, keep as single_sequence - otherwise, the API will be flooded with requests

In [None]:
msa_mode = "single_sequence" #@param ["mmseqs2_uniref_env", "mmseqs2_uniref","single_sequence","custom"]
pair_mode = "unpaired_paired" #@param ["unpaired_paired","paired","unpaired"] {type:"string"}
#@markdown - "unpaired_paired" = pair sequences from same species + unpaired MSA, "unpaired" = seperate MSA for each chain, "paired" - only use paired sequences.

### Loss functions

In [11]:
def RMSD(p1, p2):
    if len(p1)>len(p2):
      loss = torch.sqrt(torch.mean((p1[:len(p2)] - p2)**2))
    else:
      loss = torch.sqrt(torch.mean((p1 - p2[:len(p1)])**2))
    return loss

def tm_score(p1, p2, lt):
    d0 = lambda l: 1.24 * torch.power(l-15, 3) - 1.8
    loss = torch.mean(1/(1+torch.power(torch.abs(torch.norm(p1-p2))/d0(lt),2)))
    return loss

### RNA and Protein MMCIF/PDB Parsers

In [12]:
def parse_rna(path):
    try:
        parser = MMCIFParser()
        structure = parser.get_structure("RNA", path)
        data = []
        for model in structure:
          for chain in model:
              for residue in chain:
                  for atom in residue:
                    if residue.get_resname() in ['A', 'U', 'C', 'G']:
                      datum = list(atom.get_vector())
                      temp = (datum[0], datum[1], datum[2], atom.get_name())
                      data.append(temp)

        points = []
        angle_points = []
        norms = []

        correction_factor = torch.zeros(3, dtype=torch.float32, requires_grad=False)

        for x, y, z, atom in data:
            x = float(x)
            y = float(y)
            z = float(z)

            point = np.add(np.array([x,y,z]), correction_factor)

            if atom == "P":
              if (correction_factor==torch.zeros(3)).all():
                correction_factor = torch.tensor([-x, -y, -z])
              points.append(point)
              angle_points.append(point)
            elif atom == "\"C1'\"":
                angle_points.append(point)
            elif atom == "\"C4'\"":
                angle_points.append(point)
                v1 = angle_points[-1]-angle_points[-2]
                v2 = angle_points[-3]-angle_points[-2]
                norms.append(np.cross(v1, v2))
                angle_points = []
        points = np.array(points)
        norms = np.array(norms)
        return torch.tensor(points, requires_grad=True, dtype=torch.float32), torch.tensor(norms, requires_grad=True, dtype=torch.float32)

    except Exception as e:
        print("Oops. %s" % e)
        sys.exit(1)

def parse_protein(path):
    try:
        parser = PDBParser()
        structure = parser.get_structure("Protein", path)
        data = []
        for model in structure:
          for chain in model:
              for residue in chain:
                  for atom in residue:
                      datum = list(atom.get_vector())
                      temp = (datum[0], datum[1], datum[2], atom.get_name())
                      data.append(temp)

        points = []
        angle_points = []
        norms = []

        correction_factor = torch.zeros(3, dtype=torch.float32, requires_grad=False)

        for x, y, z, atom in data:
            x = float(x)
            y = float(y)
            z = float(z)

            point = np.add(np.array([x,y,z]), correction_factor)
            if atom == "CA":
              if (correction_factor==torch.zeros(3)).all():
                correction_factor = torch.tensor([-x, -y, -z])
              points.append(point)
              angle_points.append(point)
            elif atom == "N":
                angle_points.append(point)
            elif atom == "C":
                angle_points.append(point)
                v1 = angle_points[-1]-angle_points[-2]
                v2 = angle_points[-3]-angle_points[-2]
                norms.append(np.cross(v1, v2))
                angle_points = []

        points = np.array(points)
        norms = np.array(norms)
        return torch.tensor(points, requires_grad=True), torch.tensor(norms, requires_grad=True)

    except Exception as e:
        print("Oops. %s" % e)
        sys.exit(1)

### RNA/Protein Comparison function + Protein coordinate correction

In [13]:
def protein_to_rna(protein, rna_path, corrector, tm=False):
    prot_points, _ = parse_protein(protein)
    rna_points, _ = parse_rna(rna_path)
    prot_points = correct_protein_coords(prot_points, corrector)
    if tm:
        return tm_score(prot_points, rna_points)
    return RMSD(prot_points, rna_points)

def correct_protein_coords(points, corrector):
    correction_factor = corrector.unsqueeze(0)

    # Calculate vector differences between consecutive points
    vectors = points[1:] - points[:-1]
    norms = torch.norm(vectors, dim=1, keepdim=True)
    normalized_vectors = vectors / norms

    # Apply correction factor
    corrected_vectors = normalized_vectors * correction_factor

    corrected_points = torch.zeros_like(points)
    corrected_points[0] = points[0]
    corrected_points[1:] = points[:-1] + corrected_vectors

    return corrected_points

### Training Loop

In [None]:
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio import BiopythonDeprecationWarning
warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)
from pathlib import Path
from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
from colabfold.plot import plot_msa_v2
import os
from google.colab import drive
from tqdm import tqdm
import shutil

def input_features_callback(input_features):
  pass

def prediction_callback(protein_obj, length,
                        prediction_result, input_features, mode):
  model_name, relaxed = mode
  pass

def train(seqs, epochs=50, batch_size=32,tm_score=False, max_seq_len=150, converter=None, pp_dist=6.8):
    !mkdir -p "/content/drive/My Drive/ConverterWeights"
    try:
        K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
    except:
        K80_chk = "0"
        pass
    if "1" in K80_chk:
        print("WARNING: found GPU Tesla K80: limited to total length < 1000")
        if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
            del os.environ["TF_FORCE_UNIFIED_MEMORY"]
        if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
            del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

    # For some reason we need that to get pdbfixer to import
    if f"/usr/local/lib/python{python_version}/site-packages/" not in sys.path:
        sys.path.insert(0, f"/usr/local/lib/python{python_version}/site-packages/")

    if converter==None:
      conv = Converter(max_seq_len=max_seq_len)
    else:
      conv = converter
    conv.train()
    corrector = [nn.Parameter(torch.tensor(pp_dist, requires_grad=True, dtype=torch.float32))] # Can't be bothered to do research, so I'll just regress it
    optimizer = torch.optim.AdamW(conv.parameters(), lr=1e-2)
    dist_optimizer = torch.optim.AdamW(corrector, lr=1e-2)

    model_type = set_model_type(False, "auto")
    download_alphafold_params(model_type, Path("."))
    for epoch in range(epochs):
        for batch in batch_data(seqs, batch_size):
            optimizer.zero_grad()
            dist_optimizer.zero_grad()
            # batch: ([(tag, seq), (tag, seq),...])

            # LAYER 1: RNA-AMINO CONVERSION
            tags = [s[0] for s in batch]

            # Preprocessing sequences
            processed_seqs = [torch.tensor(np.transpose(np.array(encode_rna(s[1])), (0,1)), requires_grad=False, dtype=torch.float32) for s in batch] # (batch, seq, base)

            # Send sequences through the converter
            aa_seqs = [conv(s) for s in processed_seqs][0] # (seq, batch, aa)
            temp = []

            # Reconvert to letter representation
            for i in range(len(aa_seqs)):
                temp.append(''.join([AA_DICT[n] for n in aa_seqs[i]]))

            aa_seqs = temp # (seq: String, batch)

            final_seqs = {} # {tag: seq}
            for i in range(len(tags)):
                final_seqs[tags[i]] = aa_seqs[i]
            write_fastas(final_seqs)

            num_relax = 1 #@param [0, 1, 5] {type:"raw"}
            #@markdown - specify how many of the top ranked structures to relax using amber
            template_mode = "none" #@param ["none", "pdb100","custom"]
            #@markdown - `none` = no template information is used. `pdb100` = detect templates in pdb100 (see [notes](#pdb100)). `custom` - upload and search own templates (PDB or mmCIF format, see [notes](#custom_templates))

            use_amber = num_relax > 0
            use_cluster_profile = True

            if template_mode == "pdb100":
                use_templates = True
                custom_template_path = None
            elif template_mode == "custom":
                custom_template_path = os.path.join(jobname,f"template")
                os.makedirs(custom_template_path, exist_ok=True)
                uploaded = files.upload()
                use_templates = True
                for fn in uploaded.keys():
                    os.rename(fn,os.path.join(custom_template_path,fn))
            else:
                custom_template_path = None
                use_templates = False

            loss = []
            lengths = 0

            for i in tqdm(range(len(final_seqs))):
              lengths=lengths+list(final_seqs.values())[i]
              with torch.no_grad():
                queries, _ = get_queries(f'/content/FASTAs/{list(final_seqs.keys())[i]}.fasta')
                jobname = add_hash(list(final_seqs.keys())[i], list(final_seqs.values())[i])
                results = run(
                    queries=queries,
                    result_dir=jobname,
                    use_templates=USE_TEMPLATES,
                    custom_template_path=None,
                    num_relax=num_relax,
                    msa_mode=msa_mode,
                    model_type=model_type,
                    num_models=1,
                    num_recycles=num_recycles,
                    relax_max_iterations=relax_max_iterations,
                    recycle_early_stop_tolerance=recycle_early_stop_tolerance,
                    num_seeds=num_seeds,
                    use_dropout=use_dropout,
                    model_order=[1,2,3,4,5],
                    is_complex=False,
                    data_dir=Path("."),
                    keep_existing_results=False,
                    rank_by="auto",
                    pair_mode=pair_mode,
                    pairing_strategy=pairing_strategy,
                    stop_at_score=float(100),
                    prediction_callback=prediction_callback,
                    dpi=200,
                    zip_results=False,
                    save_all=False,
                    max_msa=max_msa,
                    use_cluster_profile=use_cluster_profile,
                    input_features_callback=input_features_callback,
                    save_recycles=False,
                    user_agent="colabfold/google-colab-main",
                )
                path = ""
                for file in os.listdir(f"/content/{jobname}"):
                  if file.endswith(".pdb"):
                    path = os.path.join(f"/content/{jobname}", file)
                    break
              temp_loss = (protein_to_rna(path, get_structure(list(final_seqs.keys())[i], struct_path), corrector[0], tm=tm_score))

                # Download generated/actual for qualitative comp
                # shutil.copy(path, "/content/generated.pdb")
                # shutil.copy(get_structure(list(final_seqs.keys())[i], struct_path), "/content/actual.cif")
              with torch.no_grad():
                loss.append(temp_loss)
                empty_dir(f"/content/{jobname}")

            lengths = lengths/batch_size

            empty_dir("FASTAs", delete=False)
            loss = torch.mean(torch.stack(loss))
            print(f"\n\nCurrent Loss: {loss}")
            print(f"Average Loss per Residue: {loss/lengths}")
            print(f"Correction factor: {corrector}\n\n")
            loss.backward()
            optimizer.step()
            dist_optimizer.step()
        torch.save(conv, f'/content/drive/My Drive/ConverterWeights/converter_epoch_{epoch}.pt')
        torch.save(conv.state_dict(), f'/content/drive/My Drive/ConverterWeights/converter_params_epoch_{epoch}.pt')
        torch.save(corrector, f'/content/drive/My Drive/ConverterWeights/corrector_epoch_{epoch}.pt')

### Main function

In [15]:
seqs, components, macro_tags = load_data(seq_path, 0, 1615, max_len=100)
drive.mount('/content/drive')
c = Converter(max_seq_len=200)
corrector = [nn.Parameter(torch.tensor(0, requires_grad=True, dtype=torch.float32))]
# try:
#   c = torch.load('/content/drive/My Drive/ConverterWeights/converter.pt')
#   corrector = torch.load('/content/drive/My Drive/ConverterWeights/corrector.pt')
#   print("Loaded model from checkpoint")
# except:
#   c = Converter(max_seq_len=200)
#   ckpt = torch.load('/content/drive/My Drive/ConverterWeights/converter_params.pt', weights_only=True)
#   corrector = corrector = torch.load('/content/drive/My Drive/ConverterWeights/corrector.pt')
#   c.load_state_dict(ckpt['model_state_dict'])
#   print("Loaded model parameters from checkpoint")
train(seqs, epochs=1, batch_size=4, max_seq_len=100, converter=c)#, pp_dist=float(corrector[0]))

Found 695 usable RNA strands...
Mounted at /content/drive


Downloading alphafold2_ptm weights to .: 100%|██████████| 3.47G/3.47G [02:39<00:00, 23.4MB/s]
100%|██████████| 4/4 [02:33<00:00, 38.33s/it]




Current Loss: 47.17895296083071
Correction factor: [Parameter containing:
tensor(6.8000, requires_grad=True)]




100%|██████████| 4/4 [02:16<00:00, 34.15s/it]




Current Loss: 39.04995134218149
Correction factor: [Parameter containing:
tensor(6.7893, requires_grad=True)]




100%|██████████| 4/4 [01:54<00:00, 28.73s/it]




Current Loss: 48.61932220953027
Correction factor: [Parameter containing:
tensor(6.7786, requires_grad=True)]




100%|██████████| 4/4 [02:37<00:00, 39.46s/it]




Current Loss: 50.61705792991525
Correction factor: [Parameter containing:
tensor(6.7680, requires_grad=True)]




100%|██████████| 4/4 [02:43<00:00, 40.96s/it]




Current Loss: 42.66410344176859
Correction factor: [Parameter containing:
tensor(6.7576, requires_grad=True)]




100%|██████████| 4/4 [02:54<00:00, 43.67s/it]




Current Loss: 35.5614499047273
Correction factor: [Parameter containing:
tensor(6.7471, requires_grad=True)]




100%|██████████| 4/4 [02:27<00:00, 36.85s/it]




Current Loss: 46.01149324720523
Correction factor: [Parameter containing:
tensor(6.7365, requires_grad=True)]




100%|██████████| 4/4 [02:20<00:00, 35.24s/it]




Current Loss: 40.22920174075317
Correction factor: [Parameter containing:
tensor(6.7259, requires_grad=True)]




100%|██████████| 4/4 [02:28<00:00, 37.02s/it]




Current Loss: 49.02489395287125
Correction factor: [Parameter containing:
tensor(6.7152, requires_grad=True)]




100%|██████████| 4/4 [01:51<00:00, 27.99s/it]




Current Loss: 44.27498864615886
Correction factor: [Parameter containing:
tensor(6.7045, requires_grad=True)]




100%|██████████| 4/4 [01:59<00:00, 29.98s/it]




Current Loss: 39.82487112158808
Correction factor: [Parameter containing:
tensor(6.6942, requires_grad=True)]




100%|██████████| 4/4 [02:00<00:00, 30.12s/it]




Current Loss: 47.878161987285296
Correction factor: [Parameter containing:
tensor(6.6842, requires_grad=True)]




100%|██████████| 4/4 [02:07<00:00, 31.83s/it]




Current Loss: 44.06607922199578
Correction factor: [Parameter containing:
tensor(6.6742, requires_grad=True)]




100%|██████████| 4/4 [02:28<00:00, 37.14s/it]




Current Loss: 47.1476896014951
Correction factor: [Parameter containing:
tensor(6.6642, requires_grad=True)]




100%|██████████| 4/4 [02:12<00:00, 33.10s/it]




Current Loss: 50.90345010544664
Correction factor: [Parameter containing:
tensor(6.6542, requires_grad=True)]




100%|██████████| 4/4 [02:34<00:00, 38.66s/it]




Current Loss: 44.39965157721785
Correction factor: [Parameter containing:
tensor(6.6444, requires_grad=True)]




100%|██████████| 4/4 [02:03<00:00, 30.95s/it]




Current Loss: 38.2050279641886
Correction factor: [Parameter containing:
tensor(6.6344, requires_grad=True)]




100%|██████████| 4/4 [02:11<00:00, 32.86s/it]




Current Loss: 44.27729844124853
Correction factor: [Parameter containing:
tensor(6.6244, requires_grad=True)]




100%|██████████| 4/4 [02:07<00:00, 31.91s/it]




Current Loss: 36.69503401567894
Correction factor: [Parameter containing:
tensor(6.6148, requires_grad=True)]




100%|██████████| 4/4 [02:30<00:00, 37.68s/it]




Current Loss: 41.74346670041062
Correction factor: [Parameter containing:
tensor(6.6054, requires_grad=True)]




100%|██████████| 4/4 [02:08<00:00, 32.15s/it]




Current Loss: 53.983660381045155
Correction factor: [Parameter containing:
tensor(6.5959, requires_grad=True)]




100%|██████████| 4/4 [02:11<00:00, 32.92s/it]




Current Loss: 42.837888798814625
Correction factor: [Parameter containing:
tensor(6.5864, requires_grad=True)]




100%|██████████| 4/4 [02:15<00:00, 33.97s/it]




Current Loss: 41.895656433131904
Correction factor: [Parameter containing:
tensor(6.5770, requires_grad=True)]




100%|██████████| 4/4 [02:31<00:00, 37.99s/it]




Current Loss: 42.96268997115444
Correction factor: [Parameter containing:
tensor(6.5670, requires_grad=True)]




100%|██████████| 4/4 [02:12<00:00, 33.06s/it]




Current Loss: 37.856721998004524
Correction factor: [Parameter containing:
tensor(6.5569, requires_grad=True)]




100%|██████████| 4/4 [02:13<00:00, 33.35s/it]




Current Loss: 39.87874548816668
Correction factor: [Parameter containing:
tensor(6.5466, requires_grad=True)]




100%|██████████| 4/4 [01:57<00:00, 29.39s/it]




Current Loss: 47.45169509096484
Correction factor: [Parameter containing:
tensor(6.5365, requires_grad=True)]




100%|██████████| 4/4 [02:04<00:00, 31.13s/it]




Current Loss: 45.73169420760474
Correction factor: [Parameter containing:
tensor(6.5260, requires_grad=True)]




100%|██████████| 4/4 [01:55<00:00, 28.88s/it]




Current Loss: 45.389442360893156
Correction factor: [Parameter containing:
tensor(6.5156, requires_grad=True)]




100%|██████████| 4/4 [01:57<00:00, 29.30s/it]




Current Loss: 39.33420107457361
Correction factor: [Parameter containing:
tensor(6.5054, requires_grad=True)]




100%|██████████| 4/4 [02:12<00:00, 33.20s/it]




Current Loss: 43.11103125082031
Correction factor: [Parameter containing:
tensor(6.4957, requires_grad=True)]




100%|██████████| 4/4 [01:55<00:00, 28.99s/it]




Current Loss: 41.005447382473676
Correction factor: [Parameter containing:
tensor(6.4858, requires_grad=True)]




100%|██████████| 4/4 [02:03<00:00, 30.88s/it]




Current Loss: 55.548762357178994
Correction factor: [Parameter containing:
tensor(6.4758, requires_grad=True)]




100%|██████████| 4/4 [02:00<00:00, 30.00s/it]




Current Loss: 38.887389934746764
Correction factor: [Parameter containing:
tensor(6.4655, requires_grad=True)]




100%|██████████| 4/4 [02:19<00:00, 34.97s/it]




Current Loss: 31.290606322672332
Correction factor: [Parameter containing:
tensor(6.4549, requires_grad=True)]




100%|██████████| 4/4 [02:18<00:00, 34.50s/it]




Current Loss: 57.378527914526806
Correction factor: [Parameter containing:
tensor(6.4441, requires_grad=True)]




100%|██████████| 4/4 [01:51<00:00, 27.94s/it]




Current Loss: 43.25209638923713
Correction factor: [Parameter containing:
tensor(6.4331, requires_grad=True)]




100%|██████████| 4/4 [02:00<00:00, 30.09s/it]




Current Loss: 46.466522060725325
Correction factor: [Parameter containing:
tensor(6.4223, requires_grad=True)]




100%|██████████| 4/4 [02:23<00:00, 35.93s/it]




Current Loss: 50.35959143724527
Correction factor: [Parameter containing:
tensor(6.4113, requires_grad=True)]




100%|██████████| 4/4 [03:17<00:00, 49.42s/it]




Current Loss: 40.593850687496975
Correction factor: [Parameter containing:
tensor(6.4007, requires_grad=True)]




100%|██████████| 4/4 [02:16<00:00, 34.16s/it]




Current Loss: 42.792528917373616
Correction factor: [Parameter containing:
tensor(6.3900, requires_grad=True)]




100%|██████████| 4/4 [02:10<00:00, 32.56s/it]




Current Loss: 49.845047082073904
Correction factor: [Parameter containing:
tensor(6.3797, requires_grad=True)]




100%|██████████| 4/4 [01:58<00:00, 29.61s/it]




Current Loss: 43.30407825552454
Correction factor: [Parameter containing:
tensor(6.3689, requires_grad=True)]




100%|██████████| 4/4 [02:03<00:00, 30.84s/it]




Current Loss: 43.82769929886865
Correction factor: [Parameter containing:
tensor(6.3584, requires_grad=True)]




  0%|          | 0/4 [00:02<?, ?it/s]


KeyboardInterrupt: 

### Save model (params)

In [16]:
torch.save(c, f'/content/drive/My Drive/ConverterWeights/converter.pt')
torch.save(c.state_dict(), f'/content/drive/My Drive/ConverterWeights/converter_params.pt')
torch.save(corrector, f'/content/drive/My Drive/ConverterWeights/corrector.pt')

### Test model on dummy sequence

In [None]:
# Predict
s = "AUGCGGGAAAAAUUCG"
processed_seq = torch.tensor(np.transpose(np.array(encode_rna(s)), (0,1)), requires_grad=False, dtype=torch.float32) # (batch, seq, base)
# Send sequences through the converter
aa_seqs = c(processed_seq)[0] # (seq, batch, aa)
temp = []

# Reconvert to letter representation
temp.append(''.join([AA_DICT[n] for n in aa_seqs]))

aa_seqs = temp # (seq: String, batch)
print(aa_seqs[0])

HPGPPDNPHGHHHPPH
