In [None]:
import pandas as pd
import numpy as np
import random
import time
import warnings
from Bio import pairwise2
from Bio.Seq import Seq

warnings.filterwarnings('ignore')


# Comprehensive RNA nucleotide mapping for modified and non-standard bases
# Maps modified nucleotides to their closest standard base for alignment purposes
NUCLEOTIDE_MAPPING = {
    # Standard bases
    'A': 'A', 'U': 'U', 'G': 'G', 'C': 'C',
    
    # Inosine and related (map to A)
    'I': 'A',      # Inosine (deaminated adenosine)
    '1MA': 'A',    # N1-methyladenosine
    'm1A': 'A',    # N1-methyladenosine (alternative notation)
    'm6A': 'A',    # N6-methyladenosine
    '6MA': 'A',    # N6-methyladenosine (alternative notation)
    'A2M': 'A',    # 2'-O-methyladenosine
    'Am': 'A',     # 2'-O-methyladenosine (alternative)
    'AOM': 'A',    # 2'-O-methyladenosine (alternative)
    
    # Pseudouridine and uridine modifications (map to U)
    'PSU': 'U',    # Pseudouridine (Ψ)
    'PSI': 'U',    # Pseudouridine (alternative notation)
    'T': 'U',      # Ribothymidine (5-methyluridine)
    'm5U': 'U',    # 5-methyluridine
    '5MU': 'U',    # 5-methyluridine (alternative)
    'Um': 'U',     # 2'-O-methyluridine
    'UOM': 'U',    # 2'-O-methyluridine (alternative)
    'D': 'U',      # Dihydrouridine
    'DHU': 'U',    # Dihydrouridine (alternative)
    's2U': 'U',    # 2-thiouridine
    's4U': 'U',    # 4-thiouridine
    'U2M': 'U',    # 2'-O-methyluridine
    
    # Guanosine modifications (map to G)
    'M2G': 'G',    # N2-methylguanosine
    'm2G': 'G',    # N2-methylguanosine (alternative notation)
    'm7G': 'G',    # 7-methylguanosine
    '7MG': 'G',    # 7-methylguanosine (alternative)
    'Gm': 'G',     # 2'-O-methylguanosine
    'GOM': 'G',    # 2'-O-methylguanosine (alternative)
    'G2M': 'G',    # 2'-O-methylguanosine (alternative)
    
    # Cytidine modifications (map to C)
    '5MC': 'C',    # 5-methylcytidine
    'm5C': 'C',    # 5-methylcytidine (alternative notation)
    'Cm': 'C',     # 2'-O-methylcytidine
    'COM': 'C',    # 2'-O-methylcytidine (alternative)
    'C2M': 'C',    # 2'-O-methylcytidine (alternative)
    'hm5C': 'C',   # 5-hydroxymethylcytidine
    'ac4C': 'C',   # N4-acetylcytidine
    
    # Ambiguous/degenerate bases (IUPAC notation - map to most common)
    # Note: 'D' above is Dihydrouridine (modification), not IUPAC 'D' (Not C)
    'N': 'A',      # Any nucleotide (default to A)
    'X': 'A',      # Unknown nucleotide
    '?': 'A',      # Unknown nucleotide
    '-': 'A',      # Gap character (shouldn't appear in sequences, but handle gracefully)
    
    # IUPAC degenerate bases (when used in sequences)
    'Y': 'U',      # Pyrimidine (C or U, default to U)
    'R': 'A',      # Purine (A or G, default to A)
    'W': 'A',      # Weak (A or U, default to A)
    'S': 'G',      # Strong (G or C, default to G)
    'K': 'G',      # Keto (G or U, default to G)
    'M': 'A',      # Amino (A or C, default to A)
    'B': 'C',      # Not A (C, G, or U, default to C)
    # 'D' is mapped above as Dihydrouridine (U) - RNA modification takes precedence
    'H': 'A',      # Not G (A, C, or U, default to A)
    'V': 'A',      # Not U (A, C, or G, default to A)
}

