In [1]:
import numpy as np
import pandas as pd
from Bio.Seq import Seq
from Bio import pairwise2
from scipy.spatial.transform import Rotation
from scipy.spatial.distance import pdist, squareform

class Config:
    DATA_PATH = '/kaggle/input/stanford-rna-3d-folding-2'
    MAX_RELATIVE_LENGTH_DIFF = 0.5
    ALIGNMENT_MATCH = 2
    ALIGNMENT_MISMATCH = -1
    ALIGNMENT_GAP_OPEN = -10
    ALIGNMENT_GAP_EXTEND = -0.5
    BOND_DISTANCE_TARGET = 6.0
    BOND_DISTANCE_TOL = 0.5
    MIN_NONBOND_DISTANCE = 3.8
    BASE_PAIRING_DISTANCE_IDEAL = 10.5
    BASE_PAIRING_DISTANCE_RANGE = (8.0, 14.0)
    HELIX_RADIUS = 10.0
    HELIX_RISE_PER_BASE = 2.5
    HELIX_ANGLE_STEP = 0.6
    NUM_PREDICTIONS = 5
    RANDOM_SEED_OFFSET = 1000
    DEFAULT_RELIABILITY = 0.2
    PAIRING_PROB_THRESHOLD = 0.7
    STEP_LENGTH_RANGE = (3.5, 4.5)
    NOISE_SCALE_MIN = 0.03
    TEMPLATE_WEIGHT = 0.71
    RANDOM_WEIGHT = 0.23
    TOP_TEMPLATES = 8
    LINEAR_INTERP_LIMIT = 3
    SIMILARITY_DECAY = 0.6

train = pd.read_csv(f'{Config.DATA_PATH}/train_sequences.csv')
val = pd.read_csv(f'{Config.DATA_PATH}/validation_sequences.csv')
test = pd.read_csv(f'{Config.DATA_PATH}/test_sequences.csv')
train_labels = pd.read_csv(f'{Config.DATA_PATH}/train_labels.csv')
val_labels = pd.read_csv(f'{Config.DATA_PATH}/validation_labels.csv')

def extract_structures(labels_data):
    structures = {}
    for group_label, group_data in labels_data.groupby(lambda x: labels_data['ID'][x].rsplit('_', 1)[0]):
        coords_list = []
        for _, row_data in group_data.sort_values('resid').iterrows():
            coords_list.append([row_data['x_1'], row_data['y_1'], row_data['z_1']])
        structures[group_label] = np.array(coords_list, dtype=np.float32)
    return structures

def secondary_structure_filter(seq_a, seq_b, min_bp=2):
    pairing_map = {'A': 'U', 'U': 'A', 'G': 'C', 'C': 'G'}
    count = 0
    for i in range(min(len(seq_a), len(seq_b))):
        if seq_a[i] in pairing_map and seq_b[i] == pairing_map[seq_a[i]]:
            count += 1
    return count >= min_bp

def compute_alignment_score(query_seq, template_seq):
    align_result = pairwise2.align.globalms(
        Seq(query_seq),
        template_seq,
        Config.ALIGNMENT_MATCH,
        Config.ALIGNMENT_MISMATCH,
        Config.ALIGNMENT_GAP_OPEN,
        Config.ALIGNMENT_GAP_EXTEND,
        one_alignment_only=True
    )
    if not align_result:
        return None, 0.0
    alignment = align_result[0]
    score = alignment.score / (2 * min(len(query_seq), len(template_seq)))
    return alignment, score

def morph_template_with_interpolation(dest_seq, source_seq, source_coords):
    alignment, _ = compute_alignment_score(dest_seq, source_seq)
    if not alignment:
        return generate_basic_structure(dest_seq)

    aligned_dest = str(alignment.seqA)
    aligned_source = str(alignment.seqB)

    morphed_coords = np.full((len(dest_seq), 3), np.nan, dtype=np.float32)
    dest_idx = source_idx = 0

    for a, b in zip(aligned_dest, aligned_source):
        if a != '-' and b != '-':
            if source_idx < len(source_coords):
                morphed_coords[dest_idx] = source_coords[source_idx]
            dest_idx += 1
            source_idx += 1
        elif a != '-' and b == '-':
            dest_idx += 1
        elif a == '-' and b != '-':
            source_idx += 1

    valid_mask = ~np.isnan(morphed_coords[:, 0])
    if not valid_mask.any():
        return generate_basic_structure(dest_seq)

    for dim in range(3):
        col_series = pd.Series(morphed_coords[:, dim])
        col_series = col_series.interpolate(method='linear', limit=Config.LINEAR_INTERP_LIMIT, limit_direction='both')
        morphed_coords[:, dim] = col_series.values
    
    nan_mask = np.isnan(morphed_coords[:, 0])
    if nan_mask.any():
        morphed_coords = np.nan_to_num(morphed_coords)

    return morphed_coords

