# RNA-FM pair-wise connectivity model

In [45]:
import torch
import fm

path= "/data/home/chenjiayang/projects/RNA-FM/redevelop/pretrained/Models/SS/RNA-FM-ResNet_PDB-All.pth"
model, alphabet = fm.downstream.build_rnafm_resnet(type="ss") #, model_location=path)

tensor([[[  0.1480,  -4.8981, -11.4271,  ..., -12.4071,  -9.5562,  -6.0911],
         [ -4.8981,   0.1253, -11.7829,  ..., -15.3051,  -9.5378, -10.4153],
         [-11.4271, -11.7829,  -6.3015,  ..., -15.4624, -14.3296, -10.6362],
         ...,
         [-12.4071, -15.3051, -15.4624,  ...,  -8.3671, -12.9738, -10.4467],
         [ -9.5562,  -9.5378, -14.3296,  ..., -12.9738,  -5.4848,  -9.9814],
         [ -6.0911, -10.4153, -10.6362,  ..., -10.4467,  -9.9814,  -5.6822]]])
torch.Size([1, 45, 45])


In [109]:
def predict_secondary_structure(seq):
    # model, alphabet = fm.downstream.build_rnafm_resnet(type="ss")
    batch_converter = alphabet.get_batch_converter()
    model.eval()  # disables dropout for deterministic results
    
    data = [("RNA1", seq.replace("T", "U"))]
    
    # Prepare data
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    input = {
        "description": batch_labels,
        "token": batch_tokens
    }

    # Secondary Structure Prediction (on CPU)
    with torch.no_grad():
        results = model(input)
    
    ss_prob_map = results["r-ss"]
    ss_prob = torch.sigmoid(ss_prob_map)[0]
    
    #print(ss_prob_map)
    #print(ss_prob_map.shape)
    return ss_prob

# Example usage
#ss_prob_map = predict_secondary_structure("GAUUCGACGGGGACUUCGGUCCUCGGACGCGGGUUCGAUUCCCGC")
ss_prob_map = predict_secondary_structure("GAUUCGACGGGGACUUCGGUCCUCGGACGCGGGUUCGAUUCCCGC")


In [87]:
ss_prob_map

tensor([[5.3692e-01, 7.4053e-03, 1.0896e-05,  ..., 4.0896e-06, 7.0754e-05,
         2.2578e-03],
        [7.4053e-03, 5.3129e-01, 7.6337e-06,  ..., 2.2547e-07, 7.2072e-05,
         2.9971e-05],
        [1.0896e-05, 7.6337e-06, 1.8302e-03,  ..., 1.9266e-07, 5.9803e-07,
         2.4029e-05],
        ...,
        [4.0896e-06, 2.2547e-07, 1.9266e-07,  ..., 2.3234e-04, 2.3203e-06,
         2.9042e-05],
        [7.0754e-05, 7.2072e-05, 5.9803e-07,  ..., 2.3203e-06, 4.1321e-03,
         4.6252e-05],
        [2.2578e-03, 2.9971e-05, 2.4029e-05,  ..., 2.9042e-05, 4.6252e-05,
         3.3944e-03]])

# Convert pairwise probability to dot bracket notation

In [79]:
import numpy as np

def pairwise_prob_to_dot_bracket(prob_matrix, threshold=0.5):
    """
    Convert a pairwise probability matrix into a Dot-Bracket notation.
    
    Args:
    - prob_matrix (np.ndarray): A NxN matrix of base-pair probabilities.
    - threshold (float): Minimum probability to consider a valid base pair.
    
    Returns:
    - str: Dot-Bracket notation for the RNA secondary structure.
    """
    n = len(prob_matrix)
    structure = ["." for _ in range(n)]  # Start with unpaired bases
    paired = set()
    
    # Flatten the upper-triangular part of the matrix to sort by probability
    pairs = [(i, j, prob_matrix[i, j]) for i in range(n) for j in range(i+1, n)]
    pairs = sorted(pairs, key=lambda x: x[2], reverse=True)  # Sort by probability (high to low)
    
    # Greedy pairing based on the highest probability
    for i, j, prob in pairs:
        if prob < threshold:
            break  # Stop if the probability is too low
        if i not in paired and j not in paired:  # Ensure valid pairing
            structure[i] = "("
            structure[j] = ")"
            paired.add(i)
            paired.add(j)

    return "".join(structure)

# Example usage
prob_matrix = np.array([
    [0.0, 0.9, 0.1, 0.0],
    [0.9, 0.0, 0.8, 0.2],
    [0.1, 0.8, 0.0, 0.7],
    [0.0, 0.2, 0.7, 0.0]
])

prob_matrix = torch.sigmoid(ss_prob_map)[0]

dot_bracket = pairwise_prob_to_dot_bracket(prob_matrix, threshold=0.95)
print(dot_bracket)  # Output: "(())." (Example, depends on input)


(..(..(((((((((..)))))))))).((((((..))..)))))


# RNA Structural Annotation Using DSSR-like Labels
To annotate each nucleotide position with a DSSR-like label, we categorize them as:

- Stem (S) → Paired bases in a helix (( )).

- Hairpin loop (H) → Unpaired bases enclosed by a single closing pair.

- Bulge (B) → Unpaired bases interrupting a stem on one side.

- Internal Loop (I) → Unpaired bases interrupting a stem on both sides.

- Multiloop (M) → Unpaired bases with multiple branching stems.

- Unstructured (U) → Unpaired bases not in any specific loop.

In [81]:
import numpy as np

