Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions analysis/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import subprocess



# in_dir = '/home/kevyan/generations/sequences/'
# out_dir = '/home/kevyan/generations/generations_to_embed/'
# directories = os.listdir(in_dir)
# for directory in directories:
# if 'sequence' in directory:
# print(directory)
# p = subprocess.run('cat ' + in_dir + directory + '/*' + '>' + out_dir + directory + '.fasta', shell=True)

# in_dir = '/home/kevyan/generations/generations_to_embed/'
# out_dir = '/home/kevyan/generations/proteinfer/'
# fastas = os.listdir(in_dir)
# for fasta in fastas:
# p = subprocess.run('python /home/kevyan/src/proteinfertorch/bin/get_embeddings.py --data-path %s --weights-dir samirchar/proteinfertorch-go-random-13731645 --num-embedding-partitions 1 --output-dir ~/generations/embeddings/%s/' %(in_dir + fasta, fasta[:-6]), shell=True)
#
# for fasta in fastas:
# p = subprocess.run('python /home/kevyan/src/ProtTrans/Embedding/prott5_embedder.py --input %s --output ~/generations/protbert/%s.h5 --model ProtBert-BFD --per_protein 1' %(in_dir + fasta, fasta[:-6]), shell=True)

in_dir = '/home/kevyan/generations/natural_sequences/'
out_dir = '/home/kevyan/generations/proteinfer/'
fastas = os.listdir(in_dir)
# for fasta in fastas:
# p = subprocess.run('python /home/kevyan/src/proteinfertorch/bin/get_embeddings.py --data-path %s --weights-dir samirchar/proteinfertorch-go-random-13731645 --num-embedding-partitions 1 --output-dir ~/generations/proteinfer/%s/' %(in_dir + fasta, fasta[:-6]), shell=True)

for fasta in fastas:
if 'gigaref' not in fasta:
continue
print(fasta)
p = subprocess.run('python /home/kevyan/src/ProtTrans/Embedding/prott5_embedder.py --input %s --output ~/generations/protbert/%s.h5 --model ProtBert-BFD --per_protein 1' %(in_dir + fasta, fasta[:-6]), shell=True)
86 changes: 86 additions & 0 deletions analysis/extract_test_fastas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import datetime
import os
from tqdm import tqdm


import numpy as np

import torch

from sequence_models.constants import OTHER_AAS, AMB_AAS
from dayhoff.utils import seed_everything
from dayhoff.datasets import UniRefDataset


# default to a single-GPU setup if not present
RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
DEVICE = torch.device(f"cuda:{RANK}")


seed_everything(0)


def generate() -> None:
data_seq_dir = '/mnt/data/protein/'
data_name = 'uniref50_202401'
split_names = ['valid', 'train', 'test', 'rtest']
n = 10000
for split_name in split_names:
print(data_name, split_name, datetime.datetime.now(), flush=True)
ds_train = UniRefDataset(os.path.join(data_seq_dir, data_name + '/'), split_name,
max_len=2048)
with open(os.path.join('/mnt/checkpoints/evodiff/generations/', data_name + "_" + split_name + "_10k.fasta"), 'w') as f:
idx = np.arange(len(ds_train))
np.random.shuffle(idx)
successes = 0
i = -1
with tqdm(total=n) as pbar:
while successes < n:
i += 1
seq = ds_train[idx[i]][0]
for aa in OTHER_AAS + AMB_AAS:
if aa in seq:
break
else:
f.write(">%d\n" %i)
f.write(seq + "\n")
successes += 1
pbar.update(1)

data_name = 'gigaref'
split_names = ['train', 'test']
for split_name in split_names:
print(data_name, split_name, datetime.datetime.now(), flush=True)
ds_train = UniRefDataset(data_seq_dir + data_name + '/', split_name,
max_len=2048, split_file=data_seq_dir + data_name + '/' + 'no_singletons/splits.json')
with open(os.path.join('/mnt/checkpoints/evodiff/generations/', data_name + "_" + split_name + "_10k.fasta"), 'w') as f:
idx = np.arange(len(ds_train))
np.random.shuffle(idx)
successes = 0
i = -1
with tqdm(total=n) as pbar:
while successes < n:
i += 1
seq = ds_train[idx[i]][0]
for aa in OTHER_AAS +AMB_AAS:
if aa in seq:
break
else:
f.write(">%d\n" %i)
f.write(seq + "\n")
successes += 1
pbar.update(1)






def main():
if RANK == 0:
generate()


if __name__ == "__main__":
main()
214 changes: 214 additions & 0 deletions analysis/fpd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import os
from tqdm import tqdm
import h5py
import numpy as np
import pandas as pd
from scipy import linalg
import torch
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('white')

