
# GB1 FITNESS PREDICTION WITH PROTEIN LANGUAGE MODEL EMBEDDINGS

Dataset: GB1 domain from Wu et al. (2016) eLife
Reference: https://elifesciences.org/articles/16965
Data from FLIP benchmark: https://github.com/J-SNACKKB/FLIP

This script demonstrates:
1. Loading multiple train/test splits from FLIP benchmark
2. Extracting embeddings using ESM-2 protein language model
3. Regression using two methods (Random Forest, Ridge)
4. Comparing performance across different splits
5. In silico screening to identify high-fitness variants

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import torch
from transformers import AutoTokenizer, EsmModel
import warnings
import gc
import copy
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7a3725dbf070>

In [2]:
if 'google.colab' in str(get_ipython()):
    print("Running on Google Colab. Executing Colab-specific commands...")
    # Mount Google Drive to access files
    from google.colab import drive
    drive.mount('/content/drive')

    # Drive location for the fasta files
    data_loc = '/content/drive/MyDrive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking/'

else:
    print("Not running on Google Colab. Skipping Colab-specific commands.")
    print("Running in a local environment or Jupyter Notebook.")
    #data_loc = '/home/davidnannemann/AIDD4B/ProteinLMs/GB1'

Running on Google Colab. Executing Colab-specific commands...
Mounted at /content/drive


### STEP 1: Load Data from Multiple Splits

In [3]:
def load_gb1_splits(base_dir=f"{data_loc}/GB1_data/splits/"):
    """
    Load GB1 data from different split strategies.

    Available splits:
    - one_vs_rest: Single mutants in train, rest in test
    - two_vs_rest: Single + double mutants in train, rest in test
    - three_vs_rest: Single + double + triple mutants in train, rest in test
    - sampled: Random 80/20 split
    """

    splits = ['one_vs_rest', 'two_vs_rest', 'three_vs_rest', 'sampled']
    all_data = {}

    print("STEP 1: Loading GB1 Data from FLIP Benchmark")
    print("-" * 80)

    for split in splits:
        try:
            # Try to load from URL
            split_path = f"{base_dir}{split}.csv"
            df = pd.read_csv(split_path)
            df['split_name'] = split
            all_data[split] = df
            print(f"✓ Loaded {split}: {len(df)} sequences")
            print(f"  Train: {(df['set']=='train').sum()}, Test: {(df['set']=='test').sum()}")
        except Exception as e:
            print(f"✗ Could not load {split} from URL: {e}")
            print(f"  Please download manually from: {split_path}")

    # if not all_data:
    #     print("\n⚠ No data loaded from URL. Creating demo dataset...")
    #     all_data = create_demo_gb1_data()

    print()
    return all_data

all_splits = load_gb1_splits()

STEP 1: Loading GB1 Data from FLIP Benchmark
--------------------------------------------------------------------------------
✗ Could not load one_vs_rest from URL: [Errno 2] No such file or directory: '/content/drive/MyDrive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking//GB1_data/splits/one_vs_rest.csv'
  Please download manually from: /content/drive/MyDrive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking//GB1_data/splits/one_vs_rest.csv
✗ Could not load two_vs_rest from URL: [Errno 2] No such file or directory: '/content/drive/MyDrive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_benchmarking//GB1_data/splits/two_vs_rest.csv'
  Please download manually from: /content/drive/MyDrive/AIDrivenDesignOfBiologics/AIDrivenDesignOfBiologics-PEGSEurope-2025/ProteinLanguageModels/pLM_basics_and_be

In [4]:
all_df = pd.DataFrame()

# there is probably a faster way to do this join, but figuring the syntax would take me longer than writing this function.
for split in all_splits:
    df = all_splits[split]
    #print(df.shape)
    for i, row in df.iterrows():
        if not row['sequence'] in all_df.index.values:
            all_df.at[row['sequence'],'sequence'] = row['sequence']
        all_df.at[row['sequence'],'target'] = row['target']
        all_df.at[row['sequence'],f"{split} split set"] = row['set']

