Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: analysis radical migration tool #388

Merged
merged 9 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/charged_peptide_homolysis_hat_naive/kimmdy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: 'kimmdy_001'
dryrun: false
max_tasks: 100
gromacs_alias: 'gmx'
gmx_mdrun_flags: -maxh 24 -dlb yes -nt 8
gmx_mdrun_flags: -maxh 24 -dlb yes
top: 'IMREE.top'
gro: 'IMREE_npt.gro'
ndx: 'index.ndx'
Expand Down Expand Up @@ -37,7 +37,7 @@ sequence:
- pull
- homolysis
-
mult: 2
mult: 5
tasks:
- equilibrium
- pull
Expand Down
2 changes: 1 addition & 1 deletion example/charged_peptide_homolysis_hat_naive/plumed.dat
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ d8: DISTANCE ATOMS=82,84
d9: DISTANCE ATOMS=84,95

#Print distances ARG to FILE every STRIDE steps
PRINT ARG=d0,d1,d2,d3,d4,d5,d6,d7,d8,d9, STRIDE=100 FILE=distances.dat
PRINT ARG=d0,d1,d2,d3,d4,d5,d6,d7,d8,d9, STRIDE=500 FILE=distances.dat
109 changes: 108 additions & 1 deletion src/kimmdy/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from seaborn import axes_style
import pandas as pd
from datetime import datetime
import json

