In [80]:
# Retrieving contact maps from PINDER

from pathlib import Path
from torch.utils.data import DataLoader
from pinder.core import PinderLoader
from pinder.core.index.system import PinderSystem
from pinder.core.loader import structure
from pinder.core import get_pinder_location
from torch_geometric.nn import radius as georadius
from tqdm import tqdm
from pinder.core.loader import filters
from biotite.structure import sasa, apply_residue_wise
import numpy as np
from multiprocessing import Pool
from tqdm import tqdm
import os
import random
import pandas as pd
import torch

PINDER_BASE_DIR="/scicore/home/schwede/durair0000/.local/share/"
SPLITS_DIR="/scicore/home/schwede/pudziu0000/projects/gLM/data/PINDER/eubacteria_5_1024_512/"
os.environ["PINDER_BASE_DIR"] = PINDER_BASE_DIR
get_pinder_location()

def get_data(
    system_id,
    interface_threshold=8.0,
):
    system = PinderSystem(system_id)
    if system is None:
        return None
    seq_r, seq_l = system.native_R.sequence, system.native_L.sequence
    try:
        sasa_r = apply_residue_wise(
            system.native_R.atom_array, sasa(system.native_R.atom_array), np.nansum
        )
        sasa_l = apply_residue_wise(
            system.native_L.atom_array, sasa(system.native_L.atom_array), np.nansum
        )
    except Exception as e:
        print(e)
        return None
    r_coords = system.native_R.filter("atom_name", mask=["CA"]).coords
    l_coords = system.native_L.filter("atom_name", mask=["CA"]).coords
    if r_coords.shape[0] == len(seq_r) and l_coords.shape[0] == len(seq_l):
        pos_l, pos_r = georadius(
            torch.tensor(r_coords),
            torch.tensor(l_coords),
            r=interface_threshold,
            max_num_neighbors=10000,
        )
        return {
            "id": system_id,
            "pos_r": pos_r,
            "pos_l": pos_l,
            "seq_r": seq_r,
            "seq_l": seq_l,
            "sasa_r": sasa_r,
            "sasa_l": sasa_l,
        }
    return None

def get_atom_data(split, interface_threshold=8.0):
    # Retrieving positions of the residues that are considered to be interacting with each other
    df = pd.read_csv(f"{SPLITS_DIR}/{split}.txt", sep="\t")
    
    ids = {"R": [], "L": []}
    for i, j in df.iterrows():
        # Neglecting negatives
        if(df["label"][i] == 0): continue
        idR = df["protein1"][i]
        idL = df["protein2"][i]
        ids["R"].append(idR)
        ids["L"].append(idL)

    atom_data = []
    
    for i, id_ in enumerate(ids["R"]):
        if(i > 10): return atom_data
        try:
            struct_r = structure.Structure(f"{PINDER_BASE_DIR}/pinder/2024-02/pdbs/{id_}-R.pdb", pinder_id=id_)
            struct_l = structure.Structure(f"{PINDER_BASE_DIR}/pinder/2024-02/pdbs/{ids['L'][i]}-L.pdb", pinder_id=ids['L'][i])
        except Exception as e:
            print(e)
            continue
        sasa_r = apply_residue_wise(
            struct_r.atom_array, sasa(struct_r.atom_array), np.nansum
        )
        sasa_l = apply_residue_wise(
            struct_l.atom_array, sasa(struct_l.atom_array), np.nansum
        )
        r_coords = struct_r.filter("atom_name", mask=["CA"]).coords
        l_coords = struct_l.filter("atom_name", mask=["CA"]).coords

        pos_l, pos_r = georadius(
            torch.tensor(r_coords),
            torch.tensor(l_coords),
            r=interface_threshold,
            max_num_neighbors=10000,
        )

        atom_data_el = {
            "id": f"{id_}--{ids['L'][i]}",
            "pos_r": pos_r,
            "pos_l": pos_l,
            "sasa_r": sasa_r,
            "sasa_l": sasa_l,
        }
        
        atom_data.append(atom_data_el)

    return atom_data
        
            

In [81]:
data = get_atom_data("test")