In [None]:
import pandas as pd
import matplotlib.pyplot as plt

#### Written by Yilun Huang, 2025 ####
#### Written for Master Rotation Project ####
#### Classify pockets based on distance to two representative residues ####
#### Create pie chart and bar chart of pocket classification ####


# ======================= Setup Input files ======================= #
input_file = "boltz_results_combined.csv"        # Input csv from Boltz-2
distance_file = "distance_matrices.csv"         # Distance matrix csv from Filter_Boltz2.py
output_file = "distance_with_labels.csv"        # Intermediate output file with distance and labels
merged_file = "pocket_classification_merged_combined_results.csv"  # Final output file merging input and classification
# ================================================================ #

# Note the residue numbers are 0 index, but the code allows for this. Input two respresentative residues.
residues = {
    "res1": 69, # Top of receptor
    "res2": 85, # Middle of receptor
}


# Read in inputs
final_models = pd.read_csv(input_file)
final_cols = [c for c in final_models.columns if "model_path" in c.lower()]
if not final_cols:
    raise ValueError(" Final 文件缺少 model_path 列")
model_col_final = final_cols[0]
final_models[model_col_final] = final_models[model_col_final].str.replace("./", "", regex=False)


# Read in distance matrix
distance_data = pd.read_csv(distance_file)
if distance_data.columns[0].startswith("Unnamed"):
    distance_data.rename(columns={distance_data.columns[0]: "model_path"}, inplace=True)

dist_cols = [c for c in distance_data.columns if "model_path" in c.lower()]
if not dist_cols:
    raise ValueError("Error in input distance matrix")
model_col_dist = dist_cols[0]


filtered = distance_data[distance_data[model_col_dist].isin(final_models[model_col_final])].copy()
if filtered.empty:
    raise ValueError("Didn't find any matching model_path between the two files.")

def classify_pocket(row):
    
    dist_residue_top = row[str(residues["res1"] -1)]
    dist_residue_middle = row[str(residues["res2"] -1)]

    if  dist_residue_top <= 8:
        return "allosteric pocket"
    elif 8 < dist_residue_top < 23 or dist_residue_middle <= 12:
        return "orthosteric pocket"
    else:
        return "outside"


# Create the output dataframe (Note used in plotting distributions of pocket classification)
output_df = pd.DataFrame()
output_df["model_path"] = filtered[model_col_dist]
output_df["pocket_label"] = filtered.apply(classify_pocket, axis=1)



# Add selected residue distances to output dataframe
for name, res in residues.items():
    col_name = str(res - 1)   # 0-based index
    output_df[name] = filtered[col_name]


# Output CSV and merge with original input file
output_df.to_csv(output_file, index=False)

merged_df = pd.merge(final_models, output_df, left_on=model_col_final, right_on="model_path", how="left")
merged_df.to_csv(merged_file, index=False)

# Plot affinity distribution for each pocket type
plt.figure(figsize=(8, 6))
for pocket_type in merged_df["pocket_label"].unique():
    subset = merged_df[merged_df["pocket_label"] == pocket_type]
    plt.hist(subset["affinity_pred_value"], bins=100, alpha=0.5, label=pocket_type)
plt.title("Affinity Prediction by Pocket Type")
plt.xlabel("Affinity Predicted Value")
plt.ylabel("Frequency")
plt.legend()
plt.savefig("pocket_affinity_pred_distribution.png", dpi=300)
plt.close()


# Count pocket labels for plotting
counts = output_df["pocket_label"].value_counts()

# Pie Chart for pocket classification
plt.figure(figsize=(6,6))
plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140)
plt.title("Pocket Classification Distribution")
plt.savefig("pocket_piechart.png", dpi=300)
plt.close()

# Bar Chart for pocket classification
plt.figure(figsize=(6,4))
counts.plot(kind="bar", color=["#66c2a5", "#fc8d62", "#8da0cb"])
plt.title("Pocket Classification Counts")
plt.xlabel("Pocket Type")
plt.ylabel("Number of Molecules")
plt.xticks(rotation=0)
plt.savefig("pocket_barchart.png", dpi=300)
plt.close()




