In [1]:
MODEL_TYPE='protenix'
VALIDATION=False

## Install requirements 

In [2]:
# !pip install --no-deps protenix
# !pip install biopython
# !pip install ml-collections
# !pip install biotite==1.0.1
# !pip install rdkit

if MODEL_TYPE=='protenix' and VALIDATION:
    !pip install --no-deps protenix
    !pip install biopython
    !pip install ml-collections
    !pip install biotite==1.0.1
    !pip install rdkit
!export PROTENIX_DATA_ROOT_DIR=/kaggle/input/protenix-checkpoints

In [3]:
! mkdir /af3-dev 
! ln -s /kaggle/input/protenix-checkpoints /af3-dev/release_data
! ls /af3-dev/release_data/

components.v20240608.cif		model_v0.2.0.pt
components.v20240608.cif.rdkit_mol.pkl


In [4]:
VALIDATION=False

## Helper scripts

In [5]:
import Bio

from copy import deepcopy

import pandas as pd
from Bio.PDB import Atom, Model, Chain, Residue, Structure, PDBParser
from Bio import SeqIO
import os, sys
import re
import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
time0=time.time()

print('IMPORT OK !!!!')

IMPORT OK !!!!


In [6]:
PYTHON = sys.executable
print('PYTHON',PYTHON)

RHONET_DIR=\
'/kaggle/input/data-for-demo-for-rhofold-plus-with-kaggle-msa/RhoFold-main'
#'<your downloaded rhofold repo>/RhoFold-main'

USALIGN = \
'/kaggle/working//USalign'
#'<your us align path>/USalign'

os.system('cp /kaggle/input/usalign/USalign /kaggle/working/')
os.system('sudo chmod u+x /kaggle/working//USalign')
sys.path.append(RHONET_DIR)


DATA_KAGGLE_DIR = '/kaggle/input/stanford-rna-3d-folding'


# helper ----
class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__

	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)

# visualisation helper ----
def set_aspect_equal(ax):
	x_limits = ax.get_xlim()
	y_limits = ax.get_ylim()
	z_limits = ax.get_zlim()

	# Compute the mean of each axis
	x_middle = np.mean(x_limits)
	y_middle = np.mean(y_limits)
	z_middle = np.mean(z_limits)

	# Compute the max range across all axes
	max_range = max(x_limits[1] - x_limits[0],
					y_limits[1] - y_limits[0],
					z_limits[1] - z_limits[0]) / 2.0

	# Set the new limits to ensure equal scaling
	ax.set_xlim(x_middle - max_range, x_middle + max_range)
	ax.set_ylim(y_middle - max_range, y_middle + max_range)
	ax.set_zlim(z_middle - max_range, z_middle + max_range)




# xyz df helper --------------------
def get_truth_df(target_id):
    truth_df = LABEL_DF[LABEL_DF['target_id'] == target_id]
    truth_df = truth_df.reset_index(drop=True)
    return truth_df

def parse_output_to_df(output, seq, target_id):
    df = []
    chain_data = []
    for i, res in enumerate(seq):
        d=dict(ID = target_id,
                    resname=res,
                    resid=i+1)
        for n in range(len(output)):
            d={**d, f'x_{n+1}': round(output[n,i,0].item(),3),
                     f'y_{n+1}': round(output[n,i,1].item(),3),
                     f'z_{n+1}': round(output[n,i,2].item(),3)}
        chain_data.append(d)

    if len(chain_data)!=0:
        chain_df = pd.DataFrame(chain_data)
        df.append(chain_df)
        ##print(chain_df)
    return df

def parse_pdb_to_df(pdb_file, target_id):
    parser = PDBParser()
    structure = parser.get_structure('', pdb_file)

    df = []
    for model in structure:
        for chain in model:
            print(chain)
            chain_data = []
            for residue in chain:
                # print(residue)
                if residue.get_resname() in ['A', 'U', 'G', 'C']:
                    # Check if the residue has a C1' atom
                    if 'C1\'' in residue:
                        atom = residue['C1\'']
                        xyz = atom.get_coord()
                        resname = residue.get_resname()
                        resid = residue.get_id()[1]

                        #todo detect discontinous: resid = prev_resid+1
                        #ID	resname	resid	x_1	y_1	z_1
                        chain_data.append(dict(
                            ID = target_id+'_'+str(resid),
                            resname=resname,
                            resid=resid,
                            x_1=xyz[0],
                            y_1=xyz[1],
                            z_1=xyz[2],
                        ))
                        ##print(f"Residue {resname} {resid}, Atom: {atom.get_name()}, xyz: {xyz}")

            if len(chain_data)!=0:
                chain_df = pd.DataFrame(chain_data)
                df.append(chain_df)
                ##print(chain_df)
    return df