def adaptive_rna_constraints(positions, rna_string, confidence=1.0):
    refined = positions.copy()
    n = len(rna_string)
    strength = 0.8 * (1.0 - min(confidence, 0.8))
    
    for i in range(n - 1):
        curr, next_pos = refined[i], refined[i+1]
        dist = np.linalg.norm(next_pos - curr)
        if dist < 5.5 or dist > 6.5:
            target = Config.BOND_DISTANCE_TARGET
            direc = (next_pos - curr)
            norm = np.linalg.norm(direc)
            if norm < 1e-10:
                direc = np.random.normal(0, 1, 3)
                direc /= np.linalg.norm(direc)
            else:
                direc = direc / norm
            adj = (target - dist) * strength
            refined[i+1] = curr + direc * (dist + adj)

    if n > 3:
        for i in range(1, n - 1):
            prev = refined[i] - refined[i-1]
            nxt = refined[i+1] - refined[i]
            norm_prev = np.linalg.norm(prev)
            norm_nxt = np.linalg.norm(nxt)
            if norm_prev > 1e-10 and norm_nxt > 1e-10:
                angle = np.arccos(np.dot(prev, nxt) / (norm_prev * norm_nxt))
                if angle > 2.5:
                    smoothed = (refined[i-1] + refined[i+1]) / 2
                    refined[i] = refined[i] * 0.3 + smoothed * 0.7

    dist_matrix = squareform(pdist(refined))
    clashes = (dist_matrix < Config.MIN_NONBOND_DISTANCE) & (dist_matrix > 0)
    for i in range(n):
        for j in range(i + 2, n):
            if clashes[i, j]:
                vec = refined[j] - refined[i]
                unit = vec / (np.linalg.norm(vec) + 1e-10)
                push = (Config.MIN_NONBOND_DISTANCE - dist_matrix[i, j]) * strength
                refined[i] -= unit * (push / 2)
                refined[j] += unit * (push / 2)

    if strength > 0.285:
        base_pairs = {'A': 'U', 'U': 'A', 'G': 'C', 'C': 'G'}
        for i in range(n):
            partner = base_pairs.get(rna_string[i])
            if not partner:
                continue
            for j in range(i + 3, min(i + 20, n)):
                if rna_string[j] == partner:
                    d = np.linalg.norm(refined[i] - refined[j])
                    if Config.BASE_PAIRING_DISTANCE_RANGE[0] < d < Config.BASE_PAIRING_DISTANCE_RANGE[1]:
                        ideal = Config.BASE_PAIRING_DISTANCE_IDEAL
                        delta = (ideal - d) * strength * 0.3
                        unit = (refined[j] - refined[i]) / (d + 1e-10)
                        refined[i] -= unit * (delta / 2)
                        refined[j] += unit * (delta / 2)
                        break

    return refined

def generate_basic_structure(sequence):
    n = len(sequence)
    coords = np.zeros((n, 3), dtype=np.float32)
    for i in range(n):
        angle = i * 0.6
        coords[i] = [
            10.0 * np.cos(angle),
            10.0 * np.sin(angle),
            i * 2.5
        ]
    return coords

def create_default_structure(seq):
    n = len(seq)
    coords = np.zeros((n, 3), dtype=np.float32)
    
    if n <= 10:
        for i in range(n):
            angle = i * 0.3
            coords[i] = [
                8.0 * np.cos(angle),
                8.0 * np.sin(angle),
                i * 3.0
            ]
        return coords
    
    base_pairs = {'G': 'C', 'C': 'G', 'A': 'U', 'U': 'A'}
    paired = {}
    
    for i in range(n):
        if i in paired:
            continue
        for j in range(i+4, min(i+15, n)):
            if j not in paired and base_pairs.get(seq[i]) == seq[j]:
                paired[i] = j
                paired[j] = i
                break
    
    current_pos = np.array([0.0, 0.0, 0.0])
    direction = np.array([1.0, 0.0, 0.0])
    
    for i in range(n):
        if i in paired and paired[i] > i:
            j = paired[i]
            if np.random.rand() < 0.6:
                vec = np.random.normal(0, 1, 3)
                vec /= np.linalg.norm(vec)
                coords[i] = current_pos + vec * 10.0
                coords[j] = current_pos - vec * 10.0
                current_pos = (coords[i] + coords[j]) / 2
                direction = np.random.normal(0, 1, 3)
                direction /= np.linalg.norm(direction)
            else:
                coords[i] = current_pos
                current_pos += direction * np.random.uniform(3.5, 4.5)
                coords[j] = current_pos
                current_pos += direction * np.random.uniform(3.5, 4.5)
        else:
            coords[i] = current_pos
            if np.random.rand() < 0.2:
                rot_angle = np.random.uniform(-0.4, 0.4)
                rot_axis = np.random.normal(0, 1, 3)
                rot_axis /= np.linalg.norm(rot_axis)
                rot = Rotation.from_rotvec(rot_angle * rot_axis)
                direction = rot.apply(direction)
            current_pos += direction * np.random.uniform(3.5, 4.5)
    
    return coords

