In [None]:
from transformers import T5Tokenizer, T5Config, AutoConfig, T5ForConditionalGeneration

from model.msa_shlab import MSAT5,T5Stack
from typing import Sequence, Tuple, List, Union
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, Dataset
from torch.utils.data.distributed import DistributedSampler


import torch
import torch.nn as nn
import torch.nn.functional as F
# textokenizer=T5Tokenizer.from_pretrained('t5-small')

In [2]:
from data.msa_dataset import MSADataSet

## Loading & Processing Data


In [None]:
import matplotlib.pyplot as plt
import torch
import os
from Bio import SeqIO
import itertools
from typing import Sequence, Tuple, List, Union
import string
import glob

In [None]:
class MSADataSet(Dataset):
    def __init__(self,data_path,seq_per_msa=3,num_files=1000):
        deletekeys = dict.fromkeys(string.ascii_lowercase)
        deletekeys["."] = None
        deletekeys["*"] = None
        self.translation = str.maketrans(deletekeys)
        
        data_path_list=glob.glob(data_path+'/*/*.a3m')[:num_files]
        msa_data=[self.read_msa(data_path,seq_per_msa*2) for data_path in data_path_list]
        # filter out msa dosen't meet seq_per_msa requirement
        msa_data=[i for i in msa_data if (len(i)==seq_per_msa*2 and self.check_len(i))]
        self.src = [msa[:seq_per_msa]  for msa in msa_data]
        self.tgt = [msa[seq_per_msa:]  for msa in msa_data]
       
    def __getitem__(self, index):
        return {"src":self.src[index],"tgt":self.tgt[index]}
        # return self.msa_data[index]
    def __len__(self):
        return len(self.src)
    def remove_insertions(self,sequence) :
        """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
        return sequence.translate(self.translation)
    def read_msa(self,filename, nseq) :
        """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""
        return [(record.description, self.remove_insertions(str(record.seq)))
                    for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
    def check_len(self,msa):
        #check if all sequence in a msa has the same length
        l=set([len(x[1]) for x in msa])
        return len(l)==1


config=T5Config.from_pretrained('config/')
tokenizer=T5Tokenizer.from_pretrained('config/')
seq_per_msa=15
config=T5Config.from_pretrained('./config/')
tokenizer=T5Tokenizer.from_pretrained('./config/')
epochs=50
data_path='/user/sunsiqi/zl/T5/CASP_msa'
train_dataset=MSADataSet(data_path,num_files=10000,seq_per_msa=seq_per_msa)
# eval_dataset=MSADataSet(data_path,num_files=100,seq_per_msa=seq_per_msa)


In [None]:
class MSADataSet(Dataset):
    def __init__(self,data_path=None,src_seq=None,num_files=None,total_seq=None,recur=True,data_path_list=None):
        deletekeys = dict.fromkeys(string.ascii_lowercase)
        deletekeys["."] = None
        deletekeys["*"] = None
        self.translation = str.maketrans(deletekeys)
        if data_path_list is None:
            if num_files is not None:
                if recur:
                    data_path_list=glob.glob(data_path+'/*/*.a3m')[:num_files]
                else:
                    data_path_list=glob.glob(data_path+'/*.a3m')[:num_files]
            else:
                # logger.warning('Train on all msa data within given data path')
                if recur:
                    data_path_list=glob.glob(data_path+'/*/*.a3m')
                else:
                    data_path_list=glob.glob(data_path+'/*.a3m')
                    print(data_path_list)
        msa_data=[self.read_msa(data_path,total_seq) for data_path in data_path_list]
        # filter out msa dosen't meet seq_per_msa requirement
        msa_data=[i for i in msa_data if (len(i)==total_seq and self.check_len(i))]
        self.src = [msa[:src_seq]  for msa in msa_data]
        self.tgt = [msa[src_seq:]  for msa in msa_data]
       
    def __getitem__(self, index):
        return {"src":self.src[index],"tgt":self.tgt[index]}
        # return self.msa_data[index]
    def __len__(self):
        return len(self.src)
    def remove_insertions(self,sequence) :
        """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
        return sequence.translate(self.translation)
    def read_msa(self,filename, nseq) :
        """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""
        return [(record.description, self.remove_insertions(str(record.seq)))
                    for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
    def check_len(self,msa):
        #check if all sequence in a msa has the same length
        l=set([len(x[1]) for x in msa])
        return len(l)==1
dataset=MSADataSet('/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/src_3_all_13/test/',src_seq=3,num_files=None,total_seq=13,recur=False)
len(dataset)

In [None]:
dataset[0]['src']

In [None]:
dataset=MSADataSet(src_seq=3,num_files=None,total_seq=13,recur=False,data_path_list=['/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/src_3_all_13/total/T1046s2-D1_all.a3m'])

In [None]:
dataset[0]['src'][1]

In [None]:
from collections import Counter
src_c=Counter()
tgt_c=Counter()
for msa in dataset:
    for seq in msa['src']:
        # print(seq[1],len(seq[1]))
        text=seq[1]
        # print(Counter(text))
        src_c+=Counter(text)
    for seq in msa['tgt']:
        text=seq[1]
        tgt_c+=Counter(text)

    

In [None]:
import numpy as np
src_c=sorted(src_c,key=lambda x:x[1],reverse=True)
tgt_c=sorted(tgt_c,key=lambda x:x[1],reverse=True)
labels, values =zip(*src_c)
indexes = np.arange(len(labels))
plt.xticks(indexes , labels)
plt.bar(indexes, values)
plt.xlabel('class')
plt.ylabel('number')
# plt.title(" {} MSA's src files({} sequences)".format(len(dataset),seq_per_msa))
plt.show()

labels, values =zip(*tgt_c)
indexes = np.arange(len(labels))
plt.bar(indexes, values)
plt.xticks(indexes , labels)
plt.xlabel('class')
plt.ylabel('number')
# plt.title("  {} MSA's tgt files({} sequences)".format(len(train_dataset),seq_per_msa))
plt.show()


In [None]:
RawMSA = Sequence[Tuple[str, str]]
class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, tokenizer,max_len=512):
        self.max_len=max_len-1
        self.tokenizer = tokenizer
    def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        batch_labels, seq_str_list = zip(*raw_batch)
        seq_encoded_list = [self.tokenizer(self._tokenize(seq_str),truncation=True,max_length=self.max_len+1).input_ids for seq_str in seq_str_list]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        tokens = torch.empty(
            (
                batch_size,
                max_len 
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.tokenizer.pad_token_id)
        labels = []
        strs = []
        for i, (label, seq_str, seq_encoded) in enumerate(
            zip(batch_labels, seq_str_list, seq_encoded_list)
        ):
            labels.append(label)
            strs.append(seq_str)
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[i,0:len(seq_encoded)] = seq

        return labels, strs, tokens
    def _tokenize(self,sequence):
        return ' '.join(list(sequence)) 
class DataCollatorForMSA(BatchConverter):
    def msa_batch_convert(self, inputs: Union[Sequence[RawMSA], RawMSA]):
        # RawMSA: Sequence[label:str,acid_seq:str]
        if isinstance(inputs[0][0], str):
            # Input is a single MSA
            raw_batch: Sequence[RawMSA] = [inputs]  # type: ignore
        else:
            raw_batch = inputs  # type: ignore

        batch_size = len(raw_batch)
        max_alignments = max(len(msa) for msa in raw_batch) #MSA的数量
        max_seqlen_msa = max(len(msa[0][1]) for msa in raw_batch) # MSA的每个序列长度
        max_seqlen=min(max_seqlen_msa,self.max_len)+1 #加一是为了凑齐每句话结尾有一个/s
        tokens = torch.empty(
            (
                batch_size,
                max_alignments,
                max_seqlen,
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.tokenizer.pad_token_id)
        labels = []
        strs = []

        for i, msa in enumerate(raw_batch):
            msa_seqlens = set(len(seq) for _, seq in msa)
            if not len(msa_seqlens) == 1:
                raise RuntimeError(
                    "Received unaligned sequences for input to MSA, all sequence "
                    "lengths must be equal."
                )
            msa_labels, msa_strs, msa_tokens = super().__call__(msa)
            msa_len=msa_tokens.size(1)
            msa_tokens=msa_tokens[:,:min(msa_len,max_seqlen)]
            labels.append(msa_labels)
            strs.append(msa_strs)
            tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
        return tokens
    def __call__(self,batch): 
        input_ids=self.msa_batch_convert([example["src"] for example in batch])
        attention_mask=input_ids.ne(self.tokenizer.pad_token_id).type_as(input_ids)
        labels=self.msa_batch_convert([example["tgt"] for example in batch])
        decoder_attention_mask=labels.ne(self.tokenizer.pad_token_id).type_as(input_ids)
        labels[labels==self.tokenizer.pad_token_id]=-100
        # labels[labels==128]=-100
        return {'input_ids':input_ids,'labels':labels,"attention_mask":attention_mask,"decoder_attention_mask":decoder_attention_mask}
msadata_collator=DataCollatorForMSA(tokenizer,max_len=512)
batch_size=2
msa_dataloader = DataLoader(train_dataset, batch_size=batch_size,collate_fn=msadata_collator)

In [None]:
for i in train_dataset:
    fl=[]
    for i in i['src']:
        fl.append(i[0])
        fl.append(i[1])

    print (fl)
    break

In [None]:
model=MSAT5(config).to('cuda:0')

In [None]:
src=train_dataset[0]['src']
src=msadata_collator.msa_batch_convert(src).to('cuda:0')
src.size(2)

In [None]:
train_dataset[0]['src']

In [None]:
output=model.generate(src,do_sample=True,top_k=5,top_p=0.95,max_length=src.size(2))

In [None]:
output.shape

In [None]:
tokenizer.decode(output[0][0],skip_special_tokens=True)

In [None]:
print(src,src.shape)
output=model.generate(src)
print(output.shape)
for i in output[0]:
    print(tokenizer.decode(i))
    break

In [None]:
train_dataset[0]['src'][0],src,src.shape,tokenizer.decode(src[0][0],skip_special_tokens=True)

In [None]:
s=tokenizer.decode(src[0][0],skip_special_tokens=True).replace(' ','')
s

#### 测试encoder decoder的extended_attention_mask的创建异同

In [None]:
import copy
config.axial_attention=True
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder=True
decoder=T5Stack(decoder_config)
encoder_config=copy.deepcopy(config)
encoder=T5Stack(encoder_config)

In [None]:
msa_input_ids=msa_batch_converter.msa_batch_convert(msa_dataset[0]['src'])[0][:2,100:105]
attention_mask=msa_input_ids.ne(tokenizer.pad_token_id)
msa_input_ids.shape,attention_mask.shape

In [None]:
decoder.is_decoder=False
extended_attention_mask_decoder=decoder.get_extended_attention_mask(attention_mask,msa_input_ids.shape,'cpu')
extended_attention_mask_encoder=encoder.get_extended_attention_mask(attention_mask,msa_input_ids.shape,'cpu')
print('-'*20,'input','-'*20)
print('input_ids shape: {}\ninput_ids:\n{}'.format(msa_input_ids.shape,msa_input_ids))
print('-'*20,'for decoder','-'*20)
print('decoder extended attention mask shape: {}\ndecoder extended attention mask value:\n{}'.format(extended_attention_mask_decoder.shape,extended_attention_mask_decoder))
print('-'*20,'for encoder','-'*20)
print('encoder extended attention mask shape: {}\nencoder extended attention mask value:\n{}'.format(extended_attention_mask_encoder.shape,extended_attention_mask_encoder))

generate时，input_ids为1，则生成非caulsal的(类似encoder)

In [None]:
decoder.get_extended_attention_mask(attention_mask,torch.randint(10,(2,1)).shape,'cpu')

In [None]:
import random
class MSADataSet(Dataset):
    def __init__(self,data_path,num_msa_files=None,num_alignments=3):
        deletekeys = dict.fromkeys(string.ascii_lowercase)
        deletekeys["."] = None
        deletekeys["*"] = None
        src_seq_per_msa='random'
        src_seq_per_msa_l=2
        src_seq_per_msa_u=10
        total_seq_per_msa=30
        self.translation = str.maketrans(deletekeys)
        if num_msa_files is not None:
            # train on small dataset
            data_path_list=glob.glob(data_path+'/*/*.a3m')[:num_msa_files]
        else:
            # train on full dataset
            data_path_list=glob.glob(data_path+'/*/*.a3m')
        
        msa_data=[self.read_msa(data_path,total_seq_per_msa) for data_path in data_path_list]
        msa_data=[i for i in msa_data if (len(i)==total_seq_per_msa and self.check_len(i))]
        print(len(msa_data))
        self.src = [msa[:src_seq_per_msa if isinstance(src_seq_per_msa, int) else random.randint(src_seq_per_msa_l,src_seq_per_msa_u)] for msa in msa_data]
        print(len(self.src))
        tgt_seq_num_list=[len(src) for src in self.src]
        self.tgt = [msa[tgt_seq_per_msa:] for msa,tgt_seq_per_msa in zip(msa_data,tgt_seq_num_list)]     
    def __getitem__(self, index):
        return {"src":self.src[index],"tgt":self.tgt[index]}

    def __len__(self):
        return len(self.src)
    def remove_insertions(self,sequence) :
        """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
        return sequence.translate(self.translation)
    def read_msa(self,filename, nseq) :
        """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""
        return [(record.description, self.remove_insertions(str(record.seq)))
                    for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]
    def check_len(self,msa):
        #check if all sequence in a msa has the same length
        l=set([len(x[1]) for x in msa])
        return len(l)==1
data_path='dataset/'
msa_dataset=MSADataSet(data_path,num_msa_files=100)

In [None]:
train_size=int(0.2*len(msa_dataset))
test_size=len(msa_dataset)-train_size
train,test=torch.utils.data.random_split(msa_dataset,[train_size,test_size])

In [None]:
for i in range(len(msa_dataset)):
    print(len(msa_dataset[i]['src']),len(msa_dataset[i]['tgt']))
    assert len(msa_dataset[i]['src'])+len(msa_dataset[i]['tgt'])==30

In [None]:
msa_dataset[4]

In [None]:
from torch.nn import CrossEntropyLoss
import torch
l=CrossEntropyLoss()

In [None]:
a=torch.randn(3,2,205,130)
b=torch.randint(130,(3,2,205))
b=torch.argmax(a,-1)
c=torch.randint(10,(2,5))

In [None]:
torch.sum(b==torch.argmax(a,-1))/(b.size(0)*b.size(1)*b.size(2))

In [None]:
l(a.view(-1,130),b.view(-1))

In [None]:
tl=0
for i,j in zip(a,b):
    print(i.shape,j.shape)
    lo=l(i.view(-1,130),j.view(-1))
    print(lo)
    tl+=lo
print(tl/3)

In [None]:
torch.sum(c).item()
a.shape

In [None]:
print('-'*75)
print('|',' '*29,'New forward step',' '*24,'|')
print('-'*75)
a=torch.randn(2,43,4,5,6)
print('| %-.25s'%'dsaddsdssadasdsainput_ids shape is:','%47s'%'{}|'.format(a.shape if a is not None else None))
print('| %-25s'%' is:','%47s'%'{}|'.format(c.shape))

# 测试logitsProcessor

In [None]:
from transformers.generation_logits_process import LogitsProcessorList,TopKLogitsWarper,TopPLogitsWarper
import torch

In [None]:
top_k=5
top_p=0.92
num_beams=1
warpers = LogitsProcessorList()
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
logits=torch.randn(2,2,1,130)
next_token_logits = logits[:,:, -1, :]
next_token_scores=next_token_logits
next_token_scores = warpers(input_ids, next_token_scores)

### TOPK

In [None]:
a=torch.randn(2,4,1,130)
b=torch.randint(10,(2,3))
b,torch.topk(b,2),torch.topk(b,2)[0],torch.topk(b,2)[0][...,-1,None]

In [None]:
indices=a < torch.topk(a, 5)[0][..., -1, None]
indices[0][0]

### TOPP

In [None]:
scores=torch.randn(2,2,5)
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
print('sorted_indices: \n',sorted_indices,sorted_indices.shape)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
print('cumulative_probs: \n',cumulative_probs,cumulative_probs.shape)
sorted_indices_to_remove = cumulative_probs > 0.92
print('sorted_indices_to_remove: \n',sorted_indices_to_remove,sorted_indices_to_remove.shape)
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.stack([i.scatter(1, j, i) for (i,j) in zip(sorted_indices_to_remove,sorted_indices)])
print('indices_to_remove: \n',indices_to_remove,indices_to_remove.shape)

In [None]:
import torch.nn as nn
probs = nn.functional.softmax(scores, dim=-1)
torch.stack([torch.multinomial(i,num_samples=1) for i in probs ]).shape

In [None]:
torch.randn(2,2,1)[:,:,None].shape

In [None]:
probs=torch.randn(1,10,10)+10
next_tokens=torch.stack([torch.multinomial(i,num_samples=1) for i in probs ],dim=0)
next_tokens.shape

In [None]:
torch.multinomial(probs[0],num_samples=1).shape

In [None]:
import torch
a=torch.tensor([1,2,3,2,5])
mask=(a!=2) & (a!=3)
a[mask]


In [None]:
import os
import glob

In [None]:
casp14_fasta_path='/share/wangsheng/train_test_data/CASP_RawData/CASP14_RawData/CASP14DM_SEQ/'
casp14_name_list=[file.split('.')[0] for file in os.listdir(casp14_fasta_path)]
all_msa_path='/user/sunsiqi/zl/T5/CASP_msa/allDM_msa/'
for file_name in casp14_name_list:
    file_path=all_msa_path+file_name+'.a3m'
    

In [None]:
from Bio import SeqIO
a=SeqIO.parse('/user/sunsiqi/zl/T5/CASP_msa/allDM_msa/T1093-D3.a3m', "fasta")

In [None]:
with open('/user/sunsiqi/zl/T5/CASP_msa/allDM_msa/T1093-D3.a3m','r') as f:
    context=f.readlines()[:4]
    print(context,len(context))
    print("".join(context))
    with open('/user/sunsiqi/zl/T5/AF2TEST/x.a3m','w') as fw:
        fw.write("".join(context))

In [None]:
pred = "/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_3_all_13/source"
for dir in sorted(os.listdir(pred)):

In [None]:
import os
pred = "/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_3_all_13/source"
for dir in sorted(os.listdir(pred)):
    print(dir.split('_')[0])
    dsa



In [None]:
a='AF2TEST/CASP14/src_3_all_13/generate_10'
b=a.split('/')
b.insert(2,'output')
b=os.path.join(*b)+'/'
b

In [None]:
generated_msa_dir='/user/sunsiqi/zl/T5/AF2TEST/CASP14/src_3_all_13/generate_10/'
pdboutdir=generated_msa_dir.split('/')
print(pdboutdir)
pdboutdir.insert(7,'output')
pdboutdir=os.path.join(*pdboutdir)
pdboutdir

In [None]:

def eval_iddt(pred,outdir):
    lddt = "/share/wangsheng/GitBucket/Fast_lDDT/Fast_lDDT"
    seqdir = "/share/wangsheng/train_test_data/CASP_RawData/allDM_SEQ/"
    native = "/share/wangsheng/train_test_data/CASP_RawData/allDM_Native/"
    
    # outdir = "pred-lddt-iter3-log/"
    # outdir="/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_3_all_13/pred-lddt/"
    #pred = "/share/liyu/hl/fold-result-tmp/tmp_res_casp13_fasta/1000000/"
    # pred = "/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_3_all_13/generate_10/"

    os.makedirs(outdir,exist_ok=True)
    with open(os.path.join(outdir,"{}.csv".format(pred.split('/')[-2])), 'w+') as fp:
        fp.write("name,result\n")
        for dir in sorted(os.listdir(pred)):
            dir_noext=dir.split('_')[0]
            p = subprocess.Popen(
                [lddt, 
                '-i', os.path.join(seqdir, "%s.seq"%dir_noext), 
                '-n', os.path.join(native, "%s.pdb"%dir_noext), 
                '-m', os.path.join(pred, dir, 'ranked_0.pdb'),
                '-v', '1'],
                #shell=True, 
                stdout=subprocess.PIPE, 
                stderr=subprocess.STDOUT
            )
            res = p.stdout.readlines()[-1].decode().split(' ')[2]
            fp.write("%s,%s\n"%(dir, res))
            ret = p.wait()
    print("finish %s"%(dir))

In [None]:
import subprocess

eval_iddt(outdir='/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_5_all_15/pred-lddt/source_sorted/',pred='/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/src_5_all_15/source_sorted/')

# 测试多轮增强

In [19]:
import os
import json
import numpy as np

In [34]:
result_dir='/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/msa_l1_u50/predict/Gtime08-19-10:48_Rpen1_Gtimes1_f_0/'

Gsteps=os.listdir(result_dir)
#target {t1026:gstep1;t1034:gstep2}
caspfile_score={}
caspfiles=os.listdir(os.path.join(result_dir,Gsteps[0]))
keys=[example for example in caspfiles]
for key in keys:
    caspfile_score.update({key:[]})
for gstep in Gsteps:
    Gstep_path=os.path.join(result_dir,gstep)
    for caspfile in caspfiles:
        caspfile_ranking_path=os.path.join(Gstep_path,caspfile,'ranking_debug.json')
        scores=json.load(open(caspfile_ranking_path,'r'))
        score=scores['plddts'][scores['order'][0]]
        caspfile_score[caspfile].append((gstep,score))



In [36]:
caspfile_step={}
def highest_gstep(g_scores):
    gsteps=[x[0] for x in g_scores]
    scores=np.array([x[1] for x in g_scores])
    idx=np.argmax(scores)
    return gsteps[idx]
for key in caspfile_score:
    g_scores=caspfile_score[key]
    caspfile_step[key]=highest_gstep(g_scores)

In [50]:
def fetch_file_path(caspfile_step,result_dir):
    input_dir=result_dir.replace('output','input')
    highest_collection_path=[os.path.join(input_dir,gstep,caspfile+'.a3m') for caspfile,gstep in caspfile_step.items()]
    for path in highest_collection_path:
        assert os.path.exists(path)
    return highest_collection_path

In [None]:
fetch_file_path(caspfile_step,result_dir)

In [55]:
def fetch_best_generation(result_dir):
    Gsteps=os.listdir(result_dir)
    caspfile_score={}
    caspfiles=os.listdir(os.path.join(result_dir,Gsteps[0]))
    keys=[example for example in caspfiles]
    for key in keys:
        caspfile_score.update({key:[]})
    for gstep in Gsteps:
        Gstep_path=os.path.join(result_dir,gstep)
        for caspfile in caspfiles:
            caspfile_ranking_path=os.path.join(Gstep_path,caspfile,'ranking_debug.json')
            try:
                scores=json.load(open(caspfile_ranking_path,'r'))
                score=scores['plddts'][scores['order'][0]]
                caspfile_score[caspfile].append((gstep,score))
            except Exception:
                pass
                
           
    caspfile_step={}
    def highest_gstep(g_scores):
        gsteps=[x[0] for x in g_scores]
        scores=np.array([x[1] for x in g_scores])
        idx=np.argmax(scores)
        return gsteps[idx]
    for key in caspfile_score:
        g_scores=caspfile_score[key]
        caspfile_step[key]=highest_gstep(g_scores)
    input_dir=result_dir.replace('output','input')
    highest_collection_path=[os.path.join(input_dir,gstep,caspfile+'.a3m') for caspfile,gstep in caspfile_step.items()]
    for path in highest_collection_path:
        assert os.path.exists(path)
    return highest_collection_path
fetch_best_generation('/user/sunsiqi/zl/T5/AF2TEST/CASP14/output/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0')

['/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_1/T1093-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_3/T1068-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_1/T1038-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_0/T1099-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_4/T1026-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_2/T1082-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_3/T1064-D1_generate.a3m',
 '/user/sunsiqi/zl/T5/AF2TEST/CASP14/input/msa_l1_u50/predict/Gtime08-17-08:50_Rpen1_Gtimes5_f0/Gstep_3/