In [None]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import random
import string
import rdkit
import yaml

In [None]:
def prepare_af3_input(system_id, data, json_dir):
    system_dict = dict()
    system_dict["name"] = system_id
    system_dict["dialect"] = "alphafold3"
    system_dict["version"] = 1
    system_dict['modelSeeds'] = [random.randint(0, 2**32 - 1) for _ in range(5)]
    sequences_list = list()
    used_ids = set()
    chain_ids = [c for c in string.ascii_uppercase]
    protein_data = data["sequences"]

    for idx, (chain, seq) in enumerate(protein_data.items()):
        protein_chain = dict()
        protein_chain["protein"] = dict()
        protein_chain["protein"]["id"] = chain_ids[idx]
        protein_chain["protein"]["sequence"] = seq
        sequences_list.append(protein_chain)
        used_ids.add(chain_ids[idx])

    ligand_ids = [c for c in string.ascii_uppercase if c not in used_ids]

    ligand_data = data["smiles"]
    ccd_data = data["ccd_codes"]
    
    for idx, (smiles, ccd_code) in enumerate(zip(ligand_data, ccd_data)):
        ligand_chain = dict()
        ligand_chain["ligand"] = dict()
        ligand_chain["ligand"]["id"] = ligand_ids[idx]
        ligand_chain["ligand"]["smiles"] = smiles
        sequences_list.append(ligand_chain)

    system_dict["sequences"] = sequences_list
    
    output_json_path = os.path.join(json_dir, f"{system_id}.json")
    with open(output_json_path, 'w') as f:
        json.dump(system_dict, f, indent=2)

In [None]:
def prepare_chai_input(system_id, data, fasta_dir):
    lines = []
    protein_data = data["sequences"]
    chain_ids = [c for c in string.ascii_uppercase]
    used_ids = set()

    for idx, (chain, seq) in enumerate(protein_data.items()):
        lines.append(f">protein|name={chain_ids[idx]}")
        lines.append(seq)
        used_ids.add(chain_ids[idx])

    ligand_data = data["smiles"]
    ccd_data = data["ccd_codes"]
    
    ligand_ids = [c for c in string.ascii_uppercase if c not in used_ids]
    
    for idx, (smiles, ccd_code) in enumerate(zip(ligand_data, ccd_data)):
        lines.append(f">ligand|name={ligand_ids[idx]}")
        lines.append(smiles)
    
    fasta_string = "\n".join(lines) + "\n"

    output_fasta_path = os.path.join(fasta_dir, f"{system_id}.fasta")
    with open(output_fasta_path, 'w') as f:
        f.write(fasta_string)

In [None]:
def prepare_protenix_input(system_id, data, json_dir, msa_dir):
    system_dict = dict()
    system_dict["name"] = system_id
    sequences_list = list()

    protein_data = data["sequences"]
    chain_ids = [c for c in string.ascii_uppercase]

    for idx, (chain, seq) in enumerate(protein_data.items()):
        protein_chain = dict()
        chain_id = chain_ids[idx]
        protein_chain["proteinChain"] = dict()
        protein_chain["proteinChain"]["count"] = 1
        protein_chain["proteinChain"]["sequence"] = seq
        protein_chain["proteinChain"]["msa"] = dict()
        protein_chain["proteinChain"]["msa"]["precomputed_msa_dir"] = os.path.join(msa_dir, system_id.lower(), chain_id)
        protein_chain["proteinChain"]["msa"]["pairing_db"] = "uniprot"
        sequences_list.append(protein_chain)

    ligand_data = data["smiles"]
    for idx, smiles in enumerate(ligand_data):
        ligand_chain = dict()
        ligand_chain["ligand"] = dict()
        ligand_chain["ligand"]["ligand"] = smiles
        ligand_chain["ligand"]["count"] = 1
        sequences_list.append(ligand_chain)

    system_dict["sequences"] = sequences_list
    system_list = [system_dict]
    
    output_json_path = os.path.join(json_dir, f"{system_id}.json")
    with open(output_json_path, 'w') as f:
        json.dump(system_list, f, indent=2)

In [None]:
def prepare_boltz_input(system_id, data, yaml_dir, msa_dir):
    query_sequences = {}
    msa_files = {}
    n_protein_chains = 0
    
    for sequence_id, sequence in data['sequences'].items():
        chain = string.ascii_uppercase[n_protein_chains]

        if sequence in query_sequences:
            query_sequences[sequence].append(chain)
        else:
            query_sequences[sequence] = [chain]

        csv_filename = f'{msa_dir}/{system_id.lower()}/{sequence_id}.csv'

        if os.path.isfile(csv_filename):
            msa_files[sequence] = os.path.abspath(csv_filename)

        n_protein_chains += 1

    query_ligands = {}

    for i, smiles in enumerate(data['smiles']):
        chain = string.ascii_uppercase[n_protein_chains + i]

        if smiles in query_ligands:
            query_ligands[smiles].append(chain)
        else:
            query_ligands[smiles] = [chain]

    config = {'sequences': []}

    for sequence, chains in query_sequences.items():
        config['sequences'].append({'protein': {'id': sorted(chains), 'sequence': sequence, 'msa': msa_files[sequence]}})

    for smiles, chains in query_ligands.items():
        config['sequences'].append({'ligand': {'id': sorted(chains), 'smiles': smiles}})

    output_path = os.path.join(yaml_dir, f"{system_id}.yaml")
    with open(output_path, 'w') as w:
        yaml.dump(config, w, sort_keys=False, width=5000, default_flow_style=None)

In [None]:
input_json = "data/inputs.json"
msa_dir = "data/msa_files"

with open(input_json, 'r') as f:
    input_data = json.load(f)

af3_input_dir = "examples/inputs/af3"
chai_input_dir = "examples/inputs/chai"
protenix_input_dir = "examples/inputs/protenix"
boltz_input_dir = "examples/inputs/boltz"
example_msa_dir = "examples/inputs/msa_files"

os.makedirs(af3_input_dir, exist_ok=True)
os.makedirs(chai_input_dir, exist_ok=True)
os.makedirs(protenix_input_dir, exist_ok=True)
os.makedirs(boltz_input_dir, exist_ok=True)

example_id = "8c3u__1__1.A__1.C"
for system_id, data in input_data.items():
    if system_id == example_id:
        prepare_af3_input(system_id, data, af3_input_dir)
        prepare_chai_input(system_id, data, chai_input_dir)
        prepare_protenix_input(system_id, data, protenix_input_dir, example_msa_dir)
        prepare_boltz_input(system_id, data, boltz_input_dir, example_msa_dir)