def annotate_rna_structure(dot_bracket):
    """
    Annotate RNA secondary structure with DSSR-like labels.

    Args:
    - dot_bracket (str): Dot-Bracket notation of RNA structure.

    Returns:
    - list: Labels for each nucleotide position.
    """
    n = len(dot_bracket)
    labels = ["U"] * n  # Default to unstructured
    stack = []
    pairs = {}

    # Step 1: Identify base pairs using a stack
    for i, char in enumerate(dot_bracket):
        if char == "(":
            stack.append(i)
        elif char == ")":
            if stack:
                j = stack.pop()
                pairs[i] = j
                pairs[j] = i

    # Step 2: Identify structural elements
    sorted_pairs = sorted(pairs.items())
    
    # Annotate stems
    stems = []
    stem = []
    for i, j in sorted_pairs:
        if stem and i == stem[-1][0] + 1 and j == stem[-1][1] - 1:
            stem.append((i, j))
        else:
            if len(stem) > 1:
                stems.append(stem)
            stem = [(i, j)]
    if len(stem) > 1:
        stems.append(stem)

    for stem in stems:
        for i, j in stem:
            labels[i] = "S"  # Mark as stem
            labels[j] = "S"

    # Annotate loops, bulges, and internal loops
    for stem in stems:
        first, last = stem[0][0], stem[-1][1]
        loop_region = dot_bracket[first + 1:last]

        if "(" not in loop_region and ")" not in loop_region:
            for i in range(first + 1, last):
                labels[i] = "H"  # Hairpin Loop
        else:
            left_unpaired = [i for i in range(first + 1, last) if dot_bracket[i] == "."]
            right_unpaired = [i for i in range(first + 1, last) if dot_bracket[i] == "."]

            if len(left_unpaired) > 0 and len(right_unpaired) > 0:
                for i in left_unpaired + right_unpaired:
                    labels[i] = "I"  # Internal Loop
            elif len(left_unpaired) > 0 or len(right_unpaired) > 0:
                for i in left_unpaired + right_unpaired:
                    labels[i] = "B"  # Bulge

    # Detect multiloops (Regions with multiple stems connecting)
    opening_brackets = set(pairs.values())
    for i, char in enumerate(dot_bracket):
        if char == "." and i > 0 and i < n - 1:
            if dot_bracket[i - 1] in "(.)" and dot_bracket[i + 1] in "(.)":
                if dot_bracket[i - 1] == "(" or dot_bracket[i + 1] == ")":
                    labels[i] = "M"  # Multiloop

    return labels

# Example usage
dot_bracket = "(((...)))..(((..)))"
labels = annotate_rna_structure(dot_bracket)
print("".join(labels))


SSSMIMSSSUUSSSMMSSS


# Read in Fasta file

In [None]:
def read_fasta(file_path):
    sequences = {}
    with open(file_path, 'r') as file:
        sequence_name = None
        sequence_data = []
        
        for line in file:
            line = line.strip()
            if line.startswith(">"):  # header line
                if sequence_name:  # Save the previous sequence
                    sequences[sequence_name] = ''.join(sequence_data)
                sequence_name = line[1:]  # Remove ">" from the header
                sequence_data = []  # Reset sequence data
            else:
                sequence_data.append(line)
        
        if sequence_name:  # Save the last sequence
            sequences[sequence_name] = ''.join(sequence_data)

    return sequences

# Example usage:
file_path = "/Users/rkwan/Downloads/3d_banana/result/transcript.fasta"
sequences = read_fasta(file_path)



# Process the workflow

In [126]:
total_sequences = len(sequences)

with open("/Users/rkwan/Downloads/3d_banana/result/ss_features.fasta", "w") as file:
    counter = 0
    for sequence_name, sequence in sequences.items():
        #if counter >= 5:
        #    break
            
        # Print progress
        print(f"Processing: {sequence_name} {counter + 1} / {total_sequences} sequences")
            
        #print(f">{sequence_name}\n{sequence}")
        #print(f">{sequence_name}")
        #print(f">{sequence}")
        file.write(f">{sequence_name}\n")
        counter += 1

        try:
            # Call the prediction function and get the secondary structure
            ss_prob_map = predict_secondary_structure(sequence)
            dot_bracket = pairwise_prob_to_dot_bracket(ss_prob_map, threshold=0.95)
            labels = annotate_rna_structure(dot_bracket)

            # Write the labels to the file
            file.write("".join(labels) + "\n")
        
        except Exception as e:
            # If an error occurs, print the error and skip to the next sequence
            print(f"Error processing {sequence_name}: {e}")
            continue

Processing: 1 / 398 sequences
Processing: 2 / 398 sequences
Processing: 3 / 398 sequences
Processing: 4 / 398 sequences
Processing: 5 / 398 sequences
Processing: 6 / 398 sequences
Processing: 7 / 398 sequences
Processing: 8 / 398 sequences
Processing: 9 / 398 sequences
Processing: 10 / 398 sequences
Processing: 11 / 398 sequences
Processing: 12 / 398 sequences
Processing: 13 / 398 sequences
Processing: 14 / 398 sequences
Processing: 15 / 398 sequences
Processing: 16 / 398 sequences
Processing: 17 / 398 sequences
Processing: 18 / 398 sequences
Processing: 19 / 398 sequences
Processing: 20 / 398 sequences
Processing: 21 / 398 sequences
Error processing ENST00000828396: index out of range in self
Processing: 22 / 398 sequences
Processing: 23 / 398 sequences
Processing: 24 / 398 sequences
Processing: 25 / 398 sequences
Processing: 26 / 398 sequences
Processing: 27 / 398 sequences
Processing: 28 / 398 sequences
Processing: 29 / 398 sequences
Error processing ENST00000444815: index out of ra