In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join('..')))
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torchmdnet.datasets.custom
import torch.nn as nn
import torch.optim as optim
from torchmdnet.module import LNNP
from torch_geometric.loader import DataLoader
import lightning.pytorch as pl
import yaml
import numpy as np
from torchmdnet.models.model import create_model
import itertools
from module import dataset
import os
import matplotlib.pyplot as plt
import json
import time
from tqdm import tqdm
import datetime
from collections import defaultdict
from typing import Optional, List, Tuple, Dict
from torch import Tensor
import matplotlib.pyplot as plt

In [None]:
class BatchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, embed, coord, batch_nums) -> Tuple[Tensor, Tensor, Tensor]:
        embed = embed.flatten()
        coord = coord.reshape(-1, coord.shape[-1])

        # Increment the molecule id numbers in each batch after the first to
        # produce a sequential numbering for the entire batch.
        for i in range(1, len(batch_nums)):
            offset = batch_nums[i - 1][-1] + 1
            batch_nums[i] += offset
        batch_nums = batch_nums.flatten()

        energy, computed_force = self.model(embed, coord, batch_nums)
        return energy, computed_force

In [None]:
def val_model(input_directory, model_directory, gpu_ids=None, batch_size=50):

    # Set random seed for reproducibility
    torch.manual_seed(123456)

    # Enable deterministic behavior for CUDA operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load the list of PDB files
    with open(f"{input_directory}/result/ok_list.txt", "r") as file:
        pdb_list = file.read().split("\n")

    # Load only the first protein
    pdb_list = pdb_list[:1]

    # Load all data
    all_data = dataset.ProteinDataset(input_directory, pdb_list)
    num_proteins = all_data.num_proteins()
    print(f"Number of proteins in the dataset: {num_proteins}")
    data_loader = DataLoader(
        all_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        persistent_workers=True,
        pin_memory=True,
    )

    # Load model
    checkpoint_path = f"{model_directory}/checkpoint.pth"
    checkpoint = torch.load(checkpoint_path)
    model = create_model(args=checkpoint["hyper_parameters"])
    model.load_state_dict(checkpoint["state_dict"])

    # Initialize DataParallel
    parallel_model = nn.DataParallel(BatchWrapper(model), device_ids=gpu_ids)
    model = parallel_model.to(parallel_model.src_device_obj)

    # Evaluation loop
    parallel_model.eval()
    criterion = nn.MSELoss()
    residue_losses = []

    standard_residues = {
        "ALA",
        "ARG",
        "ASN",
        "ASP",
        "ASX",
        "CYS",
        "GLU",
        "GLN",
        "GLX",
        "GLY",
        "HIS",
        "ILE",
        "LEU",
        "LYS",
        "MET",
        "PHE",
        "PRO",
        "SER",
        "THR",
        "TRP",
        "TYR",
        "VAL",
    }
    amino_acid_mapping = {
        name: index + 1 for index, name in enumerate(sorted(standard_residues))
    }
    reverse_amino_acid_mapping = {
        index: name for name, index in amino_acid_mapping.items()
    }

    for batch in tqdm(data_loader, desc="Evaluation", total=len(data_loader)):
        coord, embed, force, batch_nums = batch
        force = force.reshape(-1, force.shape[-1])
        force = force.to(parallel_model.output_device)
        _, out = parallel_model(embed, coord, batch_nums)
        embed = embed.flatten()

        for i in range(len(embed)):
            residue_id = embed[i].item()
            loss = criterion(force[i], out[i]).item()
            residue_name = reverse_amino_acid_mapping.get(residue_id, "UNKNOWN")
            residue_losses.append(
                {
                    "frame_index": i,
                    "residue_id": residue_id,
                    "residue_name": residue_name,
                    "loss": loss,
                }
            )

    return residue_losses

In [None]:
# python

# plot_residues.py /media/DATA_18_TB_1/daniel_s/cgschnet/cg_single_chain_CAprior_2024.04.03/
# /media/DATA_18_TB_1/daniel_s/cgschnet/cgschnet_models/cg_single_chain_CAprior_2024.04.03/ --batch 32

input_directory = (
    "/media/DATA_18_TB_1/daniel_s/cgschnet/cg_single_chain_CAprior_2024.04.03/"
)
model_directory = "/media/DATA_18_TB_1/daniel_s/cgschnet/cgschnet_models/cg_single_chain_CAprior_2024.04.03/"
residue_losses = val_model(
    input_directory, model_directory, gpu_ids=None, batch_size=1
)  # We are setting batch size to 1 to avoid batching

In [None]:
import pandas as pd

df = pd.DataFrame(residue_losses)
# write to txt
print(df)

In [None]:
# Save to csv
df = df.reindex(sorted(df.columns), axis=1)
df.to_csv("residue_losses.csv", index=False)

In [None]:
def plot_performance(df):
    # Filter the data to include only the first 50 frames
    df_filtered = df.loc[df["frame_index"] < 50].copy()
    df_filtered.loc[:, "frame_index"] = df_filtered["frame_index"].astype(int)

    # Plot using seaborn
    plt.figure(figsize=(20, 10))
    sns.barplot(
        data=df_filtered, x="frame_index", y="loss", hue="residue_name", errorbar=None
    )

    # Set plot labels and title
    plt.xlabel("Frame Index")
    plt.ylabel("Loss")
    plt.title("Per-Residue Losses")
    plt.legend(title="Residue Name", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()

    # Show the plot
    plt.show()


# Plot the bar graph
plot_performance(df)

In [None]:
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Train a CGSchNet network")
    parser.add_argument("input", help="Processed data to train on ")
    parser.add_argument(
        "model", default=None, nargs="?", help="Checkpoint directory to continue"
    )
    parser.add_argument(
        "--gpus",
        default=None,
        type=str,
        help='List of GPUs to train on (e.g. "0,1,2") ',
    )
    parser.add_argument("--batch", type=int, default=50, help="The batch size to use")

    args = parser.parse_args()

    input_path = args.input
    model_path = args.model
    if args.gpus:
        gpu_ids = [int(i) for i in args.gpus.strip().split(",")]
    else:
        gpu_ids = None
    batch_size = args.batch
    val_model(
        input_path, model_directory=model_path, gpu_ids=gpu_ids, batch_size=batch_size
    )