def clean_sequence(seq):
    return "".join([NUCLEOTIDE_MAPPING.get(b, 'A') for b in seq])


#====Phase 1====
DATA_PATH = '/kaggle/input/stanford-rna-3d-folding-2/'
train_seqs = pd.read_csv(DATA_PATH + 'train_sequences.csv')
test_seqs = pd.read_csv(DATA_PATH + 'test_sequences.csv')
train_labels = pd.read_csv(DATA_PATH + 'train_labels.csv')

def process_labels(labels_df):
    coords_dict = {}
    for id_prefix, group in labels_df.groupby(lambda x: labels_df['ID'][x].rsplit('_', 1)[0]):
        coords_dict[id_prefix] = group.sort_values('resid')[['x_1', 'y_1', 'z_1']].values
    return coords_dict

train_coords_dict = process_labels(train_labels)



#====Phase 2====
def adaptive_rna_constraints(coordinates, sequence, confidence=1.0):
    refined_coords = coordinates.copy()
    n = len(sequence)
    strength = 0.68 * (1.0 - min(confidence, 0.96))
    
    for _ in range(2):
        for i in range(n - 1):
            p1, p2 = refined_coords[i], refined_coords[i+1]
            dist = np.linalg.norm(p2 - p1)
            if dist > 0:
                adj = (5.95 - dist) * strength * 0.45
                refined_coords[i+1] += (p2 - p1) / dist * adj
            
            if i < n - 2:
                p3 = refined_coords[i+2]
                dist2 = np.linalg.norm(p3 - p1)
                if dist2 > 0:
                    adj2 = (10.2 - dist2) * strength * 0.25
                    refined_coords[i+2] += (p3 - p1) / dist2 * adj2
    return refined_coords


#====Phase 3=====
def adapt_template_to_query(query_seq, template_seq, template_coords):
    q_c = clean_sequence(query_seq)
    t_c = clean_sequence(template_seq)
    
    alignments = pairwise2.align.globalms(Seq(q_c), Seq(t_c), 2, -1, -7, -0.25, one_alignment_only=True)
    if not alignments: return np.zeros((len(query_seq), 3))
    
    a_q, a_t = alignments[0].seqA, alignments[0].seqB
    new_coords = np.full((len(query_seq), 3), np.nan)
    q_idx, t_idx = 0, 0
    
    for cq, ct in zip(a_q, a_t):
        if cq != '-' and ct != '-':
            if t_idx < len(template_coords): new_coords[q_idx] = template_coords[t_idx]
            q_idx += 1; t_idx += 1
        elif cq != '-': q_idx += 1
        elif ct != '-': t_idx += 1

    for i in range(len(new_coords)):
        if np.isnan(new_coords[i, 0]):
            prev_v = next((j for j in range(i-1, -1, -1) if not np.isnan(new_coords[j, 0])), -1)
            next_v = next((j for j in range(i+1, len(new_coords)) if not np.isnan(new_coords[j, 0])), -1)
            if prev_v >= 0 and next_v >= 0:
                w = (i - prev_v) / (next_v - prev_v)
                new_coords[i] = (1-w)*new_coords[prev_v] + w*new_coords[next_v]
            elif prev_v >= 0: new_coords[i] = new_coords[prev_v] + [3.5, 0, 0]
            elif next_v >= 0: new_coords[i] = new_coords[next_v] + [3.5, 0, 0]
            else: new_coords[i] = [i*3.5, 0, 0]
    return np.nan_to_num(new_coords)

