In [None]:
import numpy as np
import matplotlib.pyplot as plt





def compute_lddt(predictionFilename: str, 
                 referenceFilename: str, 
                 outputFilename: str):
  %shell stage/bin/lddt -f -p /content/foldeval/openstructure/stereo_chemical_props.txt \
                        $predictionFilename \
                        $referenceFilename | tee $outputFilename
  global_lddt = 0
  lddt_scores = []
  reading_data = False
  for line in open(outputFilename, 'r'):
    if reading_data: 
      columns = line.split()
      if len(columns) == 0: continue
      lddt_scores.append(float(columns[5]))
    if line[:17] == 'Global LDDT score': global_lddt = float(line[18:25])
    if line[:5] == 'Chain': reading_data = True
  return global_lddt, np.array(lddt_scores)



def display_lddt_comparison(predictedScores,
                            measuredScores,
                            outputFilename: str,
                            dpi: int = 140):
  plt.figure(figsize=(8,5), dpi=dpi)
  plt.plot(predictedScores, color='#ccc', label='Predicted lDDT')
  plt.plot(measuredScores, label='Measured lDDT')
  plt.xlabel('Residues')
  plt.ylim(0, 1)
  plt.legend()
  plt.savefig(outputFilename, bbox_inches='tight')
  plt.show()



def compute_metrics(plddtScores, 
                    predictionFilename: str, 
                    referenceFilename: str, 
                    outputFilename: str = 'metrics.txt'):
  global jobname
  global sequences
  global algorithm
  global dpi
  sequence_length = len(sequences[0])
  with open(f'{jobname}/{outputFilename}', 'wt') as file:
    file.write(f'JOBNAME={jobname}\n')
    file.write(f'SEQUENCE_LENGTH={sequence_length}\n')
    file.write(f'ALGORITHM={algorithm}\n')
  # %shell /content/pymol-open-source-build/bin/pymol -qc $predictionFilename \
  %shell pymol -qc $predictionFilename \
    $referenceFilename -r /concent/foldeval/pymol_gdt.py 
  global_lddt, lddt_scores = compute_lddt(predictionFilename, referenceFilename, 
                                          f'{jobname}/lDDT.txt')
  display_lddt_comparison(plddtScores, lddt_scores, 
                          f'{jobname}/lddt_comparison.png', dpi)
  with open(f'{jobname}/{outputFilename}', 'wt') as file:
    file.write(f'LDDT={global_lddt:.4}\n')
    file.write('LDDT_ATOMS=all\n')
    file.write('LDDT_CUTOFFS=[ 0.5, 1, 2, 4 ]\n')
    file.write('LDDT_BOND_ANGLE_TOLERANCE=12\n')
    file.write('LDDT_BOND_LENGTH_TOLERANCE=12\n')
  