# Genie: De Novo Protein Design Demo

This notebook demonstrates how to use the Genie model for de novo protein backbone generation.

## 1. Environment Setup

First, we import necessary libraries and check for GPU availability.

In [None]:
import os
import sys
import glob
import shutil
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add current directory to path
sys.path.append(os.getcwd())

from genie.config import Config
from genie.utils.model_io import load_model

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 2. Prepare Pre-trained Weights

The provided weights are in a flat structure. The sampling script expects a specific directory hierarchy (`runs/<model_name>/version_<X>/checkpoints/`). We will create a temporary `runs` directory and link the weights correctly.

In [None]:
def setup_weights(source_path, model_name, target_root='runs', version=0):
    """
    Restructures weights from source_path to target_root compatible with Genie's loader.
    """
    
    # Target paths
    run_dir = os.path.join(target_root, model_name)
    version_dir = os.path.join(run_dir, f'version_{version}')
    ckpt_dir = os.path.join(version_dir, 'checkpoints')
    
    # Create directories
    os.makedirs(ckpt_dir, exist_ok=True)
    
    # 1. Copy/Link Configuration
    src_config = os.path.join(source_path, 'configuration')
    dst_config = os.path.join(run_dir, 'configuration')
    if os.path.exists(src_config) and not os.path.exists(dst_config):
        shutil.copy(src_config, dst_config)
        print(f"Copied configuration to {dst_config}")
        
    # 2. Link Checkpoint
    # Find .ckpt file in source
    ckpts = glob.glob(os.path.join(source_path, '*.ckpt'))
    if not ckpts:
        raise FileNotFoundError(f"No .ckpt files found in {source_path}")
    
    src_ckpt = ckpts[0] # Take the first one found
    ckpt_name = os.path.basename(src_ckpt)
    dst_ckpt = os.path.join(ckpt_dir, ckpt_name)
    
    if not os.path.exists(dst_ckpt):
        # Use symlink if possible, else copy
        try:
            os.symlink(os.path.abspath(src_ckpt), dst_ckpt)
            print(f"Symlinked checkpoint to {dst_ckpt}")
        except OSError:
            shutil.copy(src_ckpt, dst_ckpt)
            print(f"Copied checkpoint to {dst_ckpt}")

# Define model to use
MODEL_NAME = 'scope_l_128'
WEIGHTS_PATH = os.path.join('weights', MODEL_NAME)

# Set up the folder structure
setup_weights(WEIGHTS_PATH, MODEL_NAME)

## 3. Sampling

Now we load the model and generate some protein backbone structures.
We will generate proteins of length 64 as an example.

In [None]:
# Configuration for sampling
MODEL_VERSION = 0 # As set up above
ROOT_DIR = 'runs'
BATCH_SIZE = 1 # Number of samples per batch
NOISE_SCALE = 1.0 # Standard deviation of noise
LENGTH = 64 # Residue length to sample

# Load Model
# Note: Providing epoch=None loads the latest, which we set up.
model = load_model(ROOT_DIR, MODEL_NAME, MODEL_VERSION).to(device)
model.eval()

print(f"Model {MODEL_NAME} loaded successfully.")

# Setup Output Directory
out_dir = os.path.join('outputs', 'demo_samples')
os.makedirs(out_dir, exist_ok=True)

# Sampling Loop
print(f"Sampling {BATCH_SIZE} structure(s) of length {LENGTH}...")

# Create mask: [Batch, MaxRes]
# Model expects max_n_res sized mask usually, but checks sample loop
max_n_res = model.config.io['max_n_res']
mask = torch.cat([
    torch.ones((BATCH_SIZE, LENGTH)),
    torch.zeros((BATCH_SIZE, max_n_res - LENGTH))
], dim=1).to(device)

with torch.no_grad():
    # p_sample_loop returns a list of intermediate states. The last one is the final sample.
    # [-1] fetches the last step
    ts = model.p_sample_loop(mask, NOISE_SCALE, verbose=True)[-1]
    
    # Save samples
    saved_files = []
    for i in range(ts.shape[0]):
        coords = ts[i].trans.detach().cpu().numpy()
        coords = coords[:LENGTH] # Truncate to actual length
        
        filepath = os.path.join(out_dir, f'sample_{LENGTH}_{i}.npy')
        np.savetxt(filepath, coords, fmt='%.3f', delimiter=',')
        saved_files.append(filepath)
        print(f"Saved: {filepath}")

## 4. Visualization

Visualize the generated C-alpha backbone.

In [None]:
# Simple visualizations using Matplotlib
from mpl_toolkits.mplot3d import Axes3D

def plot_structure(filepath):
    coords = np.loadtxt(filepath, delimiter=',')
    
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_proj_type('persp', focal_length=0.2)
    
    xs = coords[:, 0]
    ys = coords[:, 1]
    zs = coords[:, 2]
    
    # Draw backbone
    ax.plot(xs, ys, zs, c='lightblue', linewidth=2, label='Backbone', alpha=0.8)
    ax.scatter(xs, ys, zs, c=np.arange(len(xs)), cmap='viridis', s=50, depthshade=True)
    
    ax.set_xlabel('X (Å)')
    ax.set_ylabel('Y (Å)')
    ax.set_zlabel('Z (Å)')
    ax.set_title(f'Structure: {os.path.basename(filepath)}')
    
    # View angle
    ax.view_init(elev=20., azim=-35)
    plt.show()

# Visualize the first sample
if saved_files:
    plot_structure(saved_files[0])

## 5. Evaluation & Analysis

The repository provides tools to evaluate the novelty of generated structures and analyze their properties.

### Novelty Evaluation
You can evaluate how "novel" your generated proteins are by comparing them against a reference database (e.g., PDB) using TM-score.

There are two provided scripts:
1. `Novelty_Evaluation_CPU.py`: Exhaustive search (slower, exact).
2. `Novelty_Evaluation_GPU.py`: Hybrid search using embeddings for fast screening (faster).

**Example Usage (Command generation):**
The following cell generates the commands you would run in a terminal. Note that you need a reference database (e.g., `data/pdbstyle-2.08`) installed for these to work.

In [None]:
# Define paths for evaluation
eval_input_dir = out_dir # The directory where we saved samples
ref_db_dir = os.path.join('data', 'pdbstyle-2.08')

print("To run CPU-based Novelty Evaluation:")
print(f"python Novelty_Evaluation_CPU.py --input_dir {eval_input_dir} --ref_dir {ref_db_dir} --num_workers 4")

print("\nTo run GPU-based Novelty Evaluation:")
print(f"python Novelty_Evaluation_GPU.py --input_dir {eval_input_dir}")

### Analysis Plotting

Once you have evaluation results (e.g., `info.csv` from the evaluation pipeline or `novelty.csv` from the scripts above), you can generate analysis plots.

1. **MDS Plot**: Visualizes the design space.
2. **General Analysis**: Plots pLDDT vs scTM, SSE distribution, etc.

In [None]:
print("To plot Design Space MDS:")
print(f"python plot_genie_mds_novelty.py --input_dir {eval_input_dir} --output_file mds_plot.png")

print("\nTo generate General Analysis Plots:")
print(f"python plot_genie_analysis.py --input_dir {eval_input_dir} --output_file analysis_plot.png")