#=====Phase 4=====
def find_similar_sequences(query_seq, train_seqs_df, train_coords_dict, top_n=5, max_candidates=300):
    """
    MEMORY OPTIMIZATION: Limit candidates to process (default 300)
    This prevents processing thousands of alignments and creating temporary objects
    """
    similar = []
    q_c = clean_sequence(query_seq)
    query_len = len(query_seq)
    candidates_checked = 0
    
    for _, row in train_seqs_df.iterrows():
        t_id, t_seq = row['target_id'], row['sequence']
        if t_id not in train_coords_dict: continue
        
        # Quick length filter first (cheap check)
        len_diff = abs(len(t_seq) - query_len) / max(len(t_seq), query_len)
        if len_diff > 0.4: continue
        
        # MEMORY OPTIMIZATION: Stop after checking max_candidates
        candidates_checked += 1
        if candidates_checked > max_candidates:
            break
        
        t_c = clean_sequence(t_seq)
        alns = pairwise2.align.globalms(Seq(q_c), Seq(t_c), 2, -1, -7, -0.25, one_alignment_only=True)
        if alns:
            score = alns[0].score / (2 * min(query_len, len(t_seq)))
            similar.append((t_id, t_seq, score, train_coords_dict[t_id]))
        
        # EARLY EXIT: If we have enough good matches, stop early
        if len(similar) >= top_n * 3:  # Collect 3x to ensure quality
            break
    
    similar.sort(key=lambda x: x[2], reverse=True)
    return similar[:top_n]

def predict_rna_structures(sequence, target_id, train_seqs_df, train_coords_dict, n_predictions=5):
    predictions = []
    similar_seqs = find_similar_sequences(sequence, train_seqs_df, train_coords_dict, top_n=n_predictions)
    
    for i in range(n_predictions):
        if i < len(similar_seqs):
            t_id, t_seq, sim, t_coords = similar_seqs[i]
            adapted = adapt_template_to_query(sequence, t_seq, t_coords)
            refined = adaptive_rna_constraints(adapted, sequence, confidence=sim)
            
            # ШУМ: Слот 0 - чистый, остальные - микро-шум
            noise = 0.0 if i == 0 else max(0.006, (0.38 - sim) * 0.07)
            if noise > 0: refined += np.random.normal(0, noise, refined.shape)
            predictions.append(refined)
        else:
            n = len(sequence)
            coords = np.zeros((n, 3))
            for j in range(1, n): coords[j] = coords[j-1] + [4.0, 0, 0]
            predictions.append(coords)
    return predictions

#====Phase 5====
# MEMORY OPTIMIZATION: Write incrementally instead of accumulating in memory
import gc

cols = ['ID', 'resname', 'resid'] + [f'{c}_{i}' for i in range(1,6) for c in ['x','y','z']]
first_write = True
start_time = time.time()

for idx, row in test_seqs.iterrows():
    if idx % 10 == 0: 
        print(f"Processing {idx} | {time.time()-start_time:.1f}s")
        # MEMORY OPTIMIZATION: Force garbage collection periodically
        gc.collect()
    
    tid, seq = row['target_id'], row['sequence']
    preds = predict_rna_structures(seq, tid, train_seqs, train_coords_dict)
    
    # Build batch for this sequence only
    batch_rows = []
    for j in range(len(seq)):
        res = {'ID': f"{tid}_{j+1}", 'resname': seq[j], 'resid': j+1}
        for i in range(5):
            res[f'x_{i+1}'], res[f'y_{i+1}'], res[f'z_{i+1}'] = preds[i][j]
        batch_rows.append(res)
    
    # MEMORY OPTIMIZATION: Write immediately instead of accumulating
    batch_df = pd.DataFrame(batch_rows)
    batch_df[cols].to_csv('submission.csv', mode='a' if not first_write else 'w', 
                          header=first_write, index=False)
    first_write = False
    
    # MEMORY OPTIMIZATION: Clean up immediately after each sequence
    del preds, batch_rows, batch_df
    if idx % 5 == 0:  # More frequent GC every 5 sequences
        gc.collect()

print("ALL COMPLETED. SUBMIT!!")