# usalign helper --------------------
def write_target_line(
    atom_name, atom_serial, residue_name, chain_id, residue_num, x_coord, y_coord, z_coord, occupancy=1.0, b_factor=0.0, atom_type='P'
):
    """
    Writes a single line of PDB format based on provided atom information.

    Args:
        atom_name (str): Name of the atom (e.g., "N", "CA").
        atom_serial (int): Atom serial number.
        residue_name (str): Residue name (e.g., "ALA").
        chain_id (str): Chain identifier.
        residue_num (int): Residue number.
        x_coord (float): X coordinate.
        y_coord (float): Y coordinate.
        z_coord (float): Z coordinate.
        occupancy (float, optional): Occupancy value (default: 1.0).
        b_factor (float, optional): B-factor value (default: 0.0).

    Returns:
        str: A single line of PDB string.
    """
    return f'ATOM  {atom_serial:>5d}  {atom_name:<5s} {residue_name:<3s} {residue_num:>3d}    {x_coord:>8.3f}{y_coord:>8.3f}{z_coord:>8.3f}{occupancy:>6.2f}{b_factor:>6.2f}           {atom_type}\n'

def write_xyz_to_pdb(df, pdb_file, xyz_id = 1):
    resolved_cnt = 0
    with open(pdb_file, 'w') as target_file:
        for _, row in df.iterrows():
            x_coord = row[f'x_{xyz_id}']
            y_coord = row[f'y_{xyz_id}']
            z_coord = row[f'z_{xyz_id}']

            if x_coord > -1e17 and y_coord > -1e17 and z_coord > -1e17:
                resolved_cnt += 1
                target_line = write_target_line(
                    atom_name="C1'",
                    atom_serial=int(row['resid']),
                    residue_name=row['resname'],
                    chain_id='0',
                    residue_num=int(row['resid']),
                    x_coord=x_coord,
                    y_coord=y_coord,
                    z_coord=z_coord,
                    atom_type='C',
                )
                target_file.write(target_line)
    return resolved_cnt

def parse_usalign_for_tm_score(output):
    # Extract TM-score based on length of reference structure (second)
    tm_score_match = re.findall(r'TM-score=\s+([\d.]+)', output)[1]
    if not tm_score_match:
        raise ValueError('No TM score found')
    return float(tm_score_match)

def parse_usalign_for_transform(output):
    # Locate the rotation matrix section
    matrix_lines = []
    found_matrix = False

    for line in output.splitlines():
        if "The rotation matrix to rotate Structure_1 to Structure_2" in line:
            found_matrix = True
        elif found_matrix and re.match(r'^\d+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+$', line):
            matrix_lines.append(line)
        elif found_matrix and not line.strip():
            break  # Stop parsing if an empty line is encountered after the matrix

    # Parse the rotation matrix values
    rotation_matrix = []
    for line in matrix_lines:
        parts = line.split()
        row_values = list(map(float, parts[1:]))  # Skip the first column (index)
        rotation_matrix.append(row_values)

    return np.array(rotation_matrix)

def call_usalign(predict_df, truth_df, verbose=1):
    truth_pdb = '~truth.pdb'
    predict_pdb = '~predict.pdb'
    write_xyz_to_pdb(predict_df, predict_pdb, xyz_id=1)
    write_xyz_to_pdb(truth_df, truth_pdb, xyz_id=1)

    command = f'{USALIGN} {predict_pdb} {truth_pdb} -atom " C1\'" -m -'
    output = os.popen(command).read()
    if verbose==1:
        print(output)
    tm_score = parse_usalign_for_tm_score(output)
    transform = parse_usalign_for_transform(output)
    return tm_score, transform

