In [2]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [3]:
!pip install biopython

Collecting biopython
  Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading biopython-1.84-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.84


In [4]:
msa_transformer, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval()
msa_batch_converter = msa_alphabet.get_batch_converter()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_msa1b_t12_100M_UR50S.pt" to /root/.cache/torch/hub/checkpoints/esm_msa1b_t12_100M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm_msa1b_t12_100M_UR50S-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm_msa1b_t12_100M_UR50S-contact-regression.pt


In [5]:
#[(n, type(m)) for n, m in msa_transformer.named_modules()]

In [1]:
# Start with importing libraries
import os
import random
import pathlib
import itertools
import string
from typing import List, Tuple
import warnings

import tqdm
import peft

import numpy as np
import pandas as pd
from numpy.random import default_rng
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import cm

from patsy import dmatrices
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=FutureWarning)
    import statsmodels.api as sm

import esm
import torch

from Bio import SeqIO
from Bio import Phylo

In [2]:
import sys
sys.path.append('/content/drive/MyDrive/data')

import model_finetune
import data
import utils

In [3]:
pfam_family = "PF00004"
MAX_DEPTH = 600
n_layers = n_heads = 12
msas_folder = pathlib.Path("/content/drive/MyDrive/data/subsampled_msa")
dists_folder = pathlib.Path("/content/drive/MyDrive/data/distance_matrix")

# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase) # Making dictionary where each lowercase ascii letter is key and value is set to None
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequences. Needed to load aligned sequences in an MSA."""
    return sequence.translate(translation)

def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
    """ Reads the first nseq sequences from an MSA file in fasta format, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]

In [6]:
msa_family = read_msa(msas_folder / f"{pfam_family}_subtree.fasta", MAX_DEPTH)
dists_family = np.load(dists_folder / f"{pfam_family}_subtree.npy")
ratio_train_test = 0.8
ratio_train_val = 0.1

In [8]:
pfam_families = [
    "PF00004",
    "PF00005",
    "PF00041",
    "PF00072",
    "PF00076",
    "PF00096",
    "PF00153",
    "PF00271",
    "PF00397",
    "PF00512",
    "PF00595",
    "PF01535",
    "PF02518",
    "PF07679",
    "PF13354"
]
train_data, val_data, test_data = data.train_val_test_split(pfam_families, ratio_train_test, ratio_train_val, MAX_DEPTH, msas_folder, dists_folder)

In [10]:
dataset = data.CustomDataset(train_data, 32)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=None)

In [24]:
model = model_finetune.FineTuneMSATransformer()

model = model.eval()
with torch.no_grad():
    for batch_sequences, batch_dists in data_loader:
        prediction = model(batch_sequences)
        break

In [19]:
store_target_modules, store_modules_to_save = utils.get_target_save_modules(model)

In [20]:
store_modules_to_save

['layers.finetune_linear_0',
 'layers.finetune_linear_1',
 'layers.finetune_linear_2',
 'layers.finetune_linear_3']

In [23]:
config = peft.LoraConfig(r=8, target_modules=store_target_modules, modules_to_save=store_modules_to_save)
peft_model = peft.get_peft_model(model, config)

In [22]:
peft_model.print_trainable_parameters()

trainable params: 2,360,745 || all params: 117,978,628 || trainable%: 2.0010
