#### CHECK mdCATH H5 FILES

In [None]:
from os.path import join as opj
from glob import glob
import os
import pandas as pd
import h5py as h5
import numpy as np
import matplotlib 
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [None]:
pdbid = '1cqzB02'
mdcath_dir = "/workspace3/mdcath/"
h5file = opj(mdcath_dir, f"mdcath_dataset_{pdbid}.h5")
f = h5.File(h5file, 'r')
sim_name = '379'
repl = '0'
print(f'keys:{f.keys()}')
print(f'attrs:{f.attrs["layout"]}')

In [None]:
print(f'molDatasets --> {f[pdbid].keys()}')
print(f'molAttrs --> {f[pdbid].attrs.keys()}')

In [None]:
print(pdbid)
print(f"numChains --> {f[pdbid].attrs['numChains']}")
print(f"numProteinAtoms --> {f[pdbid].attrs['numProteinAtoms']}")
print(f"numResidues --> {f[pdbid].attrs['numResidues']}")

In [None]:
print(f"z.shape --> {f[pdbid]['z'].shape}")
print(f"z --> {f[pdbid]['z'][:10]}")

In [None]:
# recover idxs of Ca from pdbProteinAtoms
pdbProteinAtoms = f[pdbid]['pdbProteinAtoms'][()].decode('utf-8').split('\n')[1:-3] # remove header and footer
atomtypes = [line.split()[2] for line in pdbProteinAtoms]
ca_indices = np.where(np.array(atomtypes) == 'CA')[0]
print(f'Number of CA atoms: {len(ca_indices)}')

In [None]:
print(f"pdbProteinAtoms\n\n{f[pdbid]['pdbProteinAtoms'][()].decode('utf-8')}")

In [None]:
print(f"PDB\n\n{f[pdbid]['pdb'][()].decode('utf-8')}")

In [None]:
print(f"PSF\n\n{f[pdbid]['psf'][()].decode('utf-8')}")

In [None]:
print(f'{sim_name} --> {f[pdbid][sim_name].keys()}')
print(f'{sim_name} --> {f[pdbid][sim_name].attrs.keys()}')

In [None]:
for key, data in f[pdbid][sim_name][str(repl)].items():
    print(f'prop {key} --> {data.shape}')
    for attr in data.attrs.keys():
        print(f'{attr} --> {data.attrs[attr]}')

In [None]:
for replattr in f[pdbid][sim_name][str(repl)].attrs.keys():
    print(f'{replattr} --> {f[pdbid][sim_name][str(repl)].attrs[replattr]}')

In [None]:
skipframes = 2
conf_idx =  f[pdbid][sim_name][str(repl)]['coords'].shape[0] -1 
print(f'conf_idx --> {conf_idx}')
z = f[pdbid]["z"][:]
coords = np.zeros((z.shape[0], 3))
forces = np.zeros((z.shape[0], 3))
slice_idxs = np.s_[conf_idx:conf_idx+1]
group = f[f"{pdbid}/{sim_name}/{repl}"]
group['coords'].read_direct(coords, slice_idxs)
group['forces'].read_direct(forces, slice_idxs)
print(f'coords --> {coords.shape}')
print(f'forces --> {forces.shape}')

In [None]:
for simsname in ["320", "348", "379", "413", "450"]:
    for repln in range(5):
        print(f"{simsname} Replica {repln} --> {f[pdbid][simsname][str(repln)].attrs['numFrames']}")

In [None]:
rmsd = f[pdbid][sim_name][str(repl)]['rmsd'][:]
rmsf = f[pdbid][sim_name][str(repl)]['rmsf'][:]
gyration_radius = f[pdbid][sim_name][str(repl)]['gyrationRadius'][:]
print(f"rmsd.shape --> {rmsd.shape}")
print(f"rmsf.shape --> {rmsf.shape}")
print(f"gyration_radius.shape --> {gyration_radius.shape}")

In [None]:
# plot rmsd, rmsd, gyration radius
time = np.arange(0, len(rmsd))/10 # time in ns
fig, axs = plt.subplots(1, 3, figsize=(18, 5))
axs = axs.flatten()
last_frame =  rmsd.shape[0]
##
axs[0].plot(time[:last_frame], rmsd[:last_frame])
axs[0].set_title('RMSD')
axs[0].set_ylabel('RMSD (nm)')
axs[0].set_xlabel('Time (ns)')
##
axs[1].plot(rmsf)
axs[1].set_title('RMSF')
axs[1].set_ylabel('RMSF (nm)')
axs[1].set_xlabel('residue id')
## 
axs[2].plot(time, gyration_radius)
axs[2].set_title('Gyration Radius')
axs[2].set_ylabel('Gyration Radius (nm)')
axs[2].set_xlabel('Time (ns)')

plt.tight_layout()
plt.show()

In [None]:
encoded_dssp = f[pdbid][sim_name][str(repl)]['dssp']
#floatMap = {'C': 0, 'E': 1, 'H': 2}
floatMap = {"H": 0, "B": 1, "E": 2, "G": 0, "I": 0, "T": 5, "S": 6, " ": 7}
dssp_decoded_float = np.zeros((encoded_dssp.shape[0], encoded_dssp.shape[1]), dtype=np.float32)
for i in range(encoded_dssp.shape[0]):
    dssp_decoded_float[i] = [floatMap[el.decode()] for el in encoded_dssp[i]]
print(f"dssp_decoded.shape --> {dssp_decoded_float.shape}")