print('HELPER OK!!!')

PYTHON /usr/bin/python3
HELPER OK!!!


In [7]:
if MODEL_TYPE=='protenix':
    
    
    from runner.batch_inference import get_default_runner
    from runner.inference import update_inference_configs, InferenceRunner

    from protenix.data.infer_data_pipeline import InferenceDataset

    np.random.seed(101)
    torch.random.manual_seed(101)
    torch.cuda.manual_seed_all(101)

    class DictDataset(InferenceDataset):
        def __init__(
            self,
            seq_list: list,
            dump_dir: str,
            id_list: list = None,
            use_msa: bool = False,
        ) -> None:

            self.dump_dir = dump_dir
            self.use_msa = use_msa
            if isinstance(id_list,type(None)):
                self.inputs = [{"sequences": 
                                [{"rnaSequence": 
                                  {"sequence": seq, 
                                   "count": 1}}],
                                "name": "query"} for seq in seq_list]
            else:
                self.inputs = [{"sequences": 
                                [{"rnaSequence": 
                                  {"sequence": seq, 
                                   "count": 1}}],
                                "name": i} for i, seq in zip(id_list,seq_list)]

In [8]:
if MODEL_TYPE=='protenix':

    from configs.configs_base import configs as configs_base
    from configs.configs_data import data_configs
    from configs.configs_inference import inference_configs
    from protenix.config.config import parse_configs

    configs_base["use_deepspeed_evo_attention"] = (
    os.environ.get("USE_DEEPSPEED_EVO_ATTENTION", False) == "true")
    configs_base["model"]["N_cycle"] = 10 #10
    configs_base["sample_diffusion"]["N_sample"] = (1 if VALIDATION else 8)
    configs_base["sample_diffusion"]["N_step"] = 200
    inference_configs['load_checkpoint_path']='/kaggle/input/protenix-checkpoints/model_v0.2.0.pt'
    # inference_configs['load_checkpoint_path']='/kaggle/input/protenix_version1/pytorch/default/1/model_v1.pt'

    configs = {**configs_base, **{"data": data_configs}, **inference_configs}

    configs = parse_configs(
            configs=configs,
            fill_required_with_null=True,
        )
    
    runner=InferenceRunner(configs)


train scheduler 16.0
inference scheduler 16.0
Diffusion Module has 16.0


  checkpoint = torch.load(checkpoint_path, self.device)


In [9]:
if VALIDATION:
    LABEL_DF = pd.read_csv('/kaggle/input/stanford-rna-3d-folding/train_labels.csv')
    LABEL_DF['target_id'] = LABEL_DF['ID'].apply(lambda x: '_'.join(x.split('_')[:-1]))
    train_df=pd.read_csv('/kaggle/input/stanford-rna-3d-folding/train_sequences.csv')


In [10]:

if MODEL_TYPE=='protenix' and VALIDATION:
    import warnings
    warnings.filterwarnings("ignore")  
    
    train_df['protenix_tm_score']=None
    dataset = DictDataset(train_df.sequence, dump_dir='output', id_list=train_df.target_id, use_msa=False)
    num_data = len(dataset)
    for i, seq in tqdm(enumerate(train_df.sequence),total=num_data):
        if train_df.loc[i,'protenix_tm_score']!=None:
            continue
        if len(seq)>300:
            continue
        target_id = train_df.loc[i,'target_id']
        truth_df = get_truth_df(target_id)
        if sum(~np.isnan(truth_df.x_1))<3:
            continue
        data, atom_array, data_error_message=dataset[i]
        if data_error_message!='':
            continue
        new_configs = update_inference_configs(configs, data["N_token"].item())
        runner.update_model_configs(new_configs)
        prediction = runner.predict(data)
        prediction=prediction['coordinate'][:,data['input_feature_dict']['atom_to_tokatom_idx']==12]       
        result = parse_output_to_df(prediction[:1], seq, target_id)[0]
        try:
            tm_score, transform = call_usalign(result, truth_df, verbose=0)
            train_df.loc[i,'protenix_tm_score']=tm_score
        except:
            pass
        if (time.time()-time0)>(12*3600-360):
            break
    train_df.to_csv('tm_scores.csv', index=False)
    print(train_df.protenix_tm_score.mean())
    display(train_df.protenix_tm_score.hist())

