In [1]:
import pandas as pd

df = pd.read_csv(
    "/Users/eshitagupta/Study/rna_project/bprna_preprocessed_subset.csv",
    sep=",",
    engine="python",
    quoting=3,        # ignore quotes weirdness
    on_bad_lines="skip"
)
print("rows:", len(df))
print(df.columns.tolist())

rows: 1000
['id', 'sequence', 'secondary_structure', 'structural_annotation', 'functional_annotation']


In [2]:
print("df rows:", len(df))
print(df.columns.tolist())
print(df.head(3))

df rows: 1000
['id', 'sequence', 'secondary_structure', 'structural_annotation', 'functional_annotation']
            id                                           sequence  \
0  bpRNA_CRW_1  ACACAUGCAAGCGAACGUGAUCUCCAGCUUGCUGGGGGAUUAGUGG...   
1  bpRNA_CRW_2  AACACAUGCAAGUCGAACGAUGAUCUCCAGCUUGCUGGGGGAUUAG...   
2  bpRNA_CRW_3  CGAACGCUGGCGGCGUGCUUAACACAUGCAAGUCGAACGGAAAGGC...   

                                 secondary_structure  \
0  .(((.(((..((..((((.(((((.((....))))))))..))))....   
1  ..(((.(((..(((..(((((.(((((.((....))))))))).))...   
2  (.((((((.(((((((((....(((.(((..(((..(((((..(((...   

                               structural_annotation  \
0  ESSSBSSSMMSSBBSSSSBSSSSSBSSHHHHSSSSSSSSBBSSSSB...   
1  EESSSBSSSMMSSSBBSSSSSBSSSSSBSSHHHHSSSSSSSSSBSS...   
2  SBSSSSSSMSSSSSSSSSIIIISSSBSSSMMSSSBBSSSSSBBSSS...   

                               functional_annotation  
0  NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...  
1  NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN...  
2 

In [3]:
import re

OPEN = set("([{<")
CLOSE = set(")]}>")

def normalize_structure_to_3class(db: str):
    db = re.sub(r"\s+", "", str(db).strip())
    out = []
    for ch in db:
        if ch == '.':
            out.append('.')
        elif ch in OPEN or ch == '(':
            out.append('(')
        elif ch in CLOSE or ch == ')':
            out.append(')')
        else:
            # any other symbol (letters/digits/pseudoknot marks) -> unpaired for POC
            out.append('.')
    return "".join(out)

def clean_row(seq, db):
    seq = re.sub(r"\s+", "", str(seq).strip().upper().replace("T", "U"))
    db3 = normalize_structure_to_3class(db)

    # Keep only RNA chars (AUGCN)
    if re.search(r"[^AUGCN]", seq):
        return None

    # Now db is guaranteed to be only . ( ) after mapping
    if re.search(r"[^().]", db3):
        return None

    if len(seq) != len(db3):
        return None

    return seq, db3


In [4]:
out = clean_row(df["sequence"].iloc[0], df["secondary_structure"].iloc[0])
print(out is None)
print("unique db chars:", set(out[1]))
print("len:", len(out[0]), len(out[1]))

False
unique db chars: {')', '.', '('}
len: 1434 1434


In [5]:
pairs = []
for _, r in df.iterrows():
    out = clean_row(r["sequence"], r["secondary_structure"])
    if out is not None:
        pairs.append(out)

print("usable pairs:", len(pairs))


usable pairs: 989


In [6]:
MAX_LEN = 2000
pairs_small = [(s, d) for (s, d) in pairs if len(s) <= MAX_LEN]
print("pairs <= 500:", len(pairs_small))

pairs <= 500: 989


In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

VOCAB = {'A':0, 'U':1, 'G':2, 'C':3, 'N':4}
LAB   = {'.':0, '(':1, ')':2}
INV_LAB = {v:k for k,v in LAB.items()}

PAD_X = 5      # padding token for sequence
PAD_Y = -100   # ignored index for loss

class BPRNADataset(Dataset):
    def __init__(self, pairs_list):
        self.items = pairs_list

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        seq, db = self.items[idx]
        x = torch.tensor([VOCAB.get(ch, VOCAB['N']) for ch in seq], dtype=torch.long)
        y = torch.tensor([LAB[ch] for ch in db], dtype=torch.long)
        return x, y, seq, db

def collate(batch):
    xs, ys, seqs, dbs = zip(*batch)
    max_len = max(x.size(0) for x in xs)

    x_pad = torch.full((len(xs), max_len), PAD_X, dtype=torch.long)
    y_pad = torch.full((len(xs), max_len), PAD_Y, dtype=torch.long)
    mask  = torch.zeros((len(xs), max_len), dtype=torch.bool)

    for i, (x, y) in enumerate(zip(xs, ys)):
        L = x.size(0)
        x_pad[i, :L] = x
        y_pad[i, :L] = y
        mask[i, :L] = True

    return x_pad, y_pad, mask, seqs, dbs

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# Use the filtered set for fast POC
pairs_used = pairs_small  # or pairs if you want full

train_pairs, val_pairs = train_test_split(pairs_used, test_size=0.1, random_state=42)

train_loader = DataLoader(BPRNADataset(train_pairs), batch_size=32, shuffle=True,  collate_fn=collate)
val_loader   = DataLoader(BPRNADataset(val_pairs),   batch_size=32, shuffle=False, collate_fn=collate)

print("train:", len(train_pairs), "val:", len(val_pairs))


train: 890 val: 99


In [8]:
import torch.nn as nn

class BiLSTMDotBracket(nn.Module):
    def __init__(self, vocab_size=6, emb=32, hidden=128, n_classes=3):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb, padding_idx=PAD_X)
        self.lstm = nn.LSTM(
            input_size=emb,
            hidden_size=hidden,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.fc = nn.Linear(hidden*2, n_classes)

    def forward(self, x):
        e = self.emb(x)       # [B,L,emb]
        h, _ = self.lstm(e)   # [B,L,2H]
        return self.fc(h)     # [B,L,3]


In [9]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiLSTMDotBracket(emb=32, hidden=128).to(device)

opt = torch.optim.Adam(model.parameters(), lr=2e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_Y)

In [10]:
def dotbracket_to_pairs(db: str):
    stack = []
    pairs = set()
    for i, ch in enumerate(db):
        if ch == '(':
            stack.append(i)
        elif ch == ')':
            if stack:
                j = stack.pop()
                pairs.add((j, i))
    return pairs

def masked_token_accuracy(logits, y_true, mask):
    pred = logits.argmax(dim=-1)
    correct = ((pred == y_true) & mask).sum().item()
    total = mask.sum().item()
    return correct / max(total, 1)

def decode_db(pred_ids, mask):
    out = []
    for b in range(pred_ids.size(0)):
        L = int(mask[b].sum().item())
        out.append("".join(INV_LAB[int(i)] for i in pred_ids[b, :L]))
    return out

def pair_f1(pred_dbs, true_dbs):
    tp = fp = fn = 0
    for p_db, t_db in zip(pred_dbs, true_dbs):
        P = dotbracket_to_pairs(p_db)
        T = dotbracket_to_pairs(t_db)
        tp += len(P & T)
        fp += len(P - T)
        fn += len(T - P)
    prec = tp / max(tp + fp, 1)
    rec  = tp / max(tp + fn, 1)
    f1   = 2*prec*rec / max(prec + rec, 1e-12)
    return prec, rec, f1


In [12]:
import torch.nn.functional as F

EPOCHS = 10

for epoch in range(1, EPOCHS + 1):
    # ---- train ----
    model.train()
    tr_loss = tr_acc = 0.0
    n = 0

    for x, y, mask, seqs, dbs in train_loader:
        x, y, mask = x.to(device), y.to(device), mask.to(device)
        opt.zero_grad()
        logits = model(x)  # [B,L,3]
        loss = loss_fn(logits.reshape(-1, 3), y.reshape(-1))
        loss.backward()
        opt.step()

        tr_loss += loss.item()
        tr_acc  += masked_token_accuracy(logits, y, mask)
        n += 1

    # ---- val ----
    model.eval()
    va_loss = va_acc = 0.0
    all_pred_db, all_true_db = [], []
    m = 0

    with torch.no_grad():
        for x, y, mask, seqs, dbs in val_loader:
            x, y, mask = x.to(device), y.to(device), mask.to(device)
            logits = model(x)
            loss = loss_fn(logits.reshape(-1, 3), y.reshape(-1))

            va_loss += loss.item()
            va_acc  += masked_token_accuracy(logits, y, mask)

            pred_ids = logits.argmax(dim=-1).cpu()
            all_pred_db.extend(decode_db(pred_ids, mask.cpu()))
            all_true_db.extend(dbs)
            m += 1

    p, r, f1 = pair_f1(all_pred_db, all_true_db)

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {tr_loss/n:.4f} acc {tr_acc/n:.4f} | "
        f"val loss {va_loss/m:.4f} acc {va_acc/m:.4f} | "
        f"pairF1 {f1:.4f} (P {p:.4f}, R {r:.4f})"
    )

Epoch 01 | train loss 0.9845 acc 0.4983 | val loss 0.9267 acc 0.5492 | pairF1 0.0059 (P 0.0071, R 0.0051)
Epoch 02 | train loss 0.7923 acc 0.6367 | val loss 0.6497 acc 0.7274 | pairF1 0.0629 (P 0.0638, R 0.0621)
Epoch 03 | train loss 0.4910 acc 0.8013 | val loss 0.3909 acc 0.8437 | pairF1 0.1616 (P 0.1637, R 0.1596)
Epoch 04 | train loss 0.2783 acc 0.8998 | val loss 0.2212 acc 0.9286 | pairF1 0.4385 (P 0.4407, R 0.4364)
Epoch 05 | train loss 0.1531 acc 0.9545 | val loss 0.1302 acc 0.9644 | pairF1 0.6053 (P 0.6073, R 0.6033)
Epoch 06 | train loss 0.0909 acc 0.9772 | val loss 0.0835 acc 0.9802 | pairF1 0.7761 (P 0.7791, R 0.7731)
Epoch 07 | train loss 0.0610 acc 0.9860 | val loss 0.0629 acc 0.9842 | pairF1 0.8520 (P 0.8568, R 0.8473)
Epoch 08 | train loss 0.0468 acc 0.9885 | val loss 0.0487 acc 0.9873 | pairF1 0.8694 (P 0.8730, R 0.8658)
Epoch 09 | train loss 0.0415 acc 0.9895 | val loss 0.0561 acc 0.9846 | pairF1 0.8519 (P 0.8574, R 0.8465)
Epoch 10 | train loss 0.1204 acc 0.9622 | val 

In [13]:
print("pairs_used:", len(pairs_used))
print("train_pairs:", len(train_pairs), "val_pairs:", len(val_pairs))
print("train batches:", len(train_loader), "val batches:", len(val_loader))

pairs_used: 989
train_pairs: 890 val_pairs: 99
train batches: 28 val batches: 4


In [14]:
for i in range(3):
    print("\nExample", i)
    print("TRUE:", all_true_db[i][:120])
    print("PRED:", all_pred_db[i][:120])


Example 0
TRUE: (((.(((..(((..(((((.((((((((....)))))))))).))))))......(((......((((((((..((...(((((((.((((....(((((((....))))))).....))
PRED: (((((((....(..)))...(((((((((.....))))))))..)))))......(((......((((((((..((...(((((((.((((....(((((((....))))))).....))

Example 1
TRUE: ((.((((((.(((((((((....(((.(((..(((..(((((.((((((....)))))))).))))))......(((......((((((((..((...(((((((.((((....((((((
PRED: ((.((((((.(((((((((....(((.(((..(((..(((((.((((((....)))))))).))))))......(((......((((((((..((...(((((((.((((....((((((

Example 2
TRUE: ((((....(((.(((..(((..(((((.((((((....)))))))).))))))......(((......((((((((..((...(((((((.((((....(((((((....)))))))...
PRED: (((.....(((.(((..(((..(((((.((((((....)))))))).))))))......(((......((((((((..((...(((((((.((((....(((((((....)))))))...


In [15]:
def balance_parentheses(db):
    db = list(db)
    stack = 0
    for i, ch in enumerate(db):
        if ch == '(':
            stack += 1
        elif ch == ')':
            if stack == 0:
                db[i] = '.'   # remove invalid close
            else:
                stack -= 1
    # remove extra opens from the end
    for i in range(len(db)-1, -1, -1):
        if stack == 0:
            break
        if db[i] == '(':
            db[i] = '.'
            stack -= 1
    return "".join(db)

In [16]:
pred_db = [balance_parentheses(s) for s in decode_db(pred_ids, mask.cpu())]

In [17]:
torch.save(model.state_dict(), "bprna_layer1_bilstm.pt")
print("saved bprna_layer1_bilstm.pt")

saved bprna_layer1_bilstm.pt
