In [None]:
import numpy as np

from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

from pair_prediction.data.dataset import LinkPredictionDataset

DATA_DIR = Path("/Users/dawid/Private/School/Master's Thesuis/non-canonical-base-pair-prediction/data/")

In [None]:
train_dataset = LinkPredictionDataset(DATA_DIR, mode="train")
validation_dataset = LinkPredictionDataset(DATA_DIR, mode="validation")

In [None]:
datasets = [train_dataset, validation_dataset]

for dataset in datasets:
    non_canonical_edges_count = 0
    multiplet_edges_count = 0
    multiplet_examples = []

    for idx, data in tqdm(enumerate(dataset)):
        edges = data.edge_index
        edge_types = np.array(data.edge_type)
        canonical_edges = edges[:,edge_types == "canonical"]
        non_canonical_edges = edges[:,edge_types == "non-canonical"]
        non_canonical_edges_count += int(non_canonical_edges.shape[1] / 2)

        nodes_with_canonical_edges = np.unique(canonical_edges.flatten())
        nodes_with_non_canonical_edges = np.unique(non_canonical_edges.flatten())
        nodes_with_both = np.intersect1d(nodes_with_canonical_edges, nodes_with_non_canonical_edges)
        multiplet_edges_count += len(nodes_with_both)

        if len(nodes_with_both) > 0:
            multiplet_examples.append(idx)

    print(f"Dataset: {dataset.mode}")
    print(f"Total non-canonical edges: {non_canonical_edges_count}")
    print(f"Total nodes with both canonical and non-canonical edges: {multiplet_edges_count}")
    print(f"Multipled edges are present in {multiplet_edges_count/non_canonical_edges_count:.2%} of non-canonical edges")

    plt.figure(figsize=(8, 6))
    plt.pie([multiplet_edges_count, non_canonical_edges_count-multiplet_edges_count],
            labels=['Multiplet Edges', 'Non-Multiplet Edges'],
            autopct='%1.1f%%', startangle=140)
    plt.title(f"Multiplet non-canonical edges in {dataset.mode} Dataset")
    plt.savefig(f"multiplet_non_canonical_edges_{dataset.mode}.png")

In [None]:
# Check how many examples have non-canonical edges in K neighbourhood around canonical edges
n_neighbourhood = [3, 5, 10, 15]

for dataset in datasets:
    print(f"Dataset: {dataset.mode}")
    averge_coverages = []
    average_non_canonical_edges_in_k_neighbourhood = []
    for k in n_neighbourhood:
        nodes_coverage = []
        non_canonical_coverage = []
        print(f"Checking for non-canonical edges in {k}-neighbourhood around canonical edges...")
        for idx, data in tqdm(enumerate(dataset)):
            non_canonical_nodes_in_k_neighbourhood = 0
            edges = data.edge_index
            edge_types = np.array(data.edge_type)

            canonical_edges = edges[:,edge_types == "canonical"]
            non_canonical_edges = edges[:,edge_types == "non-canonical"]

            nodes_with_canonical_edges = np.unique(canonical_edges.flatten())
            nodes_with_non_canonical_edges = np.unique(non_canonical_edges.flatten())

            for non_canonical_edge_node in nodes_with_non_canonical_edges:
                # Check if the nodes has index in range of k around canonical node
                if np.any(np.abs(nodes_with_canonical_edges - non_canonical_edge_node) <= k):
                    non_canonical_nodes_in_k_neighbourhood += 1

            # Check what is total coverage of neighbourhoods of k around canonical edges in relation to total number of nodes in the data
            canonical_neighbourhood = set()
            for node in nodes_with_canonical_edges:
                start = max(0, node - k)
                end = min(data.num_nodes, node + k + 1)
                canonical_neighbourhood.update(range(start, end))
            
            canonical_neighbourhood_coverage = len(canonical_neighbourhood) / data.num_nodes
            nodes_coverage.append(canonical_neighbourhood_coverage)
            non_canonical_coverage.append(non_canonical_nodes_in_k_neighbourhood / len(nodes_with_non_canonical_edges))

        print(f"Average coverage of non-canonical edge in {k}-neighbourhood around canonical edges: {np.mean(non_canonical_coverage):.2%}")
        print(f"Average coverage of {k}-neighbourhood around canonical edges: {np.mean(nodes_coverage):.2%}")
        averge_coverages.append(np.mean(nodes_coverage))
        average_non_canonical_edges_in_k_neighbourhood.append(np.mean(non_canonical_coverage))

    plt.figure(figsize=(8, 6))
    plt.plot(n_neighbourhood, averge_coverages, marker='o')
    plt.plot(n_neighbourhood, average_non_canonical_edges_in_k_neighbourhood, marker='x')
    plt.xticks(n_neighbourhood)
    plt.xlabel('k Neighbourhood around canonical edge')
    plt.ylabel('Average Coverage of k Neighbourhood')
    plt.title(f"Average Coverage of k Neighbourhood around Canonical Edges in {dataset.mode} Dataset")
    plt.grid()
    plt.savefig(f"{dataset.mode}_coverage_plot.png")