In [1]:

import sys
from transformers import T5Tokenizer, T5Config
from model.msa_augmentor import MSAT5
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from torch.utils.data import DataLoader, Dataset
from data.utils_inference import DataCollatorForMSA
from data.iteration import fetch_best_generation
import torch
from Bio import SeqIO
import itertools
import string
import glob
import torch
from tqdm import tqdm
import os 
import logging
import datetime
import torch.nn.functional as F
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config=T5Config.from_pretrained('./config/')
tokenizer=T5Tokenizer.from_pretrained('./config/')

In [6]:
model=MSAT5.from_pretrained('.r/checkpoints/msat5-base/checkpoint-740000')

In [3]:
class MSADataSet(Dataset):
    def __init__(self,data_dir=None,data_path_list=None,only_poor_msa=None):
        """
        data_dir: read all .a3m file in this dir
        data_path_list: read all .a3m file in this list
        """
        self.setup_traslation()
        if data_path_list is None:
            data_path_list=glob.glob(data_dir+'/*.a3m')

        self.msa_data={msa_file_path.split("/")[-1].split(".")[0]:self.read_msa(msa_file_path,9999) for msa_file_path in data_path_list}
        assert all([self.check_same_len(msa) for msa in self.msa_data.values()]),"all sequence in a msa should have the same length"
        if only_poor_msa:
            self.msa_data={k:v for k,v in self.msa_data.items() if len(v)<only_poor_msa}
        for k,v in self.msa_data.items():
            print(f"{k}: unique seq num:{len(set([x[1] for x in v]))} | total seq num:{len(v)}")
       
    def __getitem__(self, index):
        # if self.is_generation:
        #     return {"all":self.msa_data[index]}
        # return {"src":self.src[index],"tgt":self.tgt[index],"all":self.msa_data[index]}
        if index in self.msa_data:
            return {'msa_name':index,'msa':self.msa_data[index]}
        else:
            key=list(self.msa_data.keys())[index]
            return {'msa_name':key,'msa':self.msa_data[key]}
    def __len__(self):
        return len(self.msa_data)
    def setup_traslation(self):
        deletekeys = dict.fromkeys(string.ascii_lowercase)
        deletekeys["."] = None
        deletekeys["*"] = None
        self.translation = str.maketrans(deletekeys)
    def remove_insertions(self,sequence) :
        """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
        return sequence.translate(self.translation)
    def read_msa(self,filename, nseq) :
        """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""
        
        return [(record.description, self.remove_insertions(str(record.seq)))
                    for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
    def check_same_len(self,msa:List[Tuple[str,str]]):
        #check if all sequence in a msa has the same length
        l=set([len(x[1]) for x in msa])
        return len(l)==1

In [11]:
dataset=MSADataSet(data_dir='./dataset/casp15/cfdb/',only_poor_msa=20)
print(len(dataset))

T1113: unique seq num:10 | total seq num:13
T1119: unique seq num:2 | total seq num:5
T1122: unique seq num:1 | total seq num:3
T1125: unique seq num:15 | total seq num:18
T1129s2: unique seq num:16 | total seq num:18
T1130: unique seq num:1 | total seq num:2
T1131: unique seq num:1 | total seq num:2
T1178: unique seq num:15 | total seq num:16
T1194: unique seq num:14 | total seq num:16
9


In [108]:
for a3m_data in dataset:
    msa_name=a3m_data['msa_name']
    msa=a3m_data['msa']
    print(msa_name)
    print(msa)
    break

T1113
[('9', 'MTAVNYPFVDTMDKFDKITKGLIFEHQAEGESETMISHELSILDNDGVVHSLHFSQITSLIDTITGKHPSLELPPQLFLITQYLLEDLKEVGEKGFVITEYFIDVLPTGNKAIFRGTLAHKSTVDGHPDFDPSSTISKKEFEFSLNQFSILQQIALSHCIANLHEECAGFRGTFDVEYTFHWTPFAFNVKFSE'), ('A0A4D6BFJ2\t348\t1.00\t6.892E-106\t0\t192\t193\t0\t192\t193', 'MTAVNYPFVDTMDKFDKITKGLIFEHQAEGESETMISHELSILDNDGVVHSLHFSQITSLIDTITGKHPSLELPPQLFLITQYLLEDLKEVGEKGFVITEYFIDVLPTGNKAIFRGTLAHKSTVDGHPDFDPSSTISKKEFEFSLNQFSILQQIALSHCIANLHEECAGFRGTFDVEYTFHWTPFAFNVKFSE'), ('B3FJX8\t222\t0.306\t3.338E-62\t0\t188\t193\t0\t172\t179', 'MTLIAKPAVPDTEVLNHIGRDLLLKD--EDGSFKLQRH-LKILTEEGLVHQVSFAQVDGLLNILDSTRETPPCSPLQYLITHYDLKDLVELGKDGWLVPEYQVVVMHSSKTVRFEGKLTRVGSID-------------KEFTFALGGFDFIQQLSLARCIASLGKEFEQVIGTFDCTYVFKTGPDGISV----'), ('A0A7S6RAB0\t217\t0.315\t2.007E-60\t33\t177\t193\t29\t170\t194', '---------------------------------TGLRPSLNITTNDGVTRNVTFDQLENFIKVIKATHDTQTVGPLQWCIYRYA----PNLVKARYIATSWKITVDHVKKTLNFQGTLNHEGTVIDHPNYDPDITARELMFDETITDFSRFEQIALSGNIANFKEEFKMMKYSFDFSY---------------'), ('F8S

In [111]:
msadata_collator=DataCollatorForMSA(tokenizer,max_len=512)


In [115]:
msadata_collator.msa_batch_convert(msa).size()

torch.Size([1, 13, 194])