In [None]:
import pandas as pd
import gemmi
import matplotlib.pyplot as plt

output_labels_csv_name = '/dls/labxchem/data/2018/lb18145-80/processing/analysis/eugene/pandda_score/score_examples/NSP14-x1069/output_labels_NSP14-x1069-aligned-structure.csv'

df = pd.read_csv(output_labels_csv_name)

def read_chain_name(input_model: str,
                    input_chain_idx: int):
    structure = gemmi.read_structure(input_model)[0]
    chain_name = structure[input_chain_idx].name
    return chain_name

def list_residues_and_pred_probs_and_labels(output_labels_dataframe: pd.DataFrame):
    input_model_list = output_labels_dataframe['input_model'].tolist()
    input_chain_idx_list = output_labels_dataframe['input_chain_idx'].tolist()
    chain_names = list(map(read_chain_name, 
                            input_model_list, 
                            input_chain_idx_list))
    residue_names = output_labels_dataframe['residue_name'].tolist()
    chain_residue_names = list(map(lambda x, y: f'{x}-{y}', chain_names, residue_names))
    pred_probabilities = output_labels_dataframe['pred_probabilities'].tolist()
    pred_labels = output_labels_dataframe['pred_label'].tolist()

    return chain_residue_names, pred_probabilities, pred_labels

def plot_residue_vs_pred_label(residue_names: list, pred_labels: list):
    def chunkify(lst,n):
        return [lst[i::n] for i in range(n)]

    chunks = int(len(residue_names) / 100)
    residue_names_chunks = chunkify(residue_names, chunks)
    pred_labels_chunks = chunkify(pred_labels, chunks)

    for residue_chunk, label_chunk in zip(residue_names_chunks, pred_labels_chunks):
        plt.figure(figsize=(20, 10))
        plt.bar(residue_chunk, label_chunk)
        plt.xticks(rotation=90)
        plt.xlim(-0.5,len(residue_chunk)-0.5)
        plt.xlabel('residue', fontsize=20)
        plt.ylabel('needs remodelling?', fontsize=20)
        plt.show()

def plot_residue_vs_pred_prob(residue_names: list, pred_probabilities: list, threshold: float = 0.17):
    def chunkify(lst,n):
        return [lst[i::n] for i in range(n)]

    chunks = int(len(residue_names) / 100)
    residue_names_chunks = chunkify(residue_names, chunks)
    pred_probabilities_chunks = chunkify(pred_probabilities, chunks)

    for residue_chunk, label_chunk in zip(residue_names_chunks, pred_probabilities_chunks):
        plt.figure(figsize=(20, 10))
        plt.bar(residue_chunk, label_chunk)
        plt.axhline(y=threshold, color='r', label='threshold')
        plt.legend()
        plt.xticks(rotation=90)
        plt.xlim(-0.5,len(residue_chunk)-0.5)
        plt.xlabel('residue',fontsize=20)
        plt.ylabel('output probability',fontsize=20)
        plt.show()


residue_names, pred_probabilities, pred_labels = list_residues_and_pred_probs_and_labels(df)
plot_residue_vs_pred_label(residue_names, pred_labels)
plot_residue_vs_pred_prob(residue_names, pred_probabilities)

print(residue_names)