all_df

In [5]:
sns.histplot(all_df, x='target')

ValueError: Could not interpret value `target` for `x`. An entry with this name does not appear in `data`.

In [6]:
# count the number of variants with activity less than x
print(f"activity fraction\tvariant count\tlibrary %")
for activity in [0.5,0.25,0.05, 0.02, 0.01, 0.0]:
    count = all_df.loc[all_df['target']<=activity].shape[0]
    print(f"{activity:.2f} \t\t\t{count} \t\t{count/all_df.shape[0]:.3f}")

activity fraction	variant count	library %


KeyError: 'target'

### STEP 2: Extract ESM-1b Embeddings

In [None]:
def extract_esm_embeddings(sequences, model_name="facebook/esm2_t33_650M_UR50D"):
    """
    Extract embeddings using ESM-2 model.

    ESM-2 is the latest version of ESM and performs well on fitness prediction.
    Using 650M parameter version for balance of performance and speed.
    """

    print("\nSTEP 2: Extracting ESM-2 Embeddings")
    print("-" * 80)
    print(f"Model: {model_name}")
    print(f"Sequences to embed: {len(sequences)}")

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmModel.from_pretrained(model_name)
    model.eval()

    # Move to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    print(f"Using device: {device}")

    embeddings = []

    # Process sequences in batches
    batch_size = 8
    with torch.no_grad():
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:i+batch_size]

            # Tokenize
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Get embeddings
            outputs = model(**inputs)

            # Use mean pooling over sequence length
            attention_mask = inputs['attention_mask']
            token_embeddings = outputs.last_hidden_state

            # Mean pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            batch_embeddings = sum_embeddings / sum_mask

            embeddings.append(batch_embeddings.cpu().numpy())

            # Clear intermediate tensors
            del inputs, outputs, token_embeddings, batch_embeddings

            if (i + batch_size) % 40 == 0:
                print(f"  Processed {min(i+batch_size, len(sequences))}/{len(sequences)} sequences")

    embeddings = np.vstack(embeddings)
    print(f"✓ Embedding shape: {embeddings.shape}")

    del model, tokenizer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print()

    return embeddings

embeddings = extract_esm_embeddings(all_df['sequence'].tolist())


### STEP 3: Train Regression Models

In [None]:
def train_and_evaluate(X_train, y_train, X_test, y_test, split_name):
    """
    Train two regression models and evaluate performance.

    Models:
    1. Random Forest: Non-linear, captures complex patterns
    2. Ridge Regression: Linear, fast and interpretable
    """

    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    results = {}

    # Model 1: Random Forest
    print("Training Random Forest Regression Model")
    rf_model = RandomForestRegressor(
        n_estimators=100,
        max_depth=10,
        random_state=42,
        n_jobs=-1
    )
    rf_model.fit(X_train_scaled, y_train)
    rf_pred = rf_model.predict(X_test_scaled)

    results['Random Forest'] = {
        'model': rf_model,
        'scaler': scaler,
        'predictions': rf_pred,
        'mse': mean_squared_error(y_test, rf_pred),
        'mae': mean_absolute_error(y_test, rf_pred),
        'r2': r2_score(y_test, rf_pred),
        'spearman': pd.Series(y_test).corr(pd.Series(rf_pred), method='spearman')
    }

    # Model 2: Ridge Regression
    print("Training Ridge Regression Model")
    ridge_model = Ridge(alpha=1.0, random_state=42)
    ridge_model.fit(X_train_scaled, y_train)
    ridge_pred = ridge_model.predict(X_test_scaled)

    results['Ridge'] = {
        'model': ridge_model,
        'scaler': scaler,
        'predictions': ridge_pred,
        'mse': mean_squared_error(y_test, ridge_pred),
        'mae': mean_absolute_error(y_test, ridge_pred),
        'r2': r2_score(y_test, ridge_pred),
        'spearman': pd.Series(y_test).corr(pd.Series(ridge_pred), method='spearman')
    }

    return results