def mmd_rbf(X, Y, gamma=1.0):
"""MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))

Arguments:
X {[n_sample1, dim]} -- [X matrix]
Y {[n_sample2, dim]} -- [Y matrix]

Keyword Arguments:
gamma {float} -- [kernel parameter] (default: {1.0})

Returns:
[scalar] -- [MMD value]
"""
XX = metrics.pairwise.rbf_kernel(X, X, gamma)
YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)
XY = metrics.pairwise.rbf_kernel(X, Y, gamma)
return XX.mean() + YY.mean() - 2 * XY.mean()


def calculate_fid(act1, act2, eps=1e-6):
"""calculate frechet inception distance"""
# calculate mean and covariance statistics
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
# calculate sum squared difference between means
ssdiff = np.sum((mu1 - mu2) ** 2.0)
# calculate sqrt of product between cov
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; "
"adding %s to diagonal of cov estimates"
) % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# check and correct imaginary numbers from sqrt
if np.iscomplexobj(covmean):
covmean = covmean.real
# calculate score
fid = ssdiff + np.trace(sigma1) + np.trace(sigma2) - 2.0 * np.trace(covmean)
return fid

# Baselines
natural_sets = ['uniref_train', 'uniref_valid', 'gigaref_train', 'gigaref_test']
# natural_sets = ['train_GO', 'test_GO']
natural_files = {
'uniref_train': 'uniref50_202401_train_10k',
'uniref_test': 'uniref50_202401_test_10k',
'uniref_valid': 'uniref50_202401_valid_10k',
'uniref_rtest': 'uniref50_202401_rtest_10k',
'gigaref_train': 'gigaref_train_10k',
'gigaref_test': 'gigaref_test_10k',
'train_GO': 'train_GO',
'test_GO': 'test_10000_GO'
}
embedding_dir = '/home/kevyan/generations/proteinfer/'
embedding_dict = {s: torch.load(embedding_dir + natural_files[s] + '/partition_0.pt').numpy()
for s in natural_sets}
gamma = 1e-3
mult = 100
mmd_dict = {}
fpd_dict = {}
for i, s in enumerate(natural_sets):
ei = embedding_dict[s]
for j, s2 in enumerate(natural_sets):
if i > j:
ej = embedding_dict[s2]
mmd = mmd_rbf(ei[:], ej[:], gamma=gamma) * mult
fpd = calculate_fid(ei[:], ej[:], eps=1e-6)
print(s, s2, mmd, fpd)
mmd_dict[s + ':' + s2] = mmd
fpd_dict[s + ':' + s2] = fpd


pb_embedding_dir = '/home/kevyan/generations/protbert/'
pb_embedding_dict = {}
for s in natural_sets:
fn = os.path.join(pb_embedding_dir, natural_files[s] + '.h5')
f = h5py.File(fn, 'r')
pb_embedding_dict[s] = np.array([f[k] for k in f.keys()])
pb_gamma = 1
mult = 100
pb_mmd_dict = {}
pb_fpd_dict = {}
for i, s in enumerate(natural_sets):
ei = pb_embedding_dict[s]
for j, s2 in enumerate(natural_sets):
if i > j:
ej = pb_embedding_dict[s2]
mmd = mmd_rbf(ei[:], ej[:], gamma=pb_gamma) * mult
fpd = calculate_fid(ei[:], ej[:], eps=1e-6)
print(s, s2, mmd, fpd)
pb_mmd_dict[s + ':' + s2] = mmd
pb_fpd_dict[s + ':' + s2] = fpd

models = os.listdir(embedding_dir)
models = [m for m in models if 'jamba' in m]
model_name = {
'jamba-3b-indel-gigaclust-120k-2': '3b-msa-gigaclust',
'jamba-3b-cooldown': '3b-msa-uniref90-cooldown',
'jamba-3b-cooldown7': '3b-msa-uniref90-cooldown',
'jamba-170m-10mnovelty-36w': '170m-1novelty',
'jamba-170m-seq-36w': '170m-uniref50',
'jamba-170m-10mrmsd-36w': '170m-rmsd',
'jamba-170m-10mbothfilter-36w': '170m-bothfilter',
'jamba-3b-seq-sam-biar-fsdp-tok90k': '3b-uniref90',
'jamba-170m-10mnofilter-36w': '170m-nofilter',
'jamba-170m-seqsam-36w': '170m-uniref90',
'jamba-170m-gigaclust-36w': '170m-gigaclust'
}
df = pd.DataFrame(columns=[
'name',
'direction',
'temperature',
'step',
'proteinfer_mmd_to_uniref',
'proterinfer_mmd_to_gigaref',
'protbert_mmd_to_uniref',
'protbert_mmd_to_gigaref',
'proteinfer_fd_to_uniref',
'proteinfer_fd_to_gigaref',
'protbert_fd_to_uniref',
'protbert_fd_to_gigaref',
])
for i, m in tqdm(enumerate(models)):
# d = m.split('_')
# df.loc[i, 'name'] = model_name[d[0]]
# df.loc[i, 'step'] = int(d[1])
# df.loc[i, 'direction'] = d[2].split('.')[1]
# df.loc[i, 'temperature'] = float(d[3][1:])
# emb = torch.load(embedding_dir + m + '/partition_0.pt').numpy()
# if np.isnan(emb).any():
# emb = emb[np.isnan(emb).sum(axis=1) == 0]
# df.loc[i, 'proteinfer_mmd_to_uniref'] = mmd_rbf(emb, embedding_dict['uniref_valid'], gamma=gamma) * mult
# df.loc[i, 'proteinfer_mmd_to_gigaref'] = mmd_rbf(emb, embedding_dict['gigaref_test'], gamma=gamma) * mult
# df.loc[i, 'proteinfer_fd_to_uniref'] = calculate_fid(emb, embedding_dict['uniref_valid'])
# df.loc[i, 'proteinfer_fd_to_gigaref'] = calculate_fid(emb, embedding_dict['gigaref_test'])
emb = h5py.File(pb_embedding_dir + '/' + m + '.h5')
emb = np.array([emb[k] for k in emb.keys()])
if np.isnan(emb).any():
emb = emb[np.isnan(emb).sum(axis=1) == 0]
df.loc[i, 'protbert_mmd_to_uniref'] = mmd_rbf(emb, pb_embedding_dict['uniref_valid'], gamma=pb_gamma) * mult
df.loc[i, 'protbert_mmd_to_gigaref'] = mmd_rbf(emb, pb_embedding_dict['gigaref_test'], gamma=pb_gamma) * mult
df.loc[i, 'protbert_fd_to_uniref'] = calculate_fid(emb, pb_embedding_dict['uniref_valid'])
df.loc[i, 'protbert_fd_to_gigaref'] = calculate_fid(emb, pb_embedding_dict['gigaref_test'])
df.to_csv('/home/kevyan/generations/fpd.csv', index=False)

