<a href="https://colab.research.google.com/github/eldarhac/CustomProteinSequenceGeneration/blob/main/CustomProteinSequenceGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Novel Transformer Based Model Architecture for Function-Based Protein Sequence Generation 

## Design and Preperation

### Components:

1. Transformer-based Generative DNA sequenceing model
2. AlphaFold 2.0
3. PDB a graph 
4. Graph2Vec
5. Regression XGBoost


### Installations

In [None]:
!pip install transformers
!pip install biopandas -q
!pip install -q --no-warn-conflicts "colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold"
!pip install -q jax -f https://storage.googleapis.com/jax-releases/jax_releases.html

### Imports

In [None]:
import torch
import transformers
import numpy as np
from biopandas.pdb import PandasPdb
from google.colab import files
import os.path
import re
import hashlib
import random
from sys import version_info 
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from colabfold.download import download_alphafold_params, default_data_dir
from colabfold.utils import setup_logging
from colabfold.batch import get_queries, run, set_model_type
import os
try:
  K80_chk = os.popen('nvidia-smi | grep "Tesla K80" | wc -l').read()
except:
  K80_chk = "0"
  pass
if "1" in K80_chk:
  print("WARNING: found GPU Tesla K80: limited to total length < 1000")
  if "TF_FORCE_UNIFIED_MEMORY" in os.environ:
    del os.environ["TF_FORCE_UNIFIED_MEMORY"]
  if "XLA_PYTHON_CLIENT_MEM_FRACTION" in os.environ:
    del os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]

from colabfold.colabfold import plot_protein
from pathlib import Path
from typing import Union, List

## Amino-Acid Sequence Generative Model

### Prepare Data

In [None]:
def load_sequence_from_fasta(filename: str) -> Union[List[str], str]:
  sequences = []
  sequence = ''
  with open(filename, "r") as f:
      for line in f:
          if line.startswith(">") and len(sequence):
              sequences.append(sequence)
              sequence = ''
          else:
              sequence += line
  return sequences if len(sequences) > 1 else sequence

def load_dataset(folder_path: str):
  return [load_sequence_from_fasta(os.path.join(folder_path, f)) 
          for f in os.listdir(folder_path) 
          if os.path.isfile(os.path.join(folder_path, f)) and f.endswith('.fasta')]


## AlphaFold 2.0

### Prepare AlphaFold Model

In [None]:
python_version = f"{version_info.major}.{version_info.minor}"

# number of models to use
use_amber = False
template_mode = "none" 
custom_template_path = None
use_templates = False

# Defining modes
msa_mode = "MMseqs2 (UniRef+Environmental)"
pair_mode = "unpaired+paired"

# Advanced settings
num_recycles = 3 
save_to_google_drive = False

dpi = 200 

if 'logging_setup' not in globals():
    setup_logging(Path(".").joinpath("log.txt"))
    logging_setup = True

COMPLEX_PARAMS_PATH = "./complex"
SIMPLE_PARAMS_PATH = "./simple"
if not os.path.isdir(COMPLEX_PARAMS_PATH):
  os.mkdir(COMPLEX_PARAMS_PATH)
if not os.path.isdir(SIMPLE_PARAMS_PATH):
  os.mkdir(SIMPLE_PARAMS_PATH)

complex_model_type = "AlphaFold2-multimer-v2"
download_alphafold_params(complex_model_type, Path(COMPLEX_PARAMS_PATH))
simple_model_type = "AlphaFold2-ptm"
download_alphafold_params(simple_model_type, Path(SIMPLE_PARAMS_PATH))

### AlphaFold Model Runner

In [None]:
result_dir="./results"
if not os.path.isdir(result_dir):
  os.mkdir(result_dir)

def clean_folders(jobname):
  for filename in os.listdir("."):
    if os.path.isfile(filename):
      os.remove(filename)
  for filename in os.listdir(result_dir):
    if os.path.isfile(os.path.join(result_dir, filename)):
      os.remove(os.path.join(result_dir, filename))
  jobname_envpath = os.path.join(result_dir, f"{jobname}_env")
  for filename in os.listdir(jobname_envpath):
    if os.path.isfile(os.path.join(jobname_envpath, filename)):
      os.remove(os.path.join(jobname_envpath, filename))
  os.rmdir(os.path.join(result_dir, f"{jobname}_env"))


def get_hash(y):
  return hashlib.sha1(y.encode()).hexdigest()[:5]

def run_alphafold_model(query_sequences: Union[str, List[str]]):
  # remove whitespaces
  if isinstance(query_sequences, str):
    query_sequence = "".join(query_sequences.split())

    # remove whitespaces
    jobname = get_hash(query_sequence)
    while os.path.isfile(f"{jobname}.csv"):
      jobname = get_hash(''.join(random.sample(query_sequence,len(query_sequence))))
    queries_path=f"{jobname}.csv"
    with open(queries_path, "w") as text_file:
        text_file.write(f"id,sequence\n{jobname},{query_sequence}")
  elif isinstance(query_sequences, list):
    for i, query_sequence in enumerate(query_sequences):
      query_sequences[i] = "".join(query_sequence.split())

    # remove whitespaces
    joint_sequences = "".join(query_sequences)
    jobname = get_hash(joint_sequences)
    while os.path.isfile(f"{jobname}.csv"):
      jobname = get_hash(''.join(random.sample(joint_sequences,len(joint_sequences))))
    queries_path=f"{jobname}.csv"
    with open(queries_path, "w") as text_file:
        text_file.write(f"id,sequence\n" + 
                        '\n'.join([f"{get_hash(sequence)},{sequence}" for sequence in query_sequences]))

  queries, is_complex = get_queries(queries_path)
  os.remove(queries_path)
  run(
      queries=queries,
      result_dir=result_dir,
      use_templates=use_templates,
      custom_template_path=custom_template_path,
      use_amber=use_amber,
      msa_mode=msa_mode,    
      model_type=complex_model_type if is_complex else simple_model_type,
      num_models=1,
      num_recycles=num_recycles,
      model_order=[5],
      is_complex=is_complex,
      data_dir=Path(COMPLEX_PARAMS_PATH if is_complex else SIMPLE_PARAMS_PATH),
      keep_existing_results=False,
      recompile_padding=1.0,
      rank_by="auto",
      pair_mode=pair_mode,
      stop_at_score=float(100),
      dpi=dpi,
      zip_results=True
  )
  directory = "./protein_structures"
  if not os.path.isdir(directory):
    os.mkdir(directory)
  os.system(f"unzip {result_dir}/{jobname}.result.zip")
  os.system(f"cp {jobname}_unrelaxed_rank_1_model_5.pdb {directory}")
  clean_folders(jobname)