In [1]:
import pickle
import torch
import pandas as pd
import numpy as np

from torch_geometric.utils import unbatch_edge_index
from pair_prediction.model.utils import get_negative_edges
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

In [2]:
with open('../outputs/rinalmo_6chr_full/results.pkl', 'rb') as f:
    full_results = pickle.load(f)

with open('../outputs/rinalmo_6chr_desc/results.pkl', 'rb') as f:
    desc_results = pickle.load(f)

In [9]:
def df_from_results(results):
    extracted_data = {
        'ids': [],
        'sequences': [],
        'inference_edge_index': [],
        'probabilities': [],
        'labels': [],
        'edge_index': [],
    }

    for result in results:
        data = result['data']
        extracted_data['ids'].extend(data.id)
        extracted_data['sequences'].extend(data.seq)
        extracted_data['edge_index'].extend([x.numpy() for x in unbatch_edge_index(data.edge_index, data.batch)])

        edge_mask = np.concatenate(data.edge_type) == 'non-canonical'
        pos_edge_index = data.edge_index[:, edge_mask]
        neg_edge_index = get_negative_edges(data, validation=True)
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1).numpy()
        
        extracted_data['inference_edge_index'].append(edge_index)
        extracted_data['probabilities'].append(result['probabilities'].numpy())
        extracted_data['labels'].append(result['labels'].numpy())

    return pd.DataFrame(extracted_data)

desc_df = df_from_results(desc_results)
full_df = df_from_results(full_results)

In [11]:
# calculate metrics
def calculate_metrics(row):
    y_true = row['labels']
    y_pred = (row['probabilities'] > 0.5).astype(int)
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    return accuracy, precision, recall, f1

