Skip to content

Commit

Permalink
Feat: analysis radical migration tool (#388)
Browse files Browse the repository at this point in the history
* initial setup

* black

* wip

* black

* finish first version

* add pymol file

* address review

* move script

---------

Co-authored-by: Eric Hartmann <hartmaec@rh05659.villa-bosch.de>
  • Loading branch information
ehhartmann and Eric Hartmann committed Feb 29, 2024
1 parent 0bc5a68 commit ba4cb0f
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 4 deletions.
4 changes: 2 additions & 2 deletions example/charged_peptide_homolysis_hat_naive/kimmdy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: 'kimmdy_001'
dryrun: false
max_tasks: 100
gromacs_alias: 'gmx'
gmx_mdrun_flags: -maxh 24 -dlb yes -nt 8
gmx_mdrun_flags: -maxh 24 -dlb yes
top: 'IMREE.top'
gro: 'IMREE_npt.gro'
ndx: 'index.ndx'
Expand Down Expand Up @@ -37,7 +37,7 @@ sequence:
- pull
- homolysis
-
mult: 2
mult: 5
tasks:
- equilibrium
- pull
Expand Down
2 changes: 1 addition & 1 deletion example/charged_peptide_homolysis_hat_naive/plumed.dat
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ d8: DISTANCE ATOMS=82,84
d9: DISTANCE ATOMS=84,95

#Print distances ARG to FILE every STRIDE steps
PRINT ARG=d0,d1,d2,d3,d4,d5,d6,d7,d8,d9, STRIDE=100 FILE=distances.dat
PRINT ARG=d0,d1,d2,d3,d4,d5,d6,d7,d8,d9, STRIDE=500 FILE=distances.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import MDAnalysis as mda
from pymol import cmd
from pymol import cgo
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import json


# %%
def get_inbetween(p1, p2, val):
return p1 + val * (p2 - p1)


# %%
interactions_path = Path("radical_migration_dopa_merged.json")
geometry_path = Path("dopa.gro")
do_arrow = True
analysis_type = "count"
manual_max_count = 40

# %%
u = mda.Universe(geometry_path.as_posix(), guess_bonds=True)

with open(interactions_path, "r") as json_file:
interactions = json.load(json_file)

# %%
if analysis_type == "quantitative":
rgb = np.tile(np.asarray([0.62, 0.59, 0.98]), (len(interactions.keys()), 1))
elif analysis_type == "max_rate":
max_rates = np.array([entry["max_rate"] for entry in interactions.values()])
log_norm_max_rates = np.log(max_rates)

# Normalize the log values to the range [0, 1]
min_log = np.min(log_norm_max_rates)
max_log = np.max(log_norm_max_rates)
norm_log_max_rates = (log_norm_max_rates - min_log) / (max_log - min_log)

# Use the purple sequential colormap
cmap = plt.get_cmap("Purples")

# Get RGB values for the normalized log max_rates
rgb = [np.asarray(cmap(value)[:3]) for value in norm_log_max_rates]
elif analysis_type == "count":
print("Selected analysis by count of reaction occurence.")

counts = np.array([entry["count"] for entry in interactions.values()])

# Normalize the counts to the range [0, 1]
min_count = np.min(counts)
if not manual_max_count:
max_count = np.max(counts)

norm_counts = (counts - min_count) / (max_count - min_count)
print(f"min count: {min_count}, max count: {max_count}")
print(norm_counts)

# Use the purple sequential colormap
cmap = plt.get_cmap("Blues")

# Get RGB values for the normalized counts
rgb = [np.asarray(cmap(value)[:3]) for value in norm_counts]

# breakpoint()
# %%
radius = 10
counter = 0
for k, v in interactions.items():
print(f"Building object for atoms {k}")
atom_1_nr, atom_2_nr = k.split(sep="_")

a1 = u.select_atoms(f"id {atom_1_nr}")[0].position
a2 = u.select_atoms(f"id {atom_2_nr}")[0].position

p_cap = get_inbetween(a1, a2, 0.75)
radius = 0.12
p1 = get_inbetween(a1, a2, 0.1)
p2 = get_inbetween(a1, a2, 0.9)

if do_arrow:
p1_far = get_inbetween(a1, a2, 0.18)

cylinder = (
[cgo.CYLINDER]
+ p1_far.tolist()
+ p_cap.tolist()
+ [radius]
+ rgb[counter].tolist()
+ rgb[counter].tolist()
)
head = (
[cgo.CONE]
+ p_cap.tolist()
+ p2.tolist()
+ [radius * 2, 0.0]
+ rgb[counter].tolist()
+ rgb[counter].tolist()
+ [1.0, 0.0]
)
# cmd.load_cgo(cylinder,f"cylinder{i}")
cmd.load_cgo(cylinder + head, f"cylinder{k}")
counter += 1
print(f"Loading geometry file")
cmd.load(geometry_path.as_posix())
print("Changing default settings")
cmd.set("virtual_trackball", 0)
cmd.bg_color("white")
cmd.set("orthoscopic", 1)
cmd.set("stick_radius", 0.1)
cmd.show("licorice", "all")
cmd.hide("everything", "resn SOL")
# %%
109 changes: 108 additions & 1 deletion src/kimmdy/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from seaborn import axes_style
import pandas as pd
from datetime import datetime
import json

from kimmdy.utils import run_shell_cmd
from kimmdy.parsing import read_json, write_json
Expand Down Expand Up @@ -350,6 +351,87 @@ def radical_population(
run_shell_cmd(f"vmd {pdb_output}", cwd=analysis_dir)


def radical_migration(
dirs: list[str],
type: str = "qualitative",
cutoff: int = 1,
):
"""Plot population of radicals for a KIMMDY run.
Parameters
----------
dirs
KIMMDY run directories to be analysed.
type
How to analyse radical migration. Available are 'qualitative','occurence' and 'min_rate'",
cutoff
Ignore migration between two atoms if it happened less often than the specified value.
"""
print(
"Running radical migration analysis\n"
f"dirs: \t\t{dirs}\n"
f"type: \t\t{type}\n"
f"cutoff: \t{cutoff}\n\n"
f"Writing analysis files in {dirs[0]}"
)

migrations = []
analysis_dir = get_analysis_dir(Path(dirs[0]))
for d in dirs:
run_dir = Path(d).expanduser().resolve()

picked_recipes = {}
for recipes in run_dir.glob("*decide_recipe/recipes.csv"):
task_nr = int(recipes.parents[0].stem.split(sep="_")[0])
rc, picked_recipe = RecipeCollection.from_csv(recipes)
picked_recipes[task_nr] = picked_recipe
sorted_recipes = [val for key, val in sorted(picked_recipes.items())]

for sorted_recipe in sorted_recipes:
connectivity_difference = {}
for step in sorted_recipe.recipe_steps:
if isinstance(step, Break):
for atom_id in [step.atom_id_1, step.atom_id_2]:
if atom_id in connectivity_difference.keys():
connectivity_difference[atom_id] += -1
else:
connectivity_difference[atom_id] = -1
elif isinstance(step, Bind):
for atom_id in [step.atom_id_1, step.atom_id_2]:
if atom_id in connectivity_difference.keys():
connectivity_difference[atom_id] += 1
else:
connectivity_difference[atom_id] = 1

from_atom = [
key for key, value in connectivity_difference.items() if value == 1
]
to_atom = [
key for key, value in connectivity_difference.items() if value == -1
]
if len(from_atom) == 1 and len(to_atom) == 1:
migrations.append([from_atom[0], to_atom[0], max(sorted_recipe.rates)])

# get unique migrations
unique_migrations = {}
for migration in migrations:
key = "_".join(migration[:2])
if key not in unique_migrations.keys():
unique_migrations[key] = {"count": 0, "max_rate": 1e-70}
unique_migrations[key]["count"] += 1
if migration[2] > unique_migrations[key]["max_rate"]:
unique_migrations[key]["max_rate"] = migration[2]

# filter by cutoff

# write json
out_path = analysis_dir / "radical_migration.json"
with open(out_path, "w") as json_file:
json.dump(unique_migrations, json_file)
print("Done!")


def plot_rates(dir: str):
"""Plot rates of all possible reactions for each 'decide_recipe' step.
Expand Down Expand Up @@ -643,7 +725,30 @@ def get_analysis_cmdline_args() -> argparse.Namespace:
help="Open VMD with the concatenated trajectory."
"To view the radical occupancy per atom, add a representation with the beta factor as color.",
)

parser_radical_migration = subparsers.add_parser(
name="radical_migration",
help="Create a json of radical migration events for further analysis.",
)
parser_radical_migration.add_argument(
"dirs",
type=str,
help="One or multiple KIMMDY run directories to be analysed.",
nargs="+",
)
parser_radical_migration.add_argument(
"--type",
"-t",
type=str,
help="How to analyse radical migration. Available are 'qualitative','occurence' and 'min_rate'",
default="qualitative",
)
parser_radical_migration.add_argument(
"--cutoff",
"-c",
type=int,
help="Ignore migration between two atoms if it happened less often than the specified value.",
default=1,
)
parser_rates = subparsers.add_parser(
name="rates",
help="Plot rates of all possible reactions after a MD run. Rates must have been saved!",
Expand Down Expand Up @@ -710,6 +815,8 @@ def entry_point_analysis():
args.open_plot,
args.open_vmd,
)
elif args.module == "radical_migration":
radical_migration(args.dirs, args.type, args.cutoff)
elif args.module == "rates":
plot_rates(args.dir)
elif args.module == "runtime":
Expand Down

0 comments on commit ba4cb0f

Please sign in to comment.