In [1]:
import torch
import numpy as np

dataset = "MPNN"
method = "AlphaDesign"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def loss_nll_flatten(S, log_probs):
    """ Negative log probabilities """
    criterion = torch.nn.NLLLoss(reduction='none')
    loss = criterion(log_probs, S)
    loss_av = loss.mean()
    return loss, loss_av

def get_metric(S, log_probs):
    nll_loss, _ = loss_nll_flatten(S, log_probs)
    chain_mask = torch.ones_like(nll_loss)
    loss = torch.sum(nll_loss * chain_mask).cpu().data.numpy()
    weight = torch.sum(chain_mask).cpu().data.numpy()
    return {"loss":loss, "weight":weight}

# 1. Results on test sets of 'CATH4.2', 'CATH4.3', 'MPNN'

In [7]:
def summary_perp_recovery(result):
    import torch.nn as nn
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    
    all_conf = []
    all_rec = []
    all_lens = []
    for i in range(len(result['title'])):
        recovery = (result['true_seq'][i] == result['pred_probs'][i].argmax(dim=1)).float().mean()
        
        # loss = criterion(torch.log(result['pred_probs'][i]), result['true_seq'][i])
        conf = result['pred_probs'][i].max(dim=-1)[0].mean()
        

        all_conf.append(conf.item())
        all_rec.append(recovery.item())
        all_lens.append(len(result['true_seq'][i]))
    
    # print(torch.exp(torch.cat(all_perp).mean()))
    all_conf = np.array(all_conf)
    all_rec = np.array(all_rec)
    all_lens = np.array(all_lens)

    summary = {}
    if dataset in ['CATH4.2', 'CATH4.3']:
        summary['conf(0,100)'] = np.median(all_conf[all_lens<100])
        summary['conf(100,300)'] = np.median(all_conf[(100<=all_lens)&(all_lens<300)])
        summary['conf(300, 500)'] = np.median(all_conf[(300<=all_lens)&(all_lens<500)])
        summary['conffull'] = np.median(all_conf)
        
        summary['rec(0,100)'] = np.median(all_rec[all_lens<100])
        summary['rec(100,300)'] = np.median(all_rec[(100<=all_lens)&(all_lens<300)])
        summary['rec(300, 500)'] = np.median(all_rec[(300<=all_lens)&(all_lens<500)])
        summary['recfull'] = np.median(all_rec)
        
    if dataset in ['MPNN']:
        summary['conf(0,100)'] = np.median(all_conf[all_lens<100])
        summary['conf(100,500)'] = np.median(all_conf[(100<=all_lens)&(all_lens<500)])
        summary['conf(500, 1000)'] = np.median(all_conf[(500<=all_lens)&(all_lens<1000)])
        summary['conffull'] = np.median(all_conf)
        
        summary['rec(0,100)'] = np.median(all_rec[all_lens<100])
        summary['rec(100,500)'] = np.median(all_rec[(100<=all_lens)&(all_lens<500)])
        summary['rec(500, 1000)'] = np.median(all_rec[(500<=all_lens)&(all_lens<1000)])
        summary['recfull'] = np.median(all_rec)
    return summary

In [35]:
for dataset in ['CATH4.2', 'CATH4.3', 'MPNN']:
    for method in ['StructGNN', 'GraphTrans', 'GCA', 'GVP', 'AlphaDesign', 'ProteinMPNN', 'PiFold', 'KWDesign']:
        result = torch.load(f"/gaozhangyang/experiments/OpenCPD/results/{dataset}/{method}/results.pt")
        summary = summary_perp_recovery(result)
        summary_string = '\t'.join('{:.2f}'.format(x) for x in list(summary.values()))
        
        print("data {} method {} \t result {}".format(dataset, method, summary_string))

data CATH4.2 method StructGNN 	 result 0.31	0.45	0.45	0.43	0.26	0.36	0.36	0.35
data CATH4.2 method GraphTrans 	 result 0.31	0.43	0.43	0.43	0.25	0.35	0.35	0.34
data CATH4.2 method GCA 	 result 0.34	0.46	0.47	0.45	0.27	0.38	0.38	0.37
data CATH4.2 method GVP 	 result 0.40	0.52	0.53	0.51	0.28	0.40	0.41	0.39
data CATH4.2 method AlphaDesign 	 result 0.36	0.49	0.49	0.47	0.33	0.43	0.44	0.42
data CATH4.2 method ProteinMPNN 	 result 0.38	0.51	0.52	0.50	0.32	0.47	0.47	0.45
data CATH4.2 method PiFold 	 result 0.44	0.58	0.60	0.57	0.39	0.53	0.56	0.52
data CATH4.2 method KWDesign 	 result 0.50	0.68	0.72	0.67	0.44	0.62	0.66	0.61
data CATH4.3 method StructGNN 	 result 0.35	0.41	0.47	0.41	0.30	0.34	0.40	0.34
data CATH4.3 method GraphTrans 	 result 0.37	0.42	0.48	0.42	0.29	0.34	0.39	0.34
data CATH4.3 method GCA 	 result 0.38	0.43	0.49	0.43	0.32	0.36	0.41	0.36
data CATH4.3 method GVP 	 result 0.45	0.51	0.55	0.50	0.33	0.38	0.45	0.38
data CATH4.3 method AlphaDesign 	 result 0.41	0.48	0.53	0.47	0.37	0.43	0.4

