# PPFT example notebook
This notebook contains a toy example of training with the PPFT (property prediction fine-tuning) loss. In this example, the pretrained BioEmu model is finetuned to modify the proportion of folded states that it samples for a single protein.

Finetuning in this narrow way may well have detrimental effects on model performance. In [our own work](https://doi.org/10.1101/2024.12.05.626885) when we finetuned with PPFT, we interspersed PPFT weight updates with standard denoising score-matching updates using structures of a large variety of proteins.

In [None]:
from pathlib import Path

OUTPUT_DIR = Path("~/ppft_example_output").expanduser()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
repo_dir=Path("~/bioemu").expanduser()

# Reference structure downloaded from https://zenodo.org/records/7992926 
reference_pdb = repo_dir / "notebooks" / "HHH_rd1_0335.pdb"
assert reference_pdb.exists(), f"Reference structure not found at {reference_pdb}"


In [None]:
# Compute reference contacts using the reference structure.
from bioemu.training.foldedness import TargetInfo, compute_contacts
import numpy as np
import mdtraj


reference_traj = mdtraj.load(reference_pdb)
reference_traj = reference_traj.atom_slice(reference_traj.topology.select("name CA"))
reference_info = compute_contacts(traj=reference_traj)



In [None]:
from bioemu.sample import main as sample_main
from bioemu.training.foldedness import compute_fnc_for_list, foldedness_from_fnc
import torch
from bioemu.chemgraph import ChemGraph
from matplotlib import pyplot as plt

# See TargetInfo for documentation of these parameters.
STEEPNESS = 10.0
P_FOLD_THR = 0.5


def sample_and_plot(output_dir: Path, ckpt_path: Path | None = None, model_config_path: Path | None = None) -> None:
    """Generate samples using ckpt_path and model_config_path. Plot their FNC and foldedness. Save the samples and plots to output_dir."""

    # Generate samples.
    sample_main(sequence=reference_info.sequence,num_samples=300,output_dir=output_dir, filter_samples=False, ckpt_path=ckpt_path, model_config_path=model_config_path)
    
    # Compute fraction of native contacts (FNC) and foldedness of the samples.
    traj = mdtraj.load(output_dir / "samples.xtc", top=output_dir/ "topology.pdb")
    traj = traj.atom_slice(traj.topology.select("name CA"))
    chemgraph_list = [ChemGraph(pos=torch.tensor(traj.xyz[i]),sequence=reference_info.sequence) for i in range(len(traj))]
    fnc = compute_fnc_for_list(batch=chemgraph_list,reference_info=reference_info)
    foldedness = foldedness_from_fnc(fnc=fnc, p_fold_thr=P_FOLD_THR, steepness=STEEPNESS)

    # Plot the results.
    plt.subplot(1,2,1)
    plt.hist(fnc.numpy(), bins=50, density=True)
    plt.xlabel("FNC")
    plt.ylabel("Density")
    plt.title("Fraction of native contacts")
    plt.text(0.5, 0.9, f"Mean FNC: {fnc.mean().item():.3f}", ha='center', va='center', transform=plt.gca().transAxes)

    plt.subplot(1,2,2)
    plt.hist(foldedness.numpy(), bins=50, density=True)
    plt.xlabel("Foldedness")
    plt.ylabel("Density")
    plt.title("Foldedness")
    plt.text(0.5, 0.9, f"Mean foldedness: {foldedness.mean().item():.2f}", ha='center', va='center', transform=plt.gca().transAxes)

    plt.savefig(output_dir / "histograms.png")

In [None]:
# Sample and plot using default checkpoint.
pretrained_samples_dir = OUTPUT_DIR / "pretrained_samples"
sample_and_plot(output_dir=pretrained_samples_dir, ckpt_path=None)

In [None]:
from bioemu.sample import maybe_download_checkpoint, load_model, load_sdes
from bioemu.training.loss import calc_ppft_loss
from bioemu.get_embeds import get_colabfold_embeds
import shutil

def finetune_with_target(p_fold_target: float, seed: int = 42) -> tuple[Path, Path]:
    """Finetune the pretrained BioEmu model to sample structures with mean foldedness p_fold_target."""
    output_dir = OUTPUT_DIR / f"target_{p_fold_target:.2f}_seed_{seed}"
    output_dir.mkdir(parents=True, exist_ok=True)

    target_info = TargetInfo(p_fold_thr=P_FOLD_THR, steepness=STEEPNESS, p_fold_target=p_fold_target)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ckpt_path, model_config_path = maybe_download_checkpoint(model_name="bioemu-v1.0")
    shutil.copy(model_config_path, output_dir / "config.yaml")
    model_config_path = output_dir / "config.yaml"
    sdes = load_sdes(model_config_path=model_config_path)
    score_model = load_model(ckpt_path=ckpt_path, model_config_path=model_config_path).to(device)
    score_model.train()
    n_steps_train = 200
    checkpoint_every = 50
    n = len(reference_info.sequence)
    system_id = reference_pdb.stem

    single_embeds_file, pair_embeds_file = get_colabfold_embeds(
        seq=reference_info.sequence, cache_embeds_dir=None
    )
    single_embeds = np.load(single_embeds_file)
    pair_embeds = np.load(pair_embeds_file)
    assert pair_embeds.shape[0] == pair_embeds.shape[1] == n
    assert single_embeds.shape[0] == n
    assert len(single_embeds.shape) == 2
    _, _, n_pair_feats = pair_embeds.shape  # [seq_len, seq_len, n_pair_feats]

    single_embeds, pair_embeds = torch.from_numpy(single_embeds), torch.from_numpy(pair_embeds)
    pair_embeds = pair_embeds.view(n**2, n_pair_feats)

    edge_index = torch.cat(
        [
            torch.arange(n).repeat_interleave(n).view(1, n**2),
            torch.arange(n).repeat(n).view(1, n**2),
        ],
        dim=0,
    )
    pos = torch.full((n, 3), float("nan"))
    node_orientations = torch.full((n, 3, 3), float("nan"))

    chemgraph = ChemGraph(
        edge_index=edge_index,
        pos=pos,
        node_orientations=node_orientations,
        single_embeds=single_embeds,
        pair_embeds=pair_embeds,
        system_id=system_id,
        sequence=reference_info.sequence,
    ).to(device)


    optimizer = torch.optim.Adam(
        score_model.parameters(),
        lr=1e-5,
        eps=1e-6,
    )

    rolling_mean_loss = None

    try:
        for i in range(n_steps_train):
            print(f"Iteration {i+1}/{n_steps_train}")
            optimizer.zero_grad()
            loss = calc_ppft_loss(score_model=score_model, 
                                  sdes=sdes, 
                                  batch=[chemgraph] * 10, 
                                  n_replications=2,
                                  mid_t=0.786, 
                                  target_info_lookup={system_id: target_info}, 
                                  N_rollout=7, 
                                  record_grad_steps={3,4,5}, 
                                  reference_info_lookup={system_id: reference_info})
            loss.backward()
            assert not torch.isnan(loss).any(), "Loss contains NaN values"
            rolling_mean_loss = loss.item() if rolling_mean_loss is None else 0.9 * rolling_mean_loss + 0.1 * loss.item()
        
            print(f"Rolling mean loss: {rolling_mean_loss:.4f}")
            optimizer.step()
            if (i + 1) % checkpoint_every == 0:
                checkpoint_path = output_dir / f"step_{i+1}.ckpt"
                torch.save(score_model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")
    except KeyboardInterrupt:
        print("Training interrupted. Saving the model...")
    
    checkpoint_path = output_dir / f"step_{i+1}.ckpt"
    torch.save(score_model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}, config at {model_config_path}")
    return checkpoint_path, model_config_path

checkpoint_path, model_config_path = finetune_with_target(p_fold_target=0.5, seed=42)
sample_and_plot(output_dir=checkpoint_path.with_suffix(".samples"), ckpt_path=checkpoint_path, model_config_path=model_config_path)
