In [None]:
%load_ext autoreload
%autoreload 2

In [34]:
import sys
sys.path.append('../pipeline/src')

import pandas as pd
import numpy as np
from datamodules.paired_dataset import PairedDataset
from datamodules.paired_datamodule import PairedProteinDataModule, PairedProteinDataModuleConfig
from datamodules.datamodule import ProteinDataModuleConfig, SamplingConfig, ProteinDataModule
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt










In [35]:
# Create paired dataset from DataFrames
import pandas as pd

train_dataset_path = '../pipeline/output/datasets/train.pt'
val_dataset_path = '../pipeline/output/datasets/val.pt'
test_dataset_path = '../pipeline/output/datasets/test.pt'
similarity_file_path = '../pipeline/output/mmseqs/all_seqs_similarity.tsv'
cluster_path = '../pipeline/output/mmseqs/all_seqs_clust.tsv'


In [None]:

sampling_conf = SamplingConfig(
    cluster_path=cluster_path,
    train_n_pdb=1.0,
    train_intersect_val_clusters=True,
    train_intersect_test_clusters=True,
    val_n_pdb=1.0,
    test_n_pdb=1.0,
    train_max_chains_per_pdb=1,
    val_max_chains_per_pdb=1,
    test_max_chains_per_pdb=1,
)
paired_datamodule_config = PairedProteinDataModuleConfig(
    _target_="pipeline.src.datamodules.paired_datamodule.PairedProteinDataModuleConfig",
    name="paired_datamodule",
    train_dataset_path=train_dataset_path,
    val_dataset_path=val_dataset_path,
    test_dataset_path=test_dataset_path,
    tokenizer_name="facebook/esm2_t33_650M_UR50D",
    batch_size=1,
    num_workers=1,
    max_seq_length=1024,
    contact_threshold=8.0,
    sampler=sampling_conf,
    similarity_file_path=similarity_file_path,
    min_similarity_threshold=0.3,
)

paired_datamodule = PairedProteinDataModule(paired_datamodule_config)
paired_datamodule.setup()




In [None]:
for i, batch in tqdm(enumerate(paired_datamodule.train_dataloader())):
    print(f"OUR BATCH {i}")
    # print(batch)
    if i>10:
        break

In [None]:
batch['metadata']

In [None]:
batch['primary_sequence']['metadata'], batch['similar_sequence']['metadata'], batch['similarity_score']

In [None]:
df_similarity = paired_datamodule.load_similarity_df(similarity_file_path)

In [None]:
df_similarity[df_similarity["query_id"] == batch["primary_sequence"]['metadata'][0]['id']]

In [25]:
train_similarities = [item['similarity_score'] for item in paired_datamodule.train_dataset]
test_similarities = [item['similarity_score'] for item in paired_datamodule.test_dataset]
val_similarities = [item['similarity_score'] for item in paired_datamodule.val_dataset]

In [None]:
sns.histplot(train_similarities)
plt.title("Train similarities")
plt.show()
sns.histplot(test_similarities)
plt.title("Test similarities")
plt.show()
sns.histplot(val_similarities)
plt.title("Val similarities")
plt.show()