def fabricate_rna_conformation(seq, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    n = len(seq)
    coords = np.zeros((n, 3), dtype=np.float32)
    base_pairs = {'G': 'C', 'C': 'G', 'A': 'U', 'U': 'A'}
    
    if n <= 3:
        for i in range(n):
            coords[i] = [i * 6.0, 0.0, 0.0]
        return coords
    
    start_helix = min(4, n // 2)
    for i in range(start_helix):
        angle = i * Config.HELIX_ANGLE_STEP * 0.8
        coords[i] = [
            Config.HELIX_RADIUS * 0.7 * np.cos(angle),
            Config.HELIX_RADIUS * 0.7 * np.sin(angle),
            i * Config.HELIX_RISE_PER_BASE * 1.2
        ]
    
    direction = np.array([0.0, 0.1, 0.995])
    direction /= np.linalg.norm(direction)
    
    for i in range(start_helix, n):
        current = seq[i]
        paired = False
        partner_idx = -1
        
        for j in range(max(0, i - 12), i):
            if seq[j] == base_pairs.get(current, None):
                paired = True
                partner_idx = j
                break
        
        if paired and (i - partner_idx <= 8) and (np.random.rand() < Config.PAIRING_PROB_THRESHOLD):
            partner_pos = coords[partner_idx]
            if i - partner_idx > 3:
                vec = coords[i-1] - coords[max(0, i-2)]
                if np.linalg.norm(vec) > 1e-6:
                    perp = np.cross(vec, np.array([0, 0, 1]))
                    perp /= np.linalg.norm(perp) + 1e-10
                    coords[i] = coords[i-1] + perp * (Config.BASE_PAIRING_DISTANCE_IDEAL * 0.6)
                else:
                    random_dir = np.random.normal(0, 1, 3)
                    random_dir /= np.linalg.norm(random_dir)
                    coords[i] = coords[i-1] + random_dir * (Config.BASE_PAIRING_DISTANCE_IDEAL * 0.6)
            else:
                center_vec = np.mean(coords[max(0, i-3):i], axis=0) - partner_pos
                if np.linalg.norm(center_vec) > 1e-6:
                    center_vec /= np.linalg.norm(center_vec)
                else:
                    center_vec = np.random.normal(0, 1, 3)
                    center_vec /= np.linalg.norm(center_vec)
                distance = Config.BASE_PAIRING_DISTANCE_IDEAL * 0.8
                coords[i] = partner_pos + center_vec * distance
            direction = np.random.normal(0, 0.2, 3)
            direction /= np.linalg.norm(direction) + 1e-10
        else:
            if np.random.rand() < 0.25:
                rot_angle = np.random.uniform(-0.3, 0.3)
                rot_axis = np.random.normal(0, 1, 3)
                rot_axis /= np.linalg.norm(rot_axis)
                rot = Rotation.from_rotvec(rot_angle * rot_axis)
                direction = rot.apply(direction)
            else:
                direction += np.random.normal(0, 0.1, 3)
                direction /= np.linalg.norm(direction) + 1e-10
            
            step = np.random.uniform(*Config.STEP_LENGTH_RANGE)
            coords[i] = coords[i - 1] + direction * step
    
    return coords

def find_similar_sequences(query_seq, train_seqs_df, train_coords_dict, temporal_cutoff=None, top_n=8):
    similar_seqs = []
    
    if temporal_cutoff:
        filtered_train_seqs = train_seqs_df[train_seqs_df['temporal_cutoff'] < temporal_cutoff]
    else:
        filtered_train_seqs = train_seqs_df
    
    for _, row in filtered_train_seqs.iterrows():
        target_id = row['target_id']
        train_seq = row['sequence']
        
        if target_id not in train_coords_dict:
            continue
        
        if abs(len(train_seq) - len(query_seq)) / max(len(train_seq), len(query_seq)) > Config.MAX_RELATIVE_LENGTH_DIFF:
            continue
        
        if not secondary_structure_filter(query_seq, train_seq):
            continue
        
        _, score = compute_alignment_score(query_seq, train_seq)
        similar_seqs.append((target_id, train_seq, score, train_coords_dict[target_id]))
    
    similar_seqs.sort(key=lambda x: x[2], reverse=True)
    return similar_seqs[:top_n]

def compute_structures(seq, identifier, template_df, struct_dict, num_outputs=5, date_limit=None):
    predictions = []

    similar = find_similar_sequences(seq, template_df, struct_dict, 
                                     temporal_cutoff=date_limit, top_n=Config.TOP_TEMPLATES)
    
    for tid, tseq, sim_score, tcoords in similar:
        morphed = morph_template_with_interpolation(seq, tseq, tcoords)
        refined = adaptive_rna_constraints(morphed, seq, confidence=sim_score)
        
        rand_scale = max(Config.NOISE_SCALE_MIN, Config.SIMILARITY_DECAY - sim_score)
        if sim_score > 0.5:
            rand_scale *= 0.5
        
        final = refined + np.random.normal(0, rand_scale, refined.shape)
        predictions.append(final)
        
        if len(predictions) >= num_outputs:
            break

    while len(predictions) < num_outputs:
        seed = (hash(identifier) % 10000) + len(predictions) * Config.RANDOM_SEED_OFFSET
        fake = fabricate_rna_conformation(seq, seed=seed)
        refined = adaptive_rna_constraints(fake, seq, confidence=Config.DEFAULT_RELIABILITY)
        
        scale_factor = np.random.uniform(0.8, 1.2)
        if len(predictions) > 0:
            centroid_pred = np.mean(predictions[0], axis=0)
            centroid_new = np.mean(refined, axis=0)
            refined = centroid_pred + (refined - centroid_new) * scale_factor
        
        predictions.append(refined)

    return predictions[:num_outputs]

train_structs = extract_structures(train_labels)
val_structs = extract_structures(val_labels)

prediction_records = []

for _, row in test.iterrows():
    seq_id = row['target_id']
    sequence = row['sequence']
    cutoff = row.get('temporal_cutoff', None)

    structs = compute_structures(
        sequence,
        seq_id,
        train,
        train_structs,
        num_outputs=Config.NUM_PREDICTIONS,
        date_limit=cutoff
    )

    for resid in range(len(sequence)):
        record = {
            'ID': f"{seq_id}_{resid + 1}",
            'resname': sequence[resid],
            'resid': resid + 1
        }
        for model_idx in range(Config.NUM_PREDICTIONS):
            record[f'x_{model_idx + 1}'] = structs[model_idx][resid][0]
            record[f'y_{model_idx + 1}'] = structs[model_idx][resid][1]
            record[f'z_{model_idx + 1}'] = structs[model_idx][resid][2]
        prediction_records.append(record)

submission = pd.DataFrame(prediction_records)
col_order = ['ID', 'resname', 'resid']
for m in range(1, Config.NUM_PREDICTIONS + 1):
    for axis in ['x', 'y', 'z']:
        col_order.append(f'{axis}_{m}')

submission = submission[col_order]
submission.to_csv('submission.csv', index=False)
submission.head()

  train_labels = pd.read_csv(f'{Config.DATA_PATH}/train_labels.csv')
  angle = np.arccos(np.dot(prev, nxt) / (norm_prev * norm_nxt))


Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,8ZNQ_1,A,1,-2.947787,-19.630668,4.969795,-40.857909,-3.706493,8.339082,-25.113084,-9.520026,17.743116,193.425554,143.816444,155.25311,139.89916,184.437677,168.722252
1,8ZNQ_2,C,2,-1.274205,-18.2618,9.504327,-41.485874,-4.939674,14.464054,-25.259269,-5.241078,12.352264,187.592346,141.593698,155.539322,139.900093,184.456292,167.394753
2,8ZNQ_3,C,3,1.874546,-14.485601,11.598462,-37.804577,-5.55743,18.15152,-21.64378,-2.169601,9.018918,184.706497,138.554474,153.703002,140.37023,186.171554,166.603529
3,8ZNQ_4,G,4,5.445153,-9.419131,11.734588,-32.959959,-6.129994,19.713358,-16.456462,-0.808309,9.652559,181.066909,134.544364,150.625607,142.853686,185.641785,172.865775
4,8ZNQ_5,U,5,7.632892,-6.248647,7.750405,-26.517858,-7.015588,19.370432,-12.125967,-3.46504,10.4082,180.547876,132.14523,145.984978,144.811847,190.722206,176.433373