In [None]:
color_list = ["blue", "black", "red", "grey", "purple", "yellow", "seagreen", "white"]
cmap = matplotlib.colors.ListedColormap(color_list)
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
xtime = np.arange(0, len(dssp_decoded_float)) / 10
x_min, x_max = xtime[0], xtime[-1]
extent = [x_min, x_max, len(dssp_decoded_float.T), 0]

cax = ax.imshow(dssp_decoded_float.T, aspect='auto', cmap=cmap, extent=extent)
labels = ["$\\alpha$-helix", "$\\beta$-Bridge", "$\\beta$-sheet", "3-10 helix", "$\\pi$-helix", "Turn", "Bend", "Coils"]
handles = [mpatches.Patch(facecolor=color_list[i], label=labels[i], edgecolor="darkgrey") for i in range(len(color_list))]
plt.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5))

plt.title('DSSP')   
plt.xlabel('Time (ns)')
plt.ylabel('Residue ID')
plt.tight_layout()
plt.show()

In [None]:
floatMap = {"H": 0, "B": 1, "E": 1, "G": 0, "I": 0, "T": 2, "S": 2, " ": 2}
dssp_decoded_float = np.zeros((encoded_dssp.shape[0], encoded_dssp.shape[1]), dtype=np.float32)
for i in range(encoded_dssp.shape[0]):
    dssp_decoded_float[i] = [floatMap[el.decode()] for el in encoded_dssp[i]]
print(f"dssp_decoded.shape --> {dssp_decoded_float.shape}")
solid_fraction_time = np.logical_or(dssp_decoded_float == 0, dssp_decoded_float == 1).mean(axis=0)
plt.figure(figsize=(5, 5))
# color by residue number
plt.scatter(rmsf, solid_fraction_time, c=np.arange(len(rmsf)), cmap='rainbow')
plt.xlabel('RMSF (nm)')
plt.ylabel('Solid Fraction')
plt.xlim(0, 2)
plt.ylim(-0.1, 1.1)
plt.show()

In [None]:
coords = f[pdbid][sim_name][str(1)]["coords"]#[:,ca_indices,:]
forces = f[pdbid][sim_name][str(1)]["forces"]#[:,ca_indices,:]
print(f'coords --> {coords.shape}, units: {f[pdbid][sim_name]["0"]["coords"].attrs["unit"]}')
print(f'forces --> {forces.shape}, units: {f[pdbid][sim_name]["0"]["forces"].attrs["unit"]}')

In [None]:
def plot_box(box, ax, origin=[0, 0, 0]):
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
    
    # Compute original vertices centered at the origin
    vertices = np.array([[-box[0][0]/2, -box[1][1]/2, -box[2][2]/2],
                        [box[0][0]/2, -box[1][1]/2, -box[2][2]/2],
                        [-box[0][0]/2, box[1][1]/2, -box[2][2]/2],
                        [box[0][0]/2, box[1][1]/2, -box[2][2]/2],
                        [-box[0][0]/2, -box[1][1]/2, box[2][2]/2],
                        [box[0][0]/2, -box[1][1]/2, box[2][2]/2],
                        [-box[0][0]/2, box[1][1]/2, box[2][2]/2],
                        [box[0][0]/2, box[1][1]/2, box[2][2]/2]])

    # Translate vertices to make the box center at 'origin'
    vertices = vertices + origin

    # List of sides' vertices indices
    edges = [
        [vertices[0], vertices[1], vertices[3], vertices[2]],  # Bottom face
        [vertices[4], vertices[5], vertices[7], vertices[6]],  # Top face
        [vertices[0], vertices[1], vertices[5], vertices[4]],  # Front face
        [vertices[2], vertices[3], vertices[7], vertices[6]],  # Back face
        [vertices[0], vertices[2], vertices[6], vertices[4]],  # Left face
        [vertices[1], vertices[3], vertices[7], vertices[5]]   # Right face
    ]

    # Create a 3D polygon collection for the faces
    faces = Poly3DCollection(edges, linewidths=1, edgecolors='k', alpha=0.1)
    ax.add_collection3d(faces)

    # Plot the vertices
    ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], s=5, color='r')

In [None]:
import matplotlib.pyplot as plt
box = f[pdbid][sim_name][str(repl)]["box"][:] * 10
init_coords = coords[0]
init_coords_barycenter = np.mean(init_coords, axis=0)
end_coords = coords[-1]
end_coords_barycenter = np.mean(end_coords, axis=0)
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax.set_title('Initial Coordinates')
ax.scatter(init_coords[:,0], init_coords[:,1], init_coords[:,2], s=2)
plot_box(box, ax, init_coords_barycenter)
ax1 = fig.add_subplot(122, projection='3d')
ax1.set_title('Final Coordinates')
ax1.scatter(end_coords[:,0], end_coords[:,1], end_coords[:,2], s=2)
plot_box(box, ax1, end_coords_barycenter)
plt.show()

In [None]:
# histogram of forces 
plt.figure(figsize=(8,5))
plt.grid()
force_atom_i = forces[:, 50, 2]
plt.plot(force_atom_i)
plt.xlabel('Time (ns)')
plt.show()

In [None]:
# histogram of forces 
plt.figure(figsize=(8,5))
plt.grid()
normForces = np.linalg.norm(forces[0], axis=1)
plt.hist(normForces, bins=50)
plt.ylabel("Atom Count")
plt.xlabel("Force (kcal/mol/A)")
plt.show()