In [1]:
import mdtraj as mda
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

data_path = './data/dump-surface.lammpstrj'

layer_num = 40

max_pos, min_pos = None, None

In [2]:
class H2O :
	def __init__ (self, pO : np.ndarray, pH1 : np.ndarray, pH2 : np.ndarray) :
		self.pO = pO
		self.pH1 = pH1
		self.pH2 = pH2

		def fix_pH(pH) :
			dis = abs(pH - pO)
			for i in range(3) :
				if dis[i] > (max_pos[i] - min_pos[i]) / 2 :
					if pH[i] > pO[i] :
						pH[i] -= (max_pos[i] - min_pos[i])
					else :
						pH[i] += (max_pos[i] - min_pos[i])

		if np.linalg.norm(pO - pH1) > 20 :
			fix_pH(pH1)
		if np.linalg.norm(pO - pH2) > 20 :
			fix_pH(pH2)

def toArr(h2o : list[H2O]) -> tuple[np.ndarray, np.ndarray, np.ndarray] :
	n = len(h2o)
	pO = np.zeros((n, 3))
	pH1 = np.zeros((n, 3))
	pH2 = np.zeros((n, 3))
	h2o = sorted(h2o, key=lambda x: x.pO[2])
	for i in range(n) :
		pO[i, :] = h2o[i].pO
		pH1[i, :] = h2o[i].pH1
		pH2[i, :] = h2o[i].pH2
	return pO, pH1, pH2


# Load Data

In [4]:
data_raw : list[str]

with open(data_path, 'r') as f:
	data_raw = f.readlines()

data_raw = [line.strip() for line in data_raw]

cur_frame : list[H2O] = None
frames : list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []

line_id = 0

process = tqdm(total=len(data_raw), desc="Parsing LAMMPS trajectory")
while line_id < len(data_raw) :
	cur_line = data_raw[line_id].strip()
	if cur_line.startswith("ITEM: TIMESTEP") :
		if cur_frame is not None :
			frames.append(toArr(cur_frame))
		cur_frame = []
		if min_pos is None :
			min_pos = np.zeros(3)
			max_pos = np.zeros(3)
			for x in range(3) :
				a, b = data_raw[line_id + 5 + x].strip().split()
				min_pos[x], max_pos[x] = float(a), float(b)
		line_id += 9
		process.update(9)
		continue
	def get_pos(line : str, type : int) -> np.ndarray :
		parts = line.strip().split()
		# print(parts)
		assert type == int(parts[1]), f"Expected type {type}, got {parts[1]}"
		return np.array([float(parts[2]), float(parts[3]), float(parts[4])])
	pO = get_pos(data_raw[line_id], 1)
	pH1 = get_pos(data_raw[line_id + 1], 2)
	pH2 = get_pos(data_raw[line_id + 2], 2)
	cur_frame.append(H2O(pO, pH1, pH2))
	process.update(3)
	line_id += 3

process.close()
assert max([frame[0].shape[0] for frame in frames]) == min([frame[0].shape[0] for frame in frames]), \
	"Number of H2O molecules must be consistent across frames."

print(f"Total frames parsed: {len(frames)} with {frames[0][0].shape[0]} H2O molecules each.")

print("box bounds:", min_pos, max_pos)

Parsing LAMMPS trajectory: 100%|██████████| 6357177/6357177 [00:14<00:00, 426554.80it/s]

Total frames parsed: 2000 with 1056 H2O molecules each.
box bounds: [-14.14213562 -14.14213562 -40.        ] [14.14213562 14.14213562 40.        ]





# Split Layers