desc_df[['accuracy', 'precision', 'recall', 'f1']] = desc_df.apply(calculate_metrics, axis=1, result_type='expand')
full_df[['accuracy', 'precision', 'recall', 'f1']] = full_df.apply(calculate_metrics, axis=1, result_type='expand')

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(re

In [12]:
print("Full Sequence Metrics:")
print(full_df[['accuracy', 'precision', 'recall', 'f1']].mean())

Full Sequence Metrics:
accuracy     0.998722
precision    0.275000
recall       0.517647
f1           0.359184
dtype: float64


In [13]:
print("\nDescriptor Sequence Metrics:")
print(desc_df[['accuracy', 'precision', 'recall', 'f1']].mean())
print("Standard Deviation:")
print(desc_df[['accuracy', 'precision', 'recall', 'f1']].std())


Descriptor Sequence Metrics:
accuracy     0.994418
precision    0.634739
recall       0.437113
f1           0.494954
dtype: float64
Standard Deviation:
accuracy     0.005198
precision    0.427956
recall       0.343915
f1           0.351910
dtype: float64


In [66]:
TRUE_IDX_MAPPING = {}

with open('../data/evaluation/6chr_full_clean/raw/idxs/6CHR_1_A-B.idx', 'rb') as f:
    # read line by line
    for line in f:
        line = line.decode('utf-8').strip()
        true_idx, idx = line.split(",", 1)
        if idx in TRUE_IDX_MAPPING:
            raise ValueError(f"Duplicate index found: {idx}")
        TRUE_IDX_MAPPING[idx] = int(true_idx) - 1

def get_true_idx_for_descriptor(row):
    desc_id = row['ids']
    true_nucleotide_idxes = []
    with open(f'../data/evaluation/6chr_desc_clean/raw/idxs/{desc_id}.idx', 'rb') as f:
        for line in f:
            line = line.decode('utf-8').strip()
            _, idx = line.split(",", 1)
            true_nucleotide_idxes.append(TRUE_IDX_MAPPING[idx])
    return true_nucleotide_idxes

def map_edge_index_to_true_idx(edge_index, true_nucleotide_idxes):
    true_edge_index = []
    for src, dst in edge_index.T:
        true_src = true_nucleotide_idxes[src]
        true_dst = true_nucleotide_idxes[dst]
        true_edge_index.append((true_src, true_dst))
    return true_edge_index

desc_df['true_nucleotide_idxes'] = desc_df.apply(get_true_idx_for_descriptor, axis=1)
desc_df['true_edge_index'] = desc_df.apply(
    lambda row: map_edge_index_to_true_idx(row['inference_edge_index'], row['true_nucleotide_idxes']), axis=1
)

In [67]:
# Aggregate predictions for true edge indices from all descriptors
def aggregate_predictions(df):
    aggregated = {}
    for _, row in df.iterrows():
        true_edge_indexes = tuple(map(tuple, row['true_edge_index']))
        probabilities = row['probabilities']
        labels = row['labels']
        for true_edge_index, probability, label in zip(true_edge_indexes, probabilities, labels):
            if true_edge_index not in aggregated:
                aggregated[true_edge_index] = {
                    'probabilities': [],
                    'predictions': [],
                    'labels': []
                }
            aggregated[true_edge_index]['probabilities'].append(probability)
            aggregated[true_edge_index]['predictions'].append(1 if probability > 0.5 else 0)
            aggregated[true_edge_index]['labels'].append(label)
    
    return aggregated

aggregated_desc = aggregate_predictions(desc_df)

In [68]:
# For each edge_index calculate the average probability, and the majority vote for prediction
def calculate_aggregated_metrics(aggregated):
    result = []
    for edge_index, data in aggregated.items():
        labels = data['labels']
        probabilities = data['probabilities']
        predictions = data['predictions']
        
        average_probability = sum(probabilities) / len(probabilities)
        max_probability = max(probabilities)
        majority_vote = max(set(predictions), key=predictions.count)

        result.append({
            'edge_index': edge_index,
            'average_probability': average_probability,
            'majority_vote': majority_vote,
            'max_probability': max_probability,
            'labels': int(labels[0])
        })

    return pd.DataFrame(result)

aggregated_metrics = calculate_aggregated_metrics(aggregated_desc)

In [69]:
desc_f1_majority = f1_score(aggregated_metrics['labels'].tolist(), aggregated_metrics['majority_vote'].tolist())
desc_f1_max = f1_score(aggregated_metrics['labels'].tolist(), (aggregated_metrics['max_probability'] > 0.5).astype(int).tolist())
desc_f1_average = f1_score(aggregated_metrics['labels'].tolist(), (aggregated_metrics['average_probability'] > 0.5).astype(int).tolist())

print("\nDescriptor Sequence Aggregated Metrics:")
print(f"F1 Score (Majority Vote): {desc_f1_majority}")
print(f"F1 Score (Max Probability): {desc_f1_max}")
print(f"F1 Score (Average Probability): {desc_f1_average}")


Descriptor Sequence Aggregated Metrics:
F1 Score (Majority Vote): 0.743455497382199
F1 Score (Max Probability): 0.5684931506849316
F1 Score (Average Probability): 0.7513227513227513


In [70]:
all_edges = set(map(tuple, full_df['inference_edge_index'].iloc[0].T.tolist()))
desc_edges = set(aggregated_metrics.edge_index.tolist())

print("\nFull Sequence vs Descriptor Sequence Edge Comparison:")
print(f"Total edges in full sequence: {len(all_edges)}")
print(f"Total edges in descriptor sequence: {len(desc_edges)}")
print(f"Common edges: {len(all_edges & desc_edges)}")
print(f"Unique edges in full sequence: {len(all_edges - desc_edges)}")
print(f"Unique edges in descriptor sequence: {len(desc_edges - all_edges)}")


Full Sequence vs Descriptor Sequence Edge Comparison:
Total edges in full sequence: 245626
Total edges in descriptor sequence: 54384
Common edges: 54322
Unique edges in full sequence: 191304
Unique edges in descriptor sequence: 62


In [73]:
import matplotlib.pyplot as plt

def plot_edge_comparison_heatmap(seq, set_a, set_b, title="Edge Comparison Heatmap", filename="edge_heatmap.png"):
    """
    Plot a heatmap showing edge presence in two sets.

    Args:
        seq (str or list): RNA base sequence.
        set_a (set of (i, j)): First set of edges (e.g., canonical).
        set_b (set of (i, j)): Second set of edges (e.g., predicted).
        title (str): Plot title.
        filename (str): File name to save the heatmap.

    Notes:
        Color coding:
            - Green (1): Edge in A only
            - Blue (2): Edge in B only
            - Red (3): Edge in both A and B
            - NaN: No edge
    """
    L = len(seq)
    heat = np.full((L, L), np.nan)

    set_a = set(set_a)
    set_b = set(set_b)

    only_a = set_a - set_b
    only_b = set_b - set_a
    overlap = set_a & set_b

    for i, j in only_a:
        heat[i-1, j-1] = 1
    for i, j in only_b:
        heat[i-1, j-1] = 2
    for i, j in overlap:
        heat[i-1, j-1] = 3

    cmap = plt.cm.get_cmap("viridis", 3)
    cmap.set_bad(color="white")

    step = 1

    fig, ax = plt.subplots(figsize=(15, 15), dpi=500)
    im = ax.imshow(heat, cmap=cmap, origin="upper", vmin=0.5, vmax=3.5)

    ax.set_title(title, pad=10)
    ax.set_xlabel("Node j  (sequence order)")
    ax.set_ylabel("Node i  (sequence order)")

    ax.set_xticks(np.arange(0, L, step))
    ax.set_yticks(np.arange(0, L, step))
    ax.set_xticklabels(seq[::step], fontsize=6, rotation=90)
    ax.set_yticklabels(seq[::step], fontsize=6)

    cbar = fig.colorbar(im, ax=ax, fraction=0.045, pad=0.04, ticks=[1, 2, 3])
    cbar.ax.set_yticklabels(['Full Sequence Considered Edges', 'Descriptors Considered Edges', 'Overlap'])

    plt.tight_layout()
    fig.savefig(filename)
    plt.close(fig)

plot_edge_comparison_heatmap(full_df['sequences'].iloc[0], set_a=all_edges, set_b=desc_edges)

  cmap = plt.cm.get_cmap("viridis", 3)


In [76]:
all_edges = list(map(tuple, full_df['inference_edge_index'].iloc[0].T.tolist()))
desc_edges = list(aggregated_metrics.edge_index.tolist())

In [99]:
import tqdm

preds_total = np.zeros(len(all_edges))

for i, edge in tqdm.tqdm(enumerate(all_edges)):
    if edge in desc_edges:
        preds_total[i] = (aggregated_metrics[aggregated_metrics.edge_index == edge]['average_probability'] > 0.5).astype(int).tolist()[0]

245626it [04:48, 850.98it/s] 


In [100]:
total_desc_f1 = f1_score(full_df['labels'].iloc[0], preds_total.tolist())
print("\nFull Sequence Total Descriptor F1 Score:")
print(f"F1 Score: {total_desc_f1}")


Full Sequence Total Descriptor F1 Score:
F1 Score: 0.6299212598425197