from kimmdy.utils import run_shell_cmd
from kimmdy.parsing import read_json, write_json
Expand Down Expand Up @@ -350,6 +351,87 @@ def radical_population(
run_shell_cmd(f"vmd {pdb_output}", cwd=analysis_dir)


def radical_migration(
dirs: list[str],
type: str = "qualitative",
cutoff: int = 1,
):
"""Plot population of radicals for a KIMMDY run.

Parameters
----------
dirs
KIMMDY run directories to be analysed.
type
How to analyse radical migration. Available are 'qualitative','occurence' and 'min_rate'",
cutoff
Ignore migration between two atoms if it happened less often than the specified value.

"""
print(
"Running radical migration analysis\n"
f"dirs: \t\t{dirs}\n"
f"type: \t\t{type}\n"
f"cutoff: \t{cutoff}\n\n"
f"Writing analysis files in {dirs[0]}"
)

migrations = []
analysis_dir = get_analysis_dir(Path(dirs[0]))
for d in dirs:
run_dir = Path(d).expanduser().resolve()

picked_recipes = {}
for recipes in run_dir.glob("*decide_recipe/recipes.csv"):
task_nr = int(recipes.parents[0].stem.split(sep="_")[0])
rc, picked_recipe = RecipeCollection.from_csv(recipes)
picked_recipes[task_nr] = picked_recipe
sorted_recipes = [val for key, val in sorted(picked_recipes.items())]

for sorted_recipe in sorted_recipes:
connectivity_difference = {}
for step in sorted_recipe.recipe_steps:
if isinstance(step, Break):
for atom_id in [step.atom_id_1, step.atom_id_2]:
if atom_id in connectivity_difference.keys():
connectivity_difference[atom_id] += -1
else:
connectivity_difference[atom_id] = -1
elif isinstance(step, Bind):
for atom_id in [step.atom_id_1, step.atom_id_2]:
if atom_id in connectivity_difference.keys():
connectivity_difference[atom_id] += 1
else:
connectivity_difference[atom_id] = 1

from_atom = [
key for key, value in connectivity_difference.items() if value == 1
]
to_atom = [
key for key, value in connectivity_difference.items() if value == -1
]
if len(from_atom) == 1 and len(to_atom) == 1:
migrations.append([from_atom[0], to_atom[0], max(sorted_recipe.rates)])

# get unique migrations
unique_migrations = {}
for migration in migrations:
key = "_".join(migration[:2])
if key not in unique_migrations.keys():
unique_migrations[key] = {"count": 0, "max_rate": 1e-70}
unique_migrations[key]["count"] += 1
if migration[2] > unique_migrations[key]["max_rate"]:
unique_migrations[key]["max_rate"] = migration[2]

# filter by cutoff

# write json
out_path = analysis_dir / "radical_migration.json"
with open(out_path, "w") as json_file:
json.dump(unique_migrations, json_file)
print("Done!")


def plot_rates(dir: str):
"""Plot rates of all possible reactions for each 'decide_recipe' step.

Expand Down Expand Up @@ -643,7 +725,30 @@ def get_analysis_cmdline_args() -> argparse.Namespace:
help="Open VMD with the concatenated trajectory."
"To view the radical occupancy per atom, add a representation with the beta factor as color.",
)

parser_radical_migration = subparsers.add_parser(
name="radical_migration",
help="Create a json of radical migration events for further analysis.",
)
parser_radical_migration.add_argument(
"dirs",
type=str,
help="One or multiple KIMMDY run directories to be analysed.",
nargs="+",
)
parser_radical_migration.add_argument(
"--type",
"-t",
type=str,
help="How to analyse radical migration. Available are 'qualitative','occurence' and 'min_rate'",
default="qualitative",
)
parser_radical_migration.add_argument(
"--cutoff",
"-c",
type=int,
help="Ignore migration between two atoms if it happened less often than the specified value.",
default=1,
)
parser_rates = subparsers.add_parser(
name="rates",
help="Plot rates of all possible reactions after a MD run. Rates must have been saved!",
Expand Down Expand Up @@ -710,6 +815,8 @@ def entry_point_analysis():
args.open_plot,
args.open_vmd,
)
elif args.module == "radical_migration":
radical_migration(args.dirs, args.type, args.cutoff)
elif args.module == "rates":
plot_rates(args.dir)
elif args.module == "runtime":
Expand Down
113 changes: 113 additions & 0 deletions src/kimmdy/assets/scripts/analyse_kimmdy_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import MDAnalysis as mda
from pymol import cmd
from pymol import cgo
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import json


# %%
def get_inbetween(p1, p2, val):
return p1 + val * (p2 - p1)


# %%
interactions_path = Path("radical_migration_dopa_merged.json")
geometry_path = Path("dopa.gro")
do_arrow = True
analysis_type = "count"
manual_max_count = 40

# %%
u = mda.Universe(geometry_path.as_posix(), guess_bonds=True)

with open(interactions_path, "r") as json_file:
interactions = json.load(json_file)

# %%
if analysis_type == "quantitative":
rgb = np.tile(np.asarray([0.62, 0.59, 0.98]), (len(interactions.keys()), 1))
elif analysis_type == "max_rate":
max_rates = np.array([entry["max_rate"] for entry in interactions.values()])
log_norm_max_rates = np.log(max_rates)

# Normalize the log values to the range [0, 1]
min_log = np.min(log_norm_max_rates)
max_log = np.max(log_norm_max_rates)
norm_log_max_rates = (log_norm_max_rates - min_log) / (max_log - min_log)

# Use the purple sequential colormap
cmap = plt.get_cmap("Purples")

# Get RGB values for the normalized log max_rates
rgb = [np.asarray(cmap(value)[:3]) for value in norm_log_max_rates]
elif analysis_type == "count":
print("Selected analysis by count of reaction occurence.")

counts = np.array([entry["count"] for entry in interactions.values()])

# Normalize the counts to the range [0, 1]
min_count = np.min(counts)
if not manual_max_count:
max_count = np.max(counts)

norm_counts = (counts - min_count) / (max_count - min_count)
print(f"min count: {min_count}, max count: {max_count}")
print(norm_counts)

# Use the purple sequential colormap
cmap = plt.get_cmap("Blues")

# Get RGB values for the normalized counts
rgb = [np.asarray(cmap(value)[:3]) for value in norm_counts]

# breakpoint()
# %%
radius = 10
counter = 0
for k, v in interactions.items():
print(f"Building object for atoms {k}")
atom_1_nr, atom_2_nr = k.split(sep="_")

a1 = u.select_atoms(f"id {atom_1_nr}")[0].position
a2 = u.select_atoms(f"id {atom_2_nr}")[0].position

p_cap = get_inbetween(a1, a2, 0.75)
radius = 0.12
p1 = get_inbetween(a1, a2, 0.1)
p2 = get_inbetween(a1, a2, 0.9)

if do_arrow:
p1_far = get_inbetween(a1, a2, 0.18)

cylinder = (
[cgo.CYLINDER]
+ p1_far.tolist()
+ p_cap.tolist()
+ [radius]
+ rgb[counter].tolist()
+ rgb[counter].tolist()
)
head = (
[cgo.CONE]
+ p_cap.tolist()
+ p2.tolist()
+ [radius * 2, 0.0]
+ rgb[counter].tolist()
+ rgb[counter].tolist()
+ [1.0, 0.0]
)
# cmd.load_cgo(cylinder,f"cylinder{i}")
cmd.load_cgo(cylinder + head, f"cylinder{k}")
counter += 1
print(f"Loading geometry file")
cmd.load(geometry_path.as_posix())
print("Changing default settings")
cmd.set("virtual_trackball", 0)
cmd.bg_color("white")
cmd.set("orthoscopic", 1)
cmd.set("stick_radius", 0.1)
cmd.show("licorice", "all")
cmd.hide("everything", "resn SOL")
# %%