In [None]:
""" Heatmap of RMSD values of proteins in dataset
NOTE: it is needed to install bassh tool TM-align
## apt install tm-align --user
"""
import os
import re
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

pdbs = str(new_folder) + "/pdbs"  # directory containing the PDB files
pdb_files = sorted(os.listdir(pdbs))

matrix_rmsd = [[0 for i in range(len(pdb_files))] for j in range(len(pdb_files))]
matrix_seqid = [[0 for i in range(len(pdb_files))] for j in range(len(pdb_files))]

for i in range(len(pdb_files)):
    for j in range(i+1, len(pdb_files)):
        pdb1 = pdb_files[i]
        pdb2 = pdb_files[j]
        # run TMalign command
        cmd = f"TMalign {os.path.join(pdbs, pdb1)} {os.path.join(pdbs, pdb2)} -a"
        output = os.popen(cmd).read()
        # extract RMSD value
        for line in output.split("\n"):
            if "RMSD" in line:
                rmsd_matches = re.findall(r"RMSD=\s+([\d.]+),", line)
                if len(rmsd_matches) > 0:
                    rmsd = float(rmsd_matches[0])
                else:
                    print(f"Error: could not find RMSD value for {pdb1} vs. {pdb2} in files {pdb1} and {pdb2}")
        matrix_rmsd[i][j] = rmsd
        matrix_rmsd[j][i] = rmsd

        # extract SeqID value
        for line in output.split("\n"):
            if "Seq_ID" in line:
                seqid_match = re.search(r"Seq_ID=n_identical/n_aligned=\s+([\d.]+)", line)
                if seqid_match:
                    seqid = float(seqid_match.group(1))
                else:
                    print(f"Error: could not find SeqID value for {pdb1} vs. {pdb2} in files {pdb1} and {pdb2}")
        matrix_seqid[i][j] = seqid
        matrix_seqid[j][i] = seqid

## save matrices
np.save("matrix_rmsd.npy", matrix_rmsd)
np.save("matrix_seqid.npy", matrix_seqid)

print("RMSD and SeqID analysis")
# Create DataFrames from the matrices
heatmap_rmsd = pd.DataFrame(matrix_rmsd, columns=pdb_files, index=pdb_files)
heatmap_seqid = pd.DataFrame(matrix_seqid, columns=pdb_files, index=pdb_files)

# Create the figure and subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Create the RMSD heatmap
sns.heatmap(heatmap_rmsd, ax=ax1, annot=False, cmap="YlGnBu", cbar=False)
ax1.set_title('Heatmap of RMSD values')

# Create the SeqID heatmap
sns.heatmap(heatmap_seqid, ax=ax2, annot=False, cmap="YlGnBu", cbar=False)
ax2.set_title('Heatmap of SeqID values')

# Remove x and y labels
ax1.set_xlabel("")
ax1.set_ylabel("")
ax2.set_xlabel("")
ax2.set_ylabel("")

# Save the figure
plt.savefig(str(new_folder) + "/heatmap_rmsd_seqid.jpg", format='jpg')
# Display the plot
plt