In [11]:
if MODEL_TYPE=='protenix' and not VALIDATION:
    test_df=pd.read_csv('/kaggle/input/stanford-rna-3d-folding/test_sequences.csv')
    import warnings
    warnings.filterwarnings("ignore")  
    
    dataset = DictDataset(test_df.sequence, dump_dir='output', id_list=test_df.target_id, use_msa=False)
    num_data = len(dataset)
    for i, seq in tqdm(enumerate(test_df.sequence),total=num_data):
        try:
            data, atom_array, data_error_message=dataset[i]
            target_id = data["sample_name"]
            assert target_id==test_df.target_id[i]
            assert data_error_message==''
            
            new_configs = update_inference_configs(configs, data["N_token"].item())
            runner.update_model_configs(new_configs)
            prediction = runner.predict(data)
            prediction=prediction['coordinate'][:,data['input_feature_dict']['atom_to_tokatom_idx']==12]

            result = parse_output_to_df(prediction, seq, target_id)[0]
        except:
            target_id==test_df.target_id[i]
            print('Failed to predict', target_id)
            result=pd.DataFrame(columns=['ID', 'resname', 'resid', 
                                         'x_1', 'y_1', 'z_1', 
                                         'x_2', 'y_2', 'z_2',
                                         'x_3', 'y_3', 'z_3', 
                                         'x_4', 'y_4', 'z_4', 
                                         'x_5', 'y_5', 'z_5'], 
                                         data=[[target_id, x, j+1] + [0.0]*15 for j, x in enumerate(seq)])
            
        result['ID']=result.apply(lambda x: x.ID + '_' + str(x.resid), axis=1)
        result.to_csv('submission.csv', index=False, mode='a', header=(i==0))
        torch.cuda.empty_cache()

    display(pd.read_csv('submission.csv'))

100%|██████████| 12/12 [50:41<00:00, 253.44s/it]


Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,...,z_5,x_6,y_6,z_6,x_7,y_7,z_7,x_8,y_8,z_8
0,R1107_1,G,1,7.400,6.808,1.849,4.616,9.902,-10.451,-9.121,...,4.893,3.818,-0.547,-5.254,-14.966,-14.775,-5.534,11.790,20.536,-0.589
1,R1107_2,G,2,9.115,11.626,-0.550,-0.489,8.299,-8.796,-13.908,...,8.900,8.644,0.394,-2.605,-11.392,-15.373,-9.747,7.360,18.574,2.145
2,R1107_3,G,3,8.500,15.795,-4.381,3.713,11.512,-4.159,-18.564,...,12.557,11.843,3.373,1.030,-8.005,-13.510,-14.000,7.790,14.946,9.019
3,R1107_4,G,4,5.493,18.587,-8.231,5.691,14.516,0.812,-21.742,...,14.770,12.487,7.101,5.070,-5.137,-9.594,-16.724,4.504,12.245,12.930
4,R1107_5,G,5,0.858,20.087,-10.997,3.365,16.921,5.649,-23.007,...,15.017,11.002,10.632,9.079,-2.891,-4.513,-17.999,-0.289,9.715,14.933
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2510,R1190_114,U,114,-22.989,-34.172,-7.109,-21.409,6.637,4.558,-0.316,...,-0.834,8.242,17.833,-0.169,-18.456,10.881,5.151,5.050,0.162,19.284
2511,R1190_115,U,115,-18.757,-32.785,-9.543,-20.557,1.653,5.169,-2.636,...,3.981,12.918,15.938,-0.826,-21.713,12.362,8.910,1.294,-2.727,21.416
2512,R1190_116,U,116,-14.818,-30.996,-13.506,-20.316,-3.463,7.504,-5.759,...,8.595,18.179,15.194,-2.694,-25.832,15.125,11.385,-1.509,-6.356,24.785
2513,R1190_117,U,117,-12.855,-27.164,-16.857,-22.443,-7.735,10.280,-6.988,...,11.141,23.201,16.895,-3.710,-30.704,17.496,11.328,-2.769,-8.394,29.684