### STEP 4: Iterate Over All Splits

In [None]:
# Store results
all_results = []
split_predictions = {}

for split in ['one_vs_rest', 'two_vs_rest', 'three_vs_rest', 'sampled']:

    print(f"\nProcessing split: {split}")
    print("  " + "-" * 76)

    # Split data
    train_mask = all_df[f'{split} split set'] == 'train'
    test_mask  = all_df[f'{split} split set'] == 'test'

    X_train = embeddings[train_mask]
    y_train = all_df.loc[train_mask, 'target'].values
    X_test = embeddings[test_mask]
    y_test = all_df.loc[test_mask, 'target'].values

    print(f"  Train size: {len(X_train)}, Test size: {len(X_test)}")

    # Train and evaluate
    results = train_and_evaluate(X_train, y_train, X_test, y_test, split)

    # Store results
    for model_name, metrics in results.items():
        try:
            result_dict = {
                'Split': split,
                'Model': model_name,
                'R²': metrics['r2'],
                'Spearman': metrics['spearman'],
                'MAE': metrics['mae'],
                'RMSE': np.sqrt(metrics['mse'])
            }
        except TypeError:
            result_dict = {
                'Split': split,
                'Model': model_name,
                'R²': metrics['r2'],
                'Spearman': metrics['spearman'],
                'MAE': metrics['mae'],
                'RMSE': 0.0
            }

        # Add classification metrics for two-stage model
        if model_name == 'Two-Stage':
            result_dict['Class_Acc'] = metrics['classification_acc']
            result_dict['Class_AUC'] = metrics['classification_auc']

        all_results.append(result_dict)

    # Store for plotting
    split_predictions[split] = {
        'y_test': y_test,
        'rf_pred': results['Random Forest']['predictions'],
        'ridge_pred': results['Ridge']['predictions']
    }

    print(f"  ✓ Random Forest - R²: {results['Random Forest']['r2']:.3f}, Spearman: {results['Random Forest']['spearman']:.3f}")
    print(f"  ✓ Ridge         - R²: {results['Ridge']['r2']:.3f}, Spearman: {results['Ridge']['spearman']:.3f}")

    gc.collect()

In [None]:
results_df = pd.DataFrame(all_results)
print("\nPerformance Metrics:")
print(results_df.to_string(index=False))

# Save to CSV
# results_df.to_csv('gb1_results_summary.csv', index=False)

### STEP 5: Visualizations

In [None]:
# Plot 1: Performance comparison across splits
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Spearman correlation
for model in ['Random Forest', 'Ridge']: #, 'XGBoost', 'Two-Stage'
    model_data = results_df[results_df['Model'] == model]
    axes[0].plot(model_data['Split'], model_data['Spearman'],
                 marker='o', linewidth=2, markersize=8, label=model)

axes[0].set_xlabel('Split Strategy', fontsize=12)
axes[0].set_ylabel('Spearman Correlation', fontsize=12)
axes[0].set_title('Model Performance Across Splits', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].tick_params(axis='x', rotation=45)

# R² score
for model in ['Random Forest', 'Ridge']: # , 'XGBoost', 'Two-Stage'
    model_data = results_df[results_df['Model'] == model]
    axes[1].plot(model_data['Split'], model_data['R²'],
                 marker='s', linewidth=2, markersize=8, label=model)

