<a href="https://colab.research.google.com/github/kattens/SASA-Calculation-For-LLMs/blob/main/Part_1_Dataframe_Creation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary libraries
!pip install faiss-gpu
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install torch-geometric

# Define the PyTorch and CUDA version for compatibility
import torch
TORCH_VERSION = torch.__version__.split('+')[0]  # get the PyTorch version
CUDA_VERSION = torch.version.cuda.replace('.', '')  # format the CUDA version correctly

# Install PyTorch Geometric dependencies
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH_VERSION}+cu{CUDA_VERSION}.html


Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidi

In [None]:
!pip install pymol-pyqt

[31mERROR: Could not find a version that satisfies the requirement pymol-pyqt (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for pymol-pyqt[0m[31m
[0m

# Dataset Overview

## Base Dataset:
Our dataset comprises a folder of PDB files, each separated by their chains and renamed to follow the format: `protein_name_chain_ID`.

## Purpose:
We aim to construct a dataset containing global sequences, local sequences, and C-alpha coordinates of the sequences.

### Global Sequences:
These are the unaltered protein sequences directly extracted from the corresponding PDB files.

### Local Sequences:
These are modified global sequences, adjusted based on the proximity of amino acids in different chains within a PDB file. We will identify the closest amino acids using the FAISS method.

### Coordinates:
We will extract only the C-alpha coordinates of each amino acid. To handle the embeddings of these values, we will build geometric graphs using PyTorch.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import pandas as pd
import torch

# Define the PyTorch and CUDA version for compatibility
TORCH_VERSION = torch.__version__.split('+')[0]  # get the PyTorch version
CUDA_VERSION = torch.version.cuda.replace('.', '')  # format the CUDA version correctly

csv_file = '/content/drive/MyDrive/batches/Batch1.csv'
df = pd.read_csv(csv_file).dropna()


In [None]:
import pandas as pd
import ast

# Function to parse coordinates from the string representation of tuples
def parse_coordinates(input_str):
    try:
        # Split the input string by ";"
        tuples_list = input_str.split(";")

        # Initialize an empty list to store the parsed coordinates
        parsed_coords = []

        # Iterate over each tuple string
        for tup_str in tuples_list:
            # Convert the tuple string to a tuple using ast.literal_eval
            tuple_values = ast.literal_eval(tup_str.strip())
            # Ensure each value in the tuple is a float
            tuple_values = tuple(val if isinstance(val, float) else float(val) for val in tuple_values)
            # Append the tuple to the parsed coordinates list
            parsed_coords.append(tuple_values)

        return parsed_coords
    except ValueError as e:
        print(f"Error in parsing coordinates: {e}")
        return None

# Apply the parse_coordinates function to each element in the column and store the parsed coordinates in the same column
df['Parsed Coordinates'] = df['Calpha Coordinates'].apply(parse_coordinates)


In [None]:
# Define a dictionary for amino acid tokens
amino_acid_tokens = {
    "[PAD]": 0,
    "[UNK]": 1,
    "[CLS]": 2,
    "[SEP]": 3,
    "[MASK]": 4,
    "L": 5,
    "A": 6,
    "G": 7,
    "V": 8,
    "E": 9,
    "S": 10,
    "I": 11,
    "K": 12,
    "R": 13,
    "D": 14,
    "T": 15,
    "P": 16,
    "N": 17,
    "Q": 18,
    "F": 19,
    "Y": 20,
    "M": 21,
    "H": 22,
    "C": 23,
    "W": 24,
    "X": 25,
    "U": 26,
    "B": 27,
    "Z": 28,
    "O": 29,
    "-": 32
}
# Function to tokenize a single protein sequence
def tokenize_sequence(sequence, tokenizer=amino_acid_tokens):
    # Using .get() method with 'UNK' as default value for unknown amino acids
    return [tokenizer.get(aa, tokenizer['[UNK]']) for aa in sequence]

# Applying the tokenization function to each protein sequence in the DataFrame
df['tokenized_sequence'] = df['Sequence'].apply(tokenize_sequence)

In [None]:
coords_array = np.array([coord for coords in df['Parsed Coordinates'].dropna() for coord in coords])
mean_values, std_values = np.mean(coords_array, axis=0), np.std(coords_array, axis=0)

def normalize_coordinates(coords_tuple):
    # Check if coords_tuple is None and return None to handle missing or erroneous data gracefully
    if coords_tuple is None:
        return None
    coords = np.array(coords_tuple)
    normalized_coords = (coords - mean_values) / std_values
    return list(map(tuple, normalized_coords))

# Ensure that df['Parsed Coordinates'] is correctly applied with the safe parsing and normalization

df['Normalized Coordinates'] = df['Parsed Coordinates'].apply(normalize_coordinates)


In [None]:
df

Unnamed: 0,Protein Name,Sequence,Calpha Coordinates,Parsed Coordinates,tokenized_sequence,Normalized Coordinates
0,3ZK8_A,KLKVVATNSIIADITKNIAGDKIDLHSIVPIGQDPHEYEPLPEDVK...,"(15.055, -28.894, 35.620); (12.377, -27.477, 3...","[(15.055, -28.894, 35.62), (12.377, -27.477, 3...","[12, 5, 12, 8, 8, 6, 15, 17, 10, 11, 11, 6, 14...","[(-0.025218022009773956, -0.6347698991412715, ..."
2,4YT3_B,VIAVKEITRFKTRTEEFSPYAWCKRMLENDPVSYHEGTDTWNVFKY...,"(-38.525, -18.057, 36.586); (-40.400, -19.890,...","[(-38.525, -18.057, 36.586), (-40.4, -19.89, 3...","[8, 11, 6, 8, 12, 9, 11, 15, 13, 19, 12, 15, 1...","[(-1.2066447358287875, -0.4811509148629433, 0...."
3,4A9N_B,PGRVTNQLQYLHKVVMKALWKHQFAWPFRQPVDAVKLGLPDYHKII...,"(32.482, 26.164, 21.833); (30.940, 22.843, 22....","[(32.482, 26.164, 21.833), (30.94, 22.843, 22....","[16, 7, 13, 8, 15, 17, 18, 5, 18, 20, 5, 22, 1...","[(0.35904333188578313, 0.1457001605058806, -0...."
4,5AQI_A,MSKGPAVGIDLGTTYSCVGVFQHGKVEIIANDQGNRTTPSYVAFTD...,"(-51.578, 1.626, 47.056); (-53.863, 1.310, 43....","[(-51.578, 1.626, 47.056), (-53.863, 1.31, 43....","[21, 10, 12, 7, 16, 6, 8, 7, 11, 14, 5, 7, 15,...","[(-1.4944603927059728, -0.20213619976187008, 0..."
6,4AY3_B,KSLVGVIMGSTSDWETMKYACDILDELNIPYEKKVVSAHRTPDYMF...,"(-24.139, -39.386, 35.906); (-21.446, -36.827,...","[(-24.139, -39.386, 35.906), (-21.446, -36.827...","[12, 10, 5, 8, 7, 8, 11, 21, 7, 10, 15, 10, 14...","[(-0.8894367346156422, -0.7834983648650159, 0...."
...,...,...,...,...,...,...
10895,6IQT_C,EAVVDSATSKFVSLLFGYSKNSLRDRKDQL?QYCDVSFQTQA?R?F...,"(38.831, -13.032, -42.515); (38.655, -16.261, ...","[(38.831, -13.032, -42.515), (38.655, -16.261,...","[9, 6, 8, 8, 14, 10, 6, 15, 10, 12, 19, 8, 10,...","[(0.4990373260260766, -0.4099194489592248, -1...."
10896,4Y0I_C,FDYDGPLMKTEVPGPRSRELMKQLNIIQNAEAVHFFCNYEESRGNY...,"(52.983, 48.156, 73.118); (49.196, 47.933, 72....","[(52.983, 48.156, 73.118), (49.196, 47.933, 72...","[19, 14, 20, 14, 7, 16, 5, 21, 12, 15, 9, 8, 1...","[(0.8110856808966752, 0.4574459113824138, 0.84..."
10897,6GPZ_B,GHMKVKLSAKEILEKEFKTGVRGYKQEDVDEFLDMIIKDYETFHQE...,"(18.932, 9.545, 8.554); (16.785, 6.556, 9.548)...","[(18.932, 9.545, 8.554), (16.785, 6.556, 9.548...","[7, 22, 21, 12, 8, 12, 5, 10, 6, 12, 9, 11, 5,...","[(0.06026893897335995, -0.08988107966404967, -..."
10898,3WAN_A,DDFVMIGSPSDRPFKQRRSFADRCKEVQQIRDQHPSKIPVIIERYK...,"(28.951, -6.400, -9.956); (25.263, -7.298, -9....","[(28.951, -6.4, -9.956), (25.263, -7.298, -9.8...","[14, 14, 19, 8, 21, 11, 7, 10, 16, 10, 14, 13,...","[(0.281185591562996, -0.3159080893824165, -0.6..."


In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph

# Adjust the radius for graph construction based on your specific requirements
RADIUS = 10

def build_graphs(normalized_coords_list):
    graphs = [None] * len(normalized_coords_list)  # Initialize with placeholders for all rows
    for i, coords in enumerate(normalized_coords_list):
        if coords is not None:  # Only process if coordinates are available
            node_features = torch.tensor(coords, dtype=torch.float)
            if node_features.size(0) > 1:  # Ensure there are enough nodes for graph construction
                edge_index = radius_graph(node_features, r=RADIUS, loop=False)
                graph_data = Data(x=node_features, edge_index=edge_index)
                graphs[i] = graph_data
    return graphs

# Placeholder function for generating embeddings from graphs
# Replace with your actual model's forward pass for embedding generation
def generate_embeddings(graphs, model=None):
    embeddings = []
    for graph in graphs:
        if graph is not None:
            embeddings.append(torch.randn(128))  # Placeholder: Random embeddings
        else:
            embeddings.append(torch.zeros(128))  # Placeholder for missing data
    return embeddings

# Normalize coordinates and prepare for graph construction
normalized_coords_list = df['Normalized Coordinates'].tolist()
graphs = build_graphs(normalized_coords_list)
#placeholder for the (GNN)
embeddings = generate_embeddings(graphs)

# Add embeddings back to the DataFrame. This step assumes the lengths match because we now account for all rows.
df['Embeddings'] = embeddings


In [None]:
df

In [None]:
import numpy as np
import faiss
import pandas as pd
from itertools import combinations

# Function to calculate pairs within a specified distance using FAISS
def calculate_pairs_within_distance(coords1, coords2, max_dist=6):
    if coords1 is None or coords2 is None:
        return []
    coords1_np = np.array(coords1).reshape(-1, 3).astype('float32')
    coords2_np = np.array(coords2).reshape(-1, 3).astype('float32')

    # Initialize FAISS index and perform the search
    index = faiss.IndexFlatL2(3)  # L2 distance for Euclidean
    index.add(coords2_np)
    distances, indices = index.search(coords1_np, len(coords2_np))

    squared_max_dist = max_dist ** 2
    close_indices = []

    # Collect indices where the distance is within the specified range
    for dist, idx in zip(distances, indices):
        mask = (dist <= squared_max_dist)
        close_indices.extend(idx[mask].tolist())

    return close_indices

# Function to create a masked sequence where only close amino acids are visible
def mask_sequence(sequence, indices):
    masked_sequence = ['-' for _ in sequence]  # Start with a fully masked sequence
    for idx in indices:
        if 0 <= idx < len(sequence):
            masked_sequence[idx] = sequence[idx]  # Unmask close amino acids
    return ''.join(masked_sequence)

df['pair_id'] = df['Protein Name'].apply(lambda x: x[:4])  # Simplified example for pair_id

pairs_list = []

# Generate pairwise comparisons for proteins within the same group
for pair_id, group in df.groupby('pair_id'):
    if len(group) > 1:
        for (idx1, row1), (idx2, row2) in combinations(group.iterrows(), 2):
            top_indices_1 = calculate_pairs_within_distance(row1['Parsed Coordinates'], row2['Parsed Coordinates'])
            top_indices_2 = calculate_pairs_within_distance(row2['Parsed Coordinates'], row1['Parsed Coordinates'])

            if top_indices_1 and top_indices_2:
                masked_seq_1 = mask_sequence(row1['Sequence'], top_indices_1)
                masked_seq_2 = mask_sequence(row2['Sequence'], top_indices_2)

                pairs_list.append({
                    'pair_id': pair_id,
                    'Protein Name A': row1['Protein Name'],
                    'Protein Name B': row2['Protein Name'],
                    'masked_sequence_A': masked_seq_1,
                    'masked_sequence_B': masked_seq_2,
                    'coords_A': row1['Parsed Coordinates'],
                    'coords_B': row2['Parsed Coordinates'],
                })

# Convert the list of pairs into a DataFrame
pairs_df = pd.DataFrame(pairs_list)
print(pairs_df)


     pair_id Protein Name A Protein Name B  \
0       3VU0         3VU0_B         3VU0_C   
1       3VU9         3VU9_A         3VU9_B   
2       3VUA         3VUA_A         3VUA_D   
3       3VUA         3VUA_A         3VUA_B   
4       3VUA         3VUA_F         3VUA_C   
...      ...            ...            ...   
7573    6K8Z         6K8Z_A         6K8Z_B   
7574    6K90         6K90_A         6K90_B   
7575    6K92         6K92_A         6K92_B   
7576    6K9S         6K9S_A         6K9S_B   
7577    6K9Z         6K9Z_B         6K9Z_A   

                                      masked_sequence_A  \
0     ----------------------------------------------...   
1     ----------------------------------------------...   
2     ---------------------V--G---------------------...   
3     -------------ES-------------------------------...   
4     ----------------------------------------------...   
...                                                 ...   
7573  -----------VT---------------

In [None]:
pairs_df

Unnamed: 0,pair_id,Protein Name A,Protein Name B,masked_sequence_A,masked_sequence_B,coords_A,coords_B
0,3VU0,3VU0_B,3VU0_C,----------------------------------------------...,-------------------------V--W-----------------...,"[(58.138, 7.448, -53.606), (61.331, 9.196, -54...","[(41.128, -11.483, -45.678), (39.232, -14.595,..."
1,3VU9,3VU9_A,3VU9_B,----------------------------------------------...,---------------------------------------------L...,"[(45.065, -10.931, 82.73), (41.902, -8.723, 82...","[(41.184, 10.935, 14.492), (40.651, 11.419, 18..."
2,3VUA,3VUA_A,3VUA_D,---------------------V--G---------------------...,----------------------------------------------...,"[(4.049, -42.702, -2.177), (1.971, -43.235, 1....","[(19.162, -69.14, 35.277), (19.905, -66.898, 3..."
3,3VUA,3VUA_A,3VUA_B,-------------ES-------------------------------...,-------------ES-------------------------------...,"[(4.049, -42.702, -2.177), (1.971, -43.235, 1....","[(-11.677, 11.253, 37.887), (-14.425, 11.009, ..."
4,3VUA,3VUA_F,3VUA_C,----------------------------------------------...,------------------------G---------------------...,"[(-65.647, -52.657, 41.232), (-64.09, -54.007,...","[(-38.81, -56.242, -0.167), (-38.24, -53.36, 2..."
...,...,...,...,...,...,...,...
7573,6K8Z,6K8Z_A,6K8Z_B,-----------VT-------------------VDA-----L-----...,------------HG------------L--F---AIN-----N----...,"[(5.212, 10.332, 8.132), (6.741, 9.036, 11.435...","[(-27.838, -3.263, 18.535), (-24.763, -1.591, ..."
7574,6K90,6K90_A,6K90_B,-----------TH------------Q--G---DAI-----S-----...,-----------TH------------Q------DAI-----S-----...,"[(-0.897, -1.881, 17.933), (2.606, -0.798, 16....","[(33.904, 10.022, 10.971), (30.532, 9.244, 12...."
7575,6K92,6K92_A,6K92_B,-----------TH-------------------DAI-----S-----...,-----------TH------------Q------DAI-----S-----...,"[(14.977, 14.449, 27.52), (17.95, 13.315, 25.2...","[(49.379, 1.79, 20.031), (45.993, 3.173, 21.38..."
7576,6K9S,6K9S_A,6K9S_B,----------------------------------------------...,----------------------------------------------...,"[(10.858, 24.115, -23.296), (12.813, 26.693, -...","[(-0.335, 24.559, 99.336), (-1.587, 26.756, 96..."


In [None]:
row = pairs_df.iloc[7438]

# Print each column name and its value
for column, value in row.items():
    print(f"{column}: {value}")


pair_id: 6JY3
Protein Name A: 6JY3_D
Protein Name B: 6JY3_I
masked_sequence_A: ------------------------------------------------------------M---------------------------------------------------------------------------
masked_sequence_B: ----------------------------------------------------------------------
coords_A: [(28.447, 63.655, 39.832), (29.604, 60.323, 41.242), (26.365, 59.794, 43.28), (23.353, 57.711, 42.236), (20.885, 60.037, 40.529), (17.284, 60.363, 39.407), (16.059, 59.883, 35.828), (12.656, 59.739, 34.114), (13.104, 57.059, 31.45), (10.368, 54.708, 30.394), (12.7, 51.762, 30.246), (14.042, 52.523, 33.737), (10.717, 53.022, 35.597), (12.208, 52.683, 39.101), (15.278, 54.96, 39.344), (18.173, 54.6, 41.726), (17.04, 57.311, 44.199), (13.815, 59.037, 45.089), (12.935, 62.335, 43.483), (12.671, 64.066, 46.897), (15.598, 63.291, 49.155), (14.886, 65.597, 52.156), (11.444, 65.276, 53.73), (9.554, 67.86, 55.833), (7.631, 67.093, 59.013), (4.404, 65.95, 57.388), (6.13, 64.144, 54.519

In [None]:
import numpy as np
import faiss
import pandas as pd
from itertools import combinations

def calculate_unique_pairs_within_distance(coords1, coords2, min_dist=6, max_dist=8):
    if not coords1 or not coords2:
        return []

    # Convert coordinates to np.float32
    coords1_np = np.array(coords1, dtype=np.float32)
    coords2_np = np.array(coords2, dtype=np.float32)

    # Create FAISS index
    index = faiss.IndexFlatL2(3)  # 3 for 3D coordinates
    index.add(coords2_np)  # Add coords of Protein B to the index

    # Query the index with coords of Protein A
    distances, indices = index.search(coords1_np, coords2_np.shape[0])

    # Calculate distance thresholds
    squared_min_dist = min_dist ** 2
    squared_max_dist = max_dist ** 2

    # Store the unique pairs ensuring no amino acid is paired more than once
    unique_pairs = []
    used_indices_a = set()
    used_indices_b = set()

    for i, (dists, idxs) in enumerate(zip(distances, indices)):
        for dist, j in zip(dists, idxs):
            if squared_min_dist <= dist <= squared_max_dist and j not in used_indices_b:
                if i not in used_indices_a:
                    unique_pairs.append((i, j))
                    used_indices_a.add(i)
                    used_indices_b.add(j)
                    break  # Ensure each amino acid from Protein A is paired only once

    return unique_pairs

def mask_sequence(sequence, pairs, index_position):
    masked_sequence = ['-' * len(sequence)]
    masked_sequence = list(masked_sequence[0])  # Convert string to list for mutation
    indices = {pair[index_position] for pair in pairs}
    for idx in indices:
        masked_sequence[idx] = sequence[idx]
    return ''.join(masked_sequence)


df['pair_id'] = df['Protein Name'].apply(lambda x: x[:4])  # Extract pair_id from Protein Name

pairs_list = []

# Generate pairwise comparisons for proteins within the same group
for pair_id, group in df.groupby('pair_id'):
    if group.shape[0] > 1:
        for (idx1, row1), (idx2, row2) in combinations(group.iterrows(), 2):
            unique_pairs = calculate_unique_pairs_within_distance(
                row1['Parsed Coordinates'], row2['Parsed Coordinates'])
            masked_seq_1 = mask_sequence(row1['Sequence'], unique_pairs, 0)
            masked_seq_2 = mask_sequence(row2['Sequence'], unique_pairs, 1)

            pairs_list.append({
                'pair_id': pair_id,
                'Protein Name A': row1['Protein Name'],
                'Protein Name B': row2['Protein Name'],
                'masked_sequence_A': masked_seq_1,
                'masked_sequence_B': masked_seq_2
            })

# Convert list of pairs into a DataFrame
pairs_df = pd.DataFrame(pairs_list)


In [None]:
# Add embeddings and sequences for 'File Name A'
pairs_df = pairs_df.merge(df[['File Name', 'Embeddings', 'Sequence']],
                          left_on='File Name A',
                          right_on='File Name',
                          how='left',
                          suffixes=('', '_A'))

# Add embeddings and sequences for 'File Name B'
pairs_df = pairs_df.merge(df[['File Name', 'Embeddings', 'Sequence']],
                          left_on='File Name B',
                          right_on='File Name',
                          how='left',
                          suffixes=('_A', '_B'))

# At this point, 'pairs_df' has two sets of embeddings and sequences columns
pairs_df.drop(columns=['File Name_A', 'File Name_B'], inplace=True)


KeyError: "['File Name'] not in index"

In [None]:
#embedidngs are the coordinate embeddings
pairs_df.columns

In [None]:
# Assuming amino_acid_tokens and tokenize_sequence function are defined as previously

# Tokenize the original and masked sequences for both A and B
pairs_df['tokenized_sequence_A'] = pairs_df['Sequence_A'].apply(tokenize_sequence)
pairs_df['tokenized_sequence_B'] = pairs_df['Sequence_B'].apply(tokenize_sequence)
pairs_df['tokenized_masked_sequence_A'] = pairs_df['masked_sequence_A'].apply(tokenize_sequence)
pairs_df['tokenized_masked_sequence_B'] = pairs_df['masked_sequence_B'].apply(tokenize_sequence)

# Define a function to sum tokenized values for each protein in the pair
def sum_tokenized_sequences(seq_tokens, masked_tokens):
    # Assuming both lists are of the same length, otherwise consider padding or other handling
    return [seq + masked for seq, masked in zip(seq_tokens, masked_tokens)]

# Apply the summation function for both A and B sequences and their masked versions
pairs_df['sum_tokenized_sequence_A'] = pairs_df.apply(lambda row: sum_tokenized_sequences(row['tokenized_sequence_A'], row['tokenized_masked_sequence_A']), axis=1)
pairs_df['sum_tokenized_sequence_B'] = pairs_df.apply(lambda row: sum_tokenized_sequences(row['tokenized_sequence_B'], row['tokenized_masked_sequence_B']), axis=1)


In [None]:
pairs_df['masked_sequence_A'][100]

In [None]:
pairs_df

In [None]:
pairs_df.to_csv('Pairs_df.csv', index = False)

In [None]:
# Simulate the encoding of embeddings directly for demonstration purposes

# Assuming these are the tokenized and summed sequences for protein A and B from the second row
input_ids_a_example = torch.tensor(pairs_df.iloc[1]['sum_tokenized_sequence_A'], dtype=torch.long)
input_ids_b_example = torch.tensor(pairs_df.iloc[1]['sum_tokenized_sequence_B'], dtype=torch.long)

# Assuming these are the embeddings for protein A and B from the second row
embeddings_a_example = torch.tensor(pairs_df.iloc[1]['Embeddings_A'], dtype=torch.float)
embeddings_b_example = torch.tensor(pairs_df.iloc[1]['Embeddings_B'], dtype=torch.float)

# Directly do the encoding embeddings into a token-like format using a placeholder approach
encoded_embeddings_a_example = torch.arange(100, 105, dtype=torch.long)  # Simulated encoded embeddings for A
encoded_embeddings_b_example = torch.arange(105, 110, dtype=torch.long)  # Simulated encoded embeddings for B

# Concatenate the tokenized sequences and simulated encoded embeddings for the example input
combined_input_example = torch.cat([input_ids_a_example, input_ids_b_example, encoded_embeddings_a_example, encoded_embeddings_b_example])

combined_input_example


#Custom Collate Function

The collate_fn function is where we'll handle the padding and creation of attention masks. It will be passed to the DataLoader to process batches.

# Handling positional encoding

In [None]:
import numpy as np
import torch

def get_sinusoidal_encoding(n_positions, d_model):
    """Generate sinusoidal positional encodings."""
    position = np.arange(n_positions)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    encoding = np.zeros((n_positions, d_model))
    encoding[:, 0::2] = np.sin(position * div_term)
    encoding[:, 1::2] = np.cos(position * div_term)
    return torch.tensor(encoding, dtype=torch.float)

# Example usage
n_positions = 10  # Number of positions in the sequence
d_model = 512  # Dimensionality of the model/token embeddings
positional_encodings = get_sinusoidal_encoding(n_positions, d_model)
print(positional_encodings.shape)


In [None]:
#adding new tokens to the bert tokenizer
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch

# Initialize the tokenizer with BERT's vocabulary and extend it with new special tokens
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
new_tokens = ['[ENTITY1]', '[ENTITY2]']
num_added_toks = tokenizer.add_tokens(new_tokens)
print(f'We have added {num_added_toks} tokens to the tokenizer.')

# Load the BERT model and resize its token embeddings to accommodate the new special tokens
model = BertModel.from_pretrained('bert-base-uncased')
model.resize_token_embeddings(len(tokenizer))
print('Token embeddings resized to accommodate new tokens.')


# Function to convert numerical token IDs back to their textual representation (simplified example)
def ids_to_text(ids):
    return ' '.join(tokenizer.convert_ids_to_tokens(ids))

# Function to convert embeddings (coordinates) to a string representation
def embeddings_to_string(embeddings):
    return ' '.join([str(e) for e in embeddings])

input_texts = []  # This will store the full input texts for tokenization

for index, row in pairs_df.iterrows():
    sequence_a_text = ids_to_text(row['sum_tokenized_sequence_A'])
    sequence_b_text = ids_to_text(row['sum_tokenized_sequence_B'])
    encoded_coordinates_a = embeddings_to_string(row['Embeddings_A'])
    encoded_coordinates_b = embeddings_to_string(row['Embeddings_B'])

    # Construct the full input string for each pair
    input_text = f"[ENTITY1] {sequence_a_text} [SEP] {encoded_coordinates_a} [SEP] [ENTITY2] {sequence_b_text} [SEP] {encoded_coordinates_b}"
    input_texts.append(input_text)


# Tokenize all input texts in a batch
tokenized_inputs = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")

# tokenized_inputs contains the tokenized representation of your input texts,
# ready to be fed into the BERT model for sequence processing.