FileNotFoundError: [Errno 2] No such file or directory: '/gaozhangyang/experiments/OpenCPD/results/MPNN/KWDesign/results.pt'

# 2. Results on CASP15

In [13]:

def summary_perp_recovery(result):
    all_conf = []
    all_rec = []
    all_class = []
    for i in range(len(result['title'])):
        recovery = (result['true_seq'][i] == result['pred_probs'][i].argmax(dim=1)).float().mean()
        conf = result['pred_probs'][i].max(dim=-1)[0].mean()
        
        all_conf.append(conf.item())
        all_rec.append(recovery.item())
        all_class.append(result['classification'][i])
    
    # print(torch.exp(torch.cat(all_perp).mean()))
    all_conf = np.array(all_conf)
    all_rec = np.array(all_rec)
    all_class = np.array(all_class)
    # print(set(all_class))
    # print(np.sort(all_rec))

    summary = {}
    mask = np.array([ one in ['FM'] for one in all_class])
    summary['FM'] = np.median(all_conf[mask])
    mask = np.array([ one not in ['FM'] for one in all_class])
    summary['TBM'] = np.median(all_conf[mask])
    mask = np.array([ one in ['TBM-easy'] for one in all_class])
    summary['TBM-easy'] = np.median(all_conf[mask])
    mask = np.array([ one in ['TBM-hard'] for one in all_class])
    summary['TBM-hard'] = np.median(all_conf[mask])
    summary['Full'] = np.median(all_conf)
    
    mask = np.array([ one in ['FM'] for one in all_class])
    summary['rec FM'] = np.median(all_rec[mask])
    mask = np.array([ one not in ['FM/TBM'] for one in all_class])
    summary['rec TBM'] = np.median(all_rec[mask])
    mask = np.array([ one in ['TBM-easy'] for one in all_class])
    summary['rec TBM-easy'] = np.median(all_rec[mask])
    mask = np.array([ one in ['TBM-hard'] for one in all_class])
    summary['rec TBM-hard'] = np.median(all_rec[mask])
    summary['rec Full'] = np.median(all_rec)
    return summary

In [3]:
result = torch.load(f"/gaozhangyang/experiments/OpenCPD/results/{dataset}/{method}/results_casp15.pt")

In [10]:
print(sum([one =='FM' for one in result['classification']]))
print(sum([one =='FM/TBM' for one in result['classification']]))
print(sum([one =='TBM-easy' for one in result['classification']]))
print(sum([one =='TBM-hard' for one in result['classification']]))

16
2
20
5


In [14]:
for dataset in ['CATH4.2', 'CATH4.3', 'MPNN']:
    for method in ['StructGNN', 'GraphTrans', 'GCA', 'GVP', 'AlphaDesign', 'ProteinMPNN', 'PiFold', 'KWDesign']:
        result = torch.load(f"/gaozhangyang/experiments/OpenCPD/results/{dataset}/{method}/results_casp15.pt")
        summary = summary_perp_recovery(result)
        summary_string = '\t'.join('{:.2f}'.format(x) for x in list(summary.values()))
        
        print("data {} method {} \t result {}".format(dataset, method, summary_string))

data CATH4.2 method StructGNN 	 result 0.41	0.46	0.48	0.43	0.45	0.35	0.35	0.38	0.35	0.35
data CATH4.2 method GraphTrans 	 result 0.39	0.45	0.46	0.42	0.44	0.33	0.36	0.37	0.36	0.36
data CATH4.2 method GCA 	 result 0.48	0.52	0.53	0.48	0.50	0.39	0.40	0.41	0.38	0.40
data CATH4.2 method GVP 	 result 0.48	0.49	0.50	0.50	0.49	0.37	0.39	0.42	0.39	0.39
data CATH4.2 method AlphaDesign 	 result 0.44	0.49	0.50	0.46	0.48	0.41	0.42	0.46	0.41	0.42
data CATH4.2 method ProteinMPNN 	 result 0.49	0.52	0.53	0.51	0.52	0.44	0.44	0.46	0.40	0.44
data CATH4.2 method PiFold 	 result 0.52	0.56	0.59	0.53	0.55	0.47	0.47	0.50	0.47	0.47
data CATH4.2 method KWDesign 	 result 0.55	0.66	0.70	0.62	0.64	0.49	0.55	0.59	0.55	0.54
data CATH4.3 method StructGNN 	 result 0.40	0.44	0.45	0.43	0.44	0.35	0.36	0.38	0.37	0.36
data CATH4.3 method GraphTrans 	 result 0.39	0.45	0.46	0.43	0.45	0.35	0.35	0.37	0.35	0.35
data CATH4.3 method GCA 	 result 0.46	0.49	0.51	0.44	0.48	0.37	0.41	0.43	0.40	0.41
data CATH4.3 method GVP 	 result 0.47

FileNotFoundError: [Errno 2] No such file or directory: '/gaozhangyang/experiments/OpenCPD/results/MPNN/KWDesign/results_casp15.pt'

# 2. Results on noisy data

In [3]:
def summary_perp_recovery(result):
    import torch.nn as nn
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    
    all_conf = []
    all_rec = []
    all_lens = []
    for i in range(len(result['title'])):
        recovery = (result['true_seq'][i] == result['pred_probs'][i].argmax(dim=1)).float().mean()
        
        # loss = criterion(torch.log(result['pred_probs'][i]), result['true_seq'][i])
        conf = result['pred_probs'][i].max(dim=-1)[0].mean()
        

        all_conf.append(conf.item())
        all_rec.append(recovery.item())
        all_lens.append(len(result['true_seq'][i]))
    
    # print(torch.exp(torch.cat(all_perp).mean()))
    all_conf = np.array(all_conf)
    all_rec = np.array(all_rec)
    all_lens = np.array(all_lens)

    summary = {}
    if dataset in ['CATH4.2', 'CATH4.3']:
        summary['conf(0,100)'] = np.median(all_conf[all_lens<100])
        summary['conf(100,300)'] = np.median(all_conf[(100<=all_lens)&(all_lens<300)])
        summary['conf(300, 500)'] = np.median(all_conf[(300<=all_lens)&(all_lens<500)])
        summary['conffull'] = np.median(all_conf)
        
        summary['rec(0,100)'] = np.median(all_rec[all_lens<100])
        summary['rec(100,300)'] = np.median(all_rec[(100<=all_lens)&(all_lens<300)])
        summary['rec(300, 500)'] = np.median(all_rec[(300<=all_lens)&(all_lens<500)])
        summary['recfull'] = np.median(all_rec)
        
    if dataset in ['MPNN']:
        summary['conf(0,100)'] = np.median(all_conf[all_lens<100])
        summary['conf(100,500)'] = np.median(all_conf[(100<=all_lens)&(all_lens<500)])
        summary['conf(500, 1000)'] = np.median(all_conf[(500<=all_lens)&(all_lens<1000)])
        summary['conffull'] = np.median(all_conf)
        
        summary['rec(0,100)'] = np.median(all_rec[all_lens<100])
        summary['rec(100,500)'] = np.median(all_rec[(100<=all_lens)&(all_lens<500)])
        summary['rec(500, 1000)'] = np.median(all_rec[(500<=all_lens)&(all_lens<1000)])
        summary['recfull'] = np.median(all_rec)
    return summary

In [4]:
dataset = 'CATH4.3'
for eps in [0.02, 0.2, 0.5, 1.0]:
    for method in ['StructGNN', 'GraphTrans', 'GCA', 'GVP', 'AlphaDesign', 'ProteinMPNN', 'PiFold', 'KWDesign']:
        result = torch.load(f"/gaozhangyang/experiments/OpenCPD/results/{dataset}/{method}_{eps}/results.pt")
        summary = summary_perp_recovery(result)
        summary_string = '\t'.join('{:.2f}'.format(x) for x in list(summary.values()))
        
        print("data {} method {} \t result {}".format(dataset, method, summary_string))

data CATH4.3 method StructGNN 	 result 0.27	0.26	0.28	0.27	0.19	0.20	0.21	0.20
data CATH4.3 method GraphTrans 	 result 0.26	0.26	0.27	0.26	0.19	0.19	0.20	0.20
data CATH4.3 method GCA 	 result 0.25	0.25	0.26	0.25	0.19	0.19	0.20	0.19
data CATH4.3 method GVP 	 result 0.47	0.53	0.56	0.52	0.32	0.37	0.43	0.38
data CATH4.3 method AlphaDesign 	 result 0.16	0.16	0.15	0.16	0.18	0.18	0.18	0.18
data CATH4.3 method ProteinMPNN 	 result 0.31	0.30	0.32	0.31	0.22	0.23	0.25	0.23
data CATH4.3 method PiFold 	 result 0.28	0.29	0.32	0.29	0.26	0.28	0.29	0.28
data CATH4.3 method KWDesign 	 result 0.33	0.42	0.45	0.40	0.29	0.37	0.41	0.35