axes[1].set_xlabel('Split Strategy', fontsize=12)
axes[1].set_ylabel('R² Score', fontsize=12)
axes[1].set_title('R² Performance Across Splits', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('gb1_split_comparison.png', dpi=300, bbox_inches='tight')
print("✓ Saved: gb1_split_comparison.png")

# Plot 2: Predicted vs Actual for each split
n_splits = len(split_predictions)
fig, axes = plt.subplots(3, n_splits, figsize=(5*n_splits, 10))
if n_splits == 1:
    axes = axes.reshape(-1, 1)

for idx, (split_name, preds) in enumerate(split_predictions.items()):
    y_test = preds['y_test']

    # Random Forest
    axes[0, idx].scatter(y_test, preds['rf_pred'], alpha=0.6, s=50, edgecolors='black')
    axes[0, idx].plot([y_test.min(), y_test.max()],
                      [y_test.min(), y_test.max()],
                      'r--', lw=2, label='Perfect')
    axes[0, idx].set_xlabel('Actual Fitness', fontsize=11)
    axes[0, idx].set_ylabel('Predicted Fitness', fontsize=11)
    axes[0, idx].set_title(f'{split_name}\nRandom Forest', fontsize=12, fontweight='bold')
    axes[0, idx].grid(True, alpha=0.3)
    axes[0, idx].legend()

    # Ridge
    axes[1, idx].scatter(y_test, preds['ridge_pred'], alpha=0.6, s=50,
                         edgecolors='black', color='orange')
    axes[1, idx].plot([y_test.min(), y_test.max()],
                      [y_test.min(), y_test.max()],
                      'r--', lw=2, label='Perfect')
    axes[1, idx].set_xlabel('Actual Fitness', fontsize=11)
    axes[1, idx].set_ylabel('Predicted Fitness', fontsize=11)
    axes[1, idx].set_title(f'{split_name}\nRidge Regression', fontsize=12, fontweight='bold')
    axes[1, idx].grid(True, alpha=0.3)
    axes[1, idx].legend()

    # Comparison of Ridge and Random Forest
    axes[2, idx].scatter(preds['rf_pred'], preds['ridge_pred'], alpha=0.6, s=50,
                         edgecolors='black', color='green')
    axes[2, idx].plot([y_test.min(), y_test.max()],
                      [y_test.min(), y_test.max()],
                      'r--', lw=2)
    axes[2, idx].set_xlabel('Predicted Fitness (Random Forest)', fontsize=11)
    axes[2, idx].set_ylabel('Predicted Fitness (Ridge)', fontsize=11)
    axes[2, idx].set_title(f'{split_name}\nRegression Comparison', fontsize=12, fontweight='bold')
    axes[2, idx].grid(True, alpha=0.3)
    axes[2, idx].legend()

plt.tight_layout()
plt.savefig('gb1_predictions_by_split.png', dpi=300, bbox_inches='tight')
print("✓ Saved: gb1_predictions_by_split.png")

plt.show()


Examine these results and consider the following.

- How does the amount or character of the training data effect the quality of the predicted fitness?
   - Which splitting method shows highest performance?
   - Is there a potential for data leakage in some of the splitting methods?
   - Which method for generating training data is most practical from a wet-lab perspective?
- What aspects of the data might hinder model training?
    - How are non-functional variants differentiated in a regression? Zero is zero but some mutations will be more disruptive.
    - Could you improve the quality of training by removing non-functional variants? Or using a two-stage method that separates 'functional' classification from 'how functional' regression?
- Would an ensemble method be useful for picking variants for the next round of screening?

### STEP 6: In Silico Screening - Find High-Fitness Variants

In [None]:
def generate_mutant_library(wt_sequence, positions_to_mutate, amino_acids='ACDEFGHIKLMNPQRSTVWY',
                            double_mutant=False, triple_mutant=False, random_mutant=False,
                            max_mutants=1000):
    """
    Generate a library of single and double mutants at specified positions.

    Parameters:
    -----------
    wt_sequence : str
        Wild-type protein sequence
    positions_to_mutate : list of int
        Positions (0-indexed) to introduce mutations
    amino_acids : str
        Amino acids to consider for mutations

    Returns:
    --------
    variants : list of dict
        List of variant information with sequences and mutation descriptions
    """

    variant_seqs = set() # Use a set to ensure no duplicates
    variants = []

    # Add wild-type
    variant_seqs.add(wt_sequence)
    variants.append({
        'sequence': wt_sequence,
        'mutation': 'WT',
        'n_mutations': 0
    })

    # Generate single mutants
    for pos in positions_to_mutate:
        wt_aa = wt_sequence[pos]
        for aa in amino_acids:
            if aa != wt_aa:
                mutant_seq = wt_sequence[:pos] + aa + wt_sequence[pos+1:]
                if mutant_seq in variant_seqs: continue
                variant_seqs.add(mutant_seq)
                variants.append({
                    'sequence': mutant_seq,
                    'mutation': f'{wt_aa}{pos+1}{aa}',
                    'n_mutations': 1
                })

    if double_mutant:
        # Generate double mutants (optional - can be computationally expensive)
        for i, pos1 in enumerate(positions_to_mutate):
            wt_aa1 = wt_sequence[pos1]
            for pos2 in positions_to_mutate[i+1:]:
                wt_aa2 = wt_sequence[pos2]
                for aa1 in amino_acids:
                    if aa1 != wt_aa1:
                        for aa2 in amino_acids:
                            if aa2 != wt_aa2:
                                mutant_seq = wt_sequence[:pos1] + aa1 + wt_sequence[pos1+1:]
                                mutant_seq = mutant_seq[:pos2] + aa2 + mutant_seq[pos2+1:]
                                if mutant_seq in variant_seqs: continue
                                variant_seqs.add(mutant_seq)
                                variants.append({
                                    'sequence': mutant_seq,
                                    'mutation': f'{wt_aa1}{pos1+1}{aa1}/{wt_aa2}{pos2+1}{aa2}',
                                    'n_mutations': 2
                                })

    if triple_mutant:
    # Generate triple mutants (optional - can be computationally expensive)
        for i, pos1 in enumerate(positions_to_mutate):
            wt_aa1 = wt_sequence[pos1]
            for pos2 in positions_to_mutate[i+1:]:
                wt_aa2 = wt_sequence[pos2]
                for pos3 in positions_to_mutate[i+2:]:
                    wt_aa3 = wt_sequence[pos3]
                    for aa1 in amino_acids:
                        if aa1 != wt_aa1:
                            for aa2 in amino_acids:
                                if aa2 != wt_aa2:
                                    for aa3 in amino_acids:
                                        if aa3 != wt_aa3:
                                            mutant_seq = wt_sequence[:pos1] + aa1 + wt_sequence[pos1+1:]
                                            mutant_seq = mutant_seq[:pos2] + aa2 + mutant_seq[pos2+1:]
                                            mutant_seq = mutant_seq[:pos3] + aa3 + mutant_seq[pos3+1:]
                                            if mutant_seq in variant_seqs: continue
                                            variant_seqs.add(mutant_seq)
                                            variants.append({
                                                'sequence': mutant_seq,
                                                'mutation': f'{wt_aa1}{pos1+1}{aa1}/{wt_aa2}{pos2+1}{aa2}/{wt_aa3}{pos3+1}{aa3}',
                                                'n_mutations': 3
                                            })

    if random_mutant:
    # Generate random mutations (optional - can be computationally expensive)
        n=0
        #while n<10:
        while len(variant_seqs)<max_mutants:
            mutant_seq = copy.deepcopy(wt_sequence) # deep copy the wild-type sequence so mutant_seq is not just a pointer to the wt_sequence object

            # pick the positions that are to be mutated. If not specified, a very high probability of variants mutated at every position are created
            mutate_positions = [np.random.rand()<0.5 for pos in positions_to_mutate] # decrease value to increase ratio of multi-mutants

            mutant_str = [] # track the "name" of each mutated position

            # create the mutated sequence
            for mutate, pos in zip(mutate_positions, positions_to_mutate):
                if mutate:
                    wt_aa = wt_sequence[pos]
                    mut_aa = np.random.choice(list(amino_acids)) # pick a random amino acid
                    mutant_seq = mutant_seq[:pos] + mut_aa + mutant_seq[pos+1:]
                    mutant_str.append(f"{wt_aa}{pos+1}{mut_aa}")
            if mutant_seq in variant_seqs: continue
            variant_seqs.add(mutant_seq)
            variants.append({
                'sequence': mutant_seq,
                'mutation': "/".join(mutant_str),
                'n_mutations': sum(mutate_positions)
            })
            n+=1

    return variants

# wt_sequence = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH"
# positions_to_mutate = [38, 39, 40, 53]
# variants = generate_mutant_library(wt_sequence, positions_to_mutate,
#                                    double_mutant=False, triple_mutant=False, random_mutant=True)
# print(variants)

In [None]:
def screen_variants_for_high_fitness(model, scaler, known_sequences, wt_sequence,
                                     positions_to_mutate, top_n=20, batch_size=40,
                                     double_mutant=False, triple_mutant=False, random_mutant=False,
                                     max_mutants=1000):
    """
    In silico screening to identify high-fitness variants not in training set.

    This function demonstrates how to use the trained model for variant discovery:
    1. Generate a library of variants
    2. Filter out known sequences
    3. Extract embeddings for all variants
    4. Predict fitness using trained model
    5. Return top predicted variants for experimental validation

    Parameters:
    -----------
    model : trained model
        The regression model to use for predictions
    scaler : StandardScaler or None
        Feature scaler (if applicable)
    known_sequences : set
        Set of sequences already characterized (to exclude)
    wt_sequence : str
        Wild-type sequence
    positions_to_mutate : list of int
        Positions to mutate (0-indexed)
    top_n : int
        Number of top candidates to return
    batch_size : int
        Batch size for embedding extraction

    Returns:
    --------
    candidates_df : DataFrame
        Top predicted variants with predicted fitness scores
    """

    print("\n" + "="*80)
    print("IN SILICO SCREENING FOR HIGH-FITNESS VARIANTS")
    print("="*80)
    print()

    # Step 1: Generate variant library
    print("Generating variant library...")
    variants = generate_mutant_library(wt_sequence, positions_to_mutate,
                                       double_mutant=double_mutant, triple_mutant=triple_mutant,
                                       random_mutant=random_mutant, max_mutants=max_mutants)
    print(f"✓ Generated {len(variants)} variants")

    # Step 2: Filter out known sequences
    novel_variants = [v for v in variants if v['sequence'] not in known_sequences]
    print(f"✓ {len(novel_variants)} novel variants (not in training/test set)")

    if len(novel_variants) == 0:
        print("⚠ All variants are already characterized!")
        return pd.DataFrame()

    # Step 3: Extract embeddings
    print("\nExtracting embeddings for variant library...")
    sequences = [v['sequence'] for v in novel_variants]
    embeddings = extract_esm_embeddings(sequences)

    # Step 4: Make predictions
    print("\nPredicting fitness for novel variants...")
    if scaler is not None:
        embeddings_scaled = scaler.transform(embeddings)
    else:
        embeddings_scaled = embeddings

    predictions = model.predict(embeddings_scaled)

    # Step 5: Create results DataFrame
    candidates_df = pd.DataFrame(novel_variants)
    candidates_df['predicted_fitness'] = predictions

    # Step 6: Sort by predicted fitness and select top candidates
    candidates_df = candidates_df.sort_values('predicted_fitness', ascending=False)
    top_candidates = candidates_df.head(top_n)

    print(f"\n✓ Top {top_n} predicted high-fitness variants:")
    print("-" * 80)

    display_cols = ['mutation', 'predicted_fitness', 'n_mutations']
    if 'functionality_prob' in top_candidates.columns:
        display_cols.append('functionality_prob')

    print(top_candidates[display_cols].to_string(index=False))

    # Save to file
    candidates_df.to_csv('gb1_screening_results.csv', index=False)
    print(f"\n✓ Full screening results saved to: gb1_screening_results.csv")
    print(f"  (Total novel variants screened: {len(candidates_df)})")

    return top_candidates


In [None]:
# Get all known sequences
known_sequences = set(df['sequence'].tolist())
print(f"Known sequences in dataset: {len(known_sequences)}")

# Define wild-type and positions to mutate
# GB1 wild-type at key positions
wt_sequence = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH"

# Focus on the 4 positions from the Wu et al. study (0-indexed)
# Positions 38, 39, 40, 53 (V39, D40, G41, V54 in paper notation)
positions_to_mutate = [38, 39, 40, 53]

print(f"Wild-type sequence length: {len(wt_sequence)}")
print(f"Positions to mutate: {positions_to_mutate}")
print(f"WT amino acids at these positions: {[wt_sequence[p] for p in positions_to_mutate]}")


In [None]:
# Re-train model on full dataset for best predictions
print("\nRetraining Ridge model on full dataset...")
X_all = embeddings
y_all = all_df['target'].values

scaler = StandardScaler()
X_all_scaled = scaler.fit_transform(X_all)

final_model = Ridge(alpha=1.0, random_state=42)
final_model.fit(X_all_scaled, y_all)
print("✓ Model trained on full dataset")


In [None]:
# Perform screening
top_candidates = screen_variants_for_high_fitness(
    model=final_model,
    scaler=scaler,
    known_sequences=known_sequences,
    wt_sequence=wt_sequence,
    positions_to_mutate=positions_to_mutate,
    top_n=100,
    double_mutant=True, triple_mutant=False, random_mutant=True,
    max_mutants=1500 #50,000 provides a high-quality prediction but takes a while to run
)

#### Visualization of Screening Results


In [None]:
if len(top_candidates) > 0:
    fig, axes = plt.subplots(1, 1, figsize=(7, 5))

    # Plot 1: Top candidates by mutation type
    mutation_counts = top_candidates.groupby('n_mutations')['predicted_fitness'].apply(list)

    positions = []
    fitness_values = []
    labels = []

    for n_mut, fitnesses in mutation_counts.items():
        positions.extend([n_mut] * len(fitnesses))
        fitness_values.extend(fitnesses)
        labels.append(f'{n_mut} mutation(s)')

    axes.scatter(positions, fitness_values, s=100, alpha=0.6, edgecolors='black')
    axes.set_xlabel('Number of Mutations', fontsize=12)
    axes.set_ylabel('Predicted Fitness', fontsize=12)
    axes.set_title('Top Candidates by Mutation Count', fontsize=14, fontweight='bold')
    axes.grid(True, alpha=0.3)
    axes.set_xticks(sorted(top_candidates['n_mutations'].unique()))

    plt.tight_layout()
    plt.savefig('gb1_screening_visualization.png', dpi=300, bbox_inches='tight')
    print("✓ Saved: gb1_screening_visualization.png")
    plt.show()

**Discussion points**
- In a system with four mutation sites there are 160,000 variants. Is it possible to make a prediction for all of these?
- What alternative algorithms might be used to generate variants with the greatest predicted fitness?
    - Would a genetic algorithm be a practical method to traverse the sequence space?


# Regression vs Fine-tuning

Think about the difference between regression, performed here, and model fine-tuning. Explore when each is practical. Think about opportunities and pitfalls of each.

The Gray Lab at Johns Hopkins has prepared the _Fitness Landscape for Antibodies (FLAb)_ database of antibody developability characteristics and scripts for fine-tuning of pLMs. Many of the datasets are behind non-commercial licenses, unfortunately, but some are permissively licensed. ([FLAb repository](https://github.com/Graylab/FLAb)).

On your own, explore the zero-shot prediction methods in the scripts.