models_to_plot = ['3b-msa-gigaclust', '3b-msa-uniref90-cooldown', '3b-uniref', '170m-uniref90', '170m-gigaclust']
uniref_hue_order = ['3b-uniref', '3b-msa-uniref90-cooldown', '170m-uniref90', '3b-msa-gigaclust', '170m-gigaclust']
plot_me = df[(df['name'].isin(models_to_plot)) & (df['temperature'] > 0.8) & (df['temperature'] < 1.2)]

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_mmd_to_uniref',
ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False)
_ = axs[0].axhline(mmd_dict['uniref_valid:uniref_train'], color='gray')
_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_mmd_to_gigaref',
ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order)
_ = axs[1].axhline(mmd_dict['gigaref_test:gigaref_train'], color='gray')
# for ax in axs:
# _ = ax.set_ylim([-0.01, 0.6])
_ = axs[1].legend(bbox_to_anchor=(1.1, 1.))
_ = fig.savefig('/home/kevyan/generations/proteinfer_mmd.png', dpi=300, bbox_inches='tight')

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_fd_to_uniref',
ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False)
_ = axs[0].axhline(fpd_dict['uniref_valid:uniref_train'], color='gray')
_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_fd_to_gigaref',
ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order)
_ = axs[1].axhline(fpd_dict['gigaref_test:gigaref_train'], color='gray')
#
# for ax in axs:
# _ = ax.set_ylim([-0.01, 0.4])
_ = axs[1].legend(bbox_to_anchor=(1.1, 1.))
_ = fig.savefig('/home/kevyan/generations/proteinfer_fpd.png', dpi=300, bbox_inches='tight')


plot_me = df[(df['name'].isin(models_to_plot))] # & (df['temperature'] > 0.8) & (df['temperature'] < 1.2)]

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
_ = sns.lineplot(plot_me, x='temperature', y='protbert_mmd_to_uniref',
ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False)
_ = axs[0].axhline(pb_mmd_dict['uniref_valid:uniref_train'], color='gray')
_ = sns.lineplot(plot_me, x='temperature', y='protbert_mmd_to_gigaref',
ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order)
_ = axs[1].axhline(pb_mmd_dict['gigaref_test:gigaref_train'], color='gray')
_ = axs[1].legend(bbox_to_anchor=(1.1, 1.))
_ = fig.savefig('/home/kevyan/generations/protbert_mmd.png', dpi=300, bbox_inches='tight')
plot_me = df[(df['name'].isin(models_to_plot)) & (df['temperature'] > 0.7)] # & (df['temperature'] < 1.2)]

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
_ = sns.lineplot(plot_me, x='temperature', y='protbert_fd_to_uniref',
ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False)
_ = axs[0].axhline(pb_fpd_dict['uniref_valid:uniref_train'], color='gray')
_ = sns.lineplot(plot_me, x='temperature', y='protbert_fd_to_gigaref',
ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order)
_ = axs[1].axhline(pb_fpd_dict['gigaref_test:gigaref_train'], color='gray')
_ = axs[1].legend(bbox_to_anchor=(1.1, 1.))
_ = fig.savefig('/home/kevyan/generations/protbert_fpd.png', dpi=300, bbox_inches='tight')
df[df['name'] == '3b-msa-uniref90-cooldown'][['direction', 'temperature', 'protbert_fd_to_uniref']]
Loading