## 1. Import Libraries

In [None]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle

# ESM-C imports
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

## 2. Configuration

In [None]:
# Select ESM-C model variant
MODEL_NAME = "esmc_300m"  # Options: esmc_300m, esmc_600m, esmc_6b (via Forge API)
USE_FORGE_API = False  # Set True for esmc_6b with Forge token

# Forge API configuration (only needed for esmc_6b)
FORGE_TOKEN = "<your forge token>"  # Get from https://forge.evolutionaryscale.ai

# Device configuration (for local models only)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Model dimensions
MODEL_DIMS = {
    "esmc_300m": 960,
    "esmc_600m": 1152,
    "esmc_6b": 2560
}
EMBEDDING_DIM = MODEL_DIMS[MODEL_NAME]
print(f"Model: {MODEL_NAME}, Embedding dimension: {EMBEDDING_DIM}")
print(f"Using Forge API: {USE_FORGE_API}")

## 3. Load ESM-C Model

In [None]:
# Load model
if USE_FORGE_API:
    print(f"Connecting to ESM Forge API for {MODEL_NAME}...")
    from esm.sdk.forge import ESM3ForgeInferenceClient
    
    model = ESM3ForgeInferenceClient(
        model="esmc-6b-2024-12",
        url="https://forge.evolutionaryscale.ai",
        token=FORGE_TOKEN
    )
    print("Connected to Forge API successfully!")
else:
    print(f"Loading local model: {MODEL_NAME}...")
    model = ESMC.from_pretrained(MODEL_NAME).to(DEVICE)
    model.eval()  # Set to evaluation mode
    print("Model loaded successfully!")

## 4. Define Embedding Extraction Function

In [None]:
def get_esmc_pretrain(model, df_dir, db_name, sep=' ', header=None, 
                      col_names=['drug_id', 'prot_id', 'drug_smile', 'prot_seq', 'label'],
                      use_forge_api=False, device='cuda'):
    """
    Extract ESM-C embeddings for protein sequences.
    
    Args:
        model: Loaded ESM-C model or Forge API client
        df_dir: Path to input CSV file
        db_name: Database name for output filename
        sep: CSV separator
        header: CSV header row (None for no header)
        col_names: Column names for the dataframe
        use_forge_api: Whether using Forge API for esmc_6b
        device: Device for computation ('cuda' or 'cpu')
    
    Returns:
        Dictionary containing embeddings
    """
    # Load data
    df = pd.read_csv(df_dir, sep=sep, header=header)
    df.columns = col_names
    df.drop_duplicates(subset='prot_id', inplace=True)
    
    prot_ids = df['prot_id'].tolist()
    prot_seqs = df['prot_seq'].tolist()
    
    emb_dict = {}
    emb_mat_dict = {}
    length_dict = {}
    
    print(f"Processing {len(prot_ids)} proteins on {device}...")
    
    if use_forge_api:
        # Use Forge Batch Executor for efficient API calls
        from esm.sdk import batch_executor
        
        def embed_sequence(client, sequence, prot_id):
            """Helper function for batch processing"""
            try:
                seq = sequence[:2048]
                protein = ESMProtein(sequence=seq)
                protein_tensor = client.encode(protein)
                
                from esm.sdk.api import ESMProteinError
                if isinstance(protein_tensor, ESMProteinError):
                    raise protein_tensor
                
                output = client.logits(
                    protein_tensor,
                    LogitsConfig(sequence=True, return_embeddings=True)
                )
                
                # CRITICAL: Convert BFloat16 immediately after receiving from API
                embeddings = output.embeddings
                
                # Convert to torch tensor first if needed, then to float32
                if not isinstance(embeddings, torch.Tensor):
                    embeddings = torch.tensor(embeddings)
                
                # Convert BFloat16 to Float32
                if embeddings.dtype == torch.bfloat16:
                    embeddings = embeddings.to(torch.float32)
                
                # Now convert to numpy (Forge API runs on server, so CPU here is fine)
                if embeddings.is_cuda:
                    embeddings = embeddings.cpu()
                embeddings = embeddings.numpy()
                
                return (prot_id, embeddings, len(seq))
            except Exception as e:
                import traceback
                print(f"\nError processing {prot_id}: {e}")
                print(traceback.format_exc())
                return (prot_id, None, len(sequence[:2048]))
        
        # Process in batches using Forge executor
        print("Using Forge Batch Executor (server-side GPU processing)...")
        with batch_executor() as executor:
            # Prepare data for batch execution
            batch_data = [
                {'client': model, 'sequence': seq, 'prot_id': pid}
                for pid, seq in zip(prot_ids, prot_seqs)
            ]
            
            outputs = executor.execute_batch(
                user_func=embed_sequence,
                **{k: [d[k] for d in batch_data] for k in batch_data[0].keys()}
            )
        
        # Process outputs
        for prot_id, embeddings, seq_len in outputs:
            length_dict[prot_id] = seq_len
            
            if embeddings is not None:
                # Ensure float32 dtype
                if embeddings.dtype != np.float32:
                    embeddings = embeddings.astype(np.float32)
                
                # Store mean embedding as sequence representation
                emb_dict[prot_id] = embeddings.mean(axis=0)
                emb_mat_dict[prot_id] = embeddings
            else:
                # Fallback to zeros
                emb_dict[prot_id] = np.zeros(EMBEDDING_DIM, dtype=np.float32)
                emb_mat_dict[prot_id] = np.zeros((seq_len, EMBEDDING_DIM), dtype=np.float32)
    
    else:
        # Local model processing - USE GPU!
        print(f"Using local GPU: {device}")
        for idx in tqdm(range(len(prot_ids))):
            prot_id = str(prot_ids[idx])
            seq = prot_seqs[idx][:2048]
            length_dict[prot_id] = len(seq)
            
            try:
                # Create ESMProtein object
                protein = ESMProtein(sequence=seq)
                
                # Encode the protein (stays on GPU)
                protein_tensor = model.encode(protein)
                
                # Extract embeddings (computed on GPU)
                with torch.no_grad():
                    logits_output = model.logits(
                        protein_tensor,
                        LogitsConfig(return_embeddings=True)
                    )
                
                embeddings = logits_output.embeddings
                
                # Convert BFloat16 to Float32 if needed (still on GPU)
                if embeddings.dtype == torch.bfloat16:
                    embeddings = embeddings.to(torch.float32)
                
                # Only move to CPU at the last step before numpy conversion
                embeddings = embeddings.cpu().numpy()
                
                # Store mean embedding as sequence representation
                emb_dict[prot_id] = embeddings.mean(axis=0)  # Shape: (d_model,)
                
                # Store full embedding matrix
                emb_mat_dict[prot_id] = embeddings  # Shape: (seq_len, d_model)
                
            except Exception as e:
                print(f"\nError processing {prot_id}: {e}")
                # Use zero embeddings as fallback
                emb_dict[prot_id] = np.zeros(EMBEDDING_DIM, dtype=np.float32)
                emb_mat_dict[prot_id] = np.zeros((len(seq), EMBEDDING_DIM), dtype=np.float32)
    
    # Prepare output dictionary
    dump_data = {
        "dataset": db_name,
        "vec_dict": emb_dict,
        "mat_dict": emb_mat_dict,
        "length_dict": length_dict,
        "model": MODEL_NAME,
        "embedding_dim": EMBEDDING_DIM,
        "use_forge_api": use_forge_api
    }
    
    # Save to pickle
    output_file = f'./{db_name}_esmc_pretrain.pkl'
    with open(output_file, 'wb') as f:
        pickle.dump(dump_data, f)
    
    print(f"\nSaved embeddings to: {output_file}")
    print(f"Total proteins: {len(emb_dict)}")
    print(f"Embedding dimension: {EMBEDDING_DIM}")
    
    return dump_data


## 5. Test on Single Case Study

Test the function on a small dataset first.

In [None]:
# Example: Process simple-Case study data
db_name = 'simple_case'
df_dir = r'D:\Download D\2025\DTA\EXPLAIN 2\Temp\data\simple-Case\proteins.csv'
col_names = ['prot_id', 'prot_seq']

# Run the extraction (passes device to function)
result = get_esmc_pretrain(
    model=model, 
    df_dir=df_dir, 
    db_name=db_name, 
    sep=',', 
    header=0, 
    col_names=col_names, 
    use_forge_api=USE_FORGE_API,
    device=DEVICE  # Pass device config
)


## 6. Process Full Datasets (Davis, KIBA, Metz)

⚠️ **Important**: Update the file paths before running!

In [None]:
# Dataset configurations
datasets = {
    'davis': './data/dta-5fold-dataset/davis/davis_prots.csv',
    'kiba': './data/dta-5fold-dataset/kiba/kiba_prots.csv',
    'metz': './data/dta-5fold-dataset/metz/metz_prots.csv'
}

col_names = ['prot_id', 'prot_seq']

# Process each dataset
for db_name, df_dir in datasets.items():
    print(f"\n{'='*60}")
    print(f"Processing {db_name.upper()} dataset")
    print(f"{'='*60}")
    
    try:
        result = get_esmc_pretrain(
            model=model,
            df_dir=df_dir,
            db_name=db_name,
            sep=',',  # Adjust if needed
            header=0,
            col_names=col_names,
            use_forge_api=USE_FORGE_API,
            device=DEVICE  # Pass device config
        )
        print(f"✓ Successfully processed {db_name}")
    except Exception as e:
        print(f"✗ Error processing {db_name}: {e}")


## 7. Verify Generated Embeddings

In [None]:
# Load and inspect generated embeddings
def verify_embeddings(db_name):
    file_path = f'./{db_name}_esmc_pretrain.pkl'
    
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        
        print(f"\n{'='*60}")
        print(f"Verification: {db_name}")
        print(f"{'='*60}")
        print(f"Dataset: {data['dataset']}")
        print(f"Model: {data['model']}")
        print(f"Embedding dimension: {data['embedding_dim']}")
        print(f"Number of proteins: {len(data['vec_dict'])}")
        
        # Check first protein
        first_prot_id = list(data['vec_dict'].keys())[0]
        first_vec = data['vec_dict'][first_prot_id]
        first_mat = data['mat_dict'][first_prot_id]
        
        print(f"\nFirst protein: {first_prot_id}")
        print(f"  - Vec shape: {first_vec.shape}")
        print(f"  - Mat shape: {first_mat.shape}")
        print(f"  - Sequence length: {data['length_dict'][first_prot_id]}")
        
        return data
    except FileNotFoundError:
        print(f"✗ File not found: {file_path}")
        return None

# Verify each dataset
for db_name in ['davis', 'kiba', 'metz']:
    verify_embeddings(db_name)

## 8. Compare with ESM2 (Optional)

If you have ESM2 embeddings, you can compare dimensions and performance.

In [None]:
def compare_embeddings(db_name):
    esmc_file = f'./{db_name}_esmc_pretrain.pkl'
    esm2_file = f'./{db_name}_esm_pretrain.pkl'
    
    try:
        with open(esmc_file, 'rb') as f:
            esmc_data = pickle.load(f)
        
        with open(esm2_file, 'rb') as f:
            esm2_data = pickle.load(f)
        
        print(f"\n{'='*60}")
        print(f"Comparison: {db_name}")
        print(f"{'='*60}")
        print(f"ESM-C proteins: {len(esmc_data['vec_dict'])}")
        print(f"ESM2 proteins:  {len(esm2_data['vec_dict'])}")
        print(f"\nESM-C embedding dim: {esmc_data['embedding_dim']}")
        print(f"ESM2 embedding dim:  {list(esm2_data['vec_dict'].values())[0].shape[0]}")
        
    except FileNotFoundError as e:
        print(f"✗ File not found: {e}")

# Run comparison
# compare_embeddings('davis')

## 9. Next Steps

After generating ESM-C embeddings:

1. **Update hyperparameter.py**:
   - Set `use_esmc = True`
   - Set `esmc_model = "esmc_300m"` (or your chosen variant)
   - Verify `protvec_dim` matches your model (960, 1152, or 2560)

2. **Update data paths**:
   - Move generated `.pkl` files to `./data/{dataset}/` directory
   - Ensure filenames match: `{dataset}_esmc_pretrain.pkl`

3. **Retrain model**:
   ```bash
   python code/train.py
   ```

4. **Compare performance**:
   - Note: You'll need to retrain with new embeddings
   - Compare metrics (MSE, CI, R²) between ESM2 and ESM-C

## Summary

✅ ESM-C provides:
- Better performance (smaller models, better results)
- Longer sequences (2048 vs 1022 tokens)
- Faster inference
- Modern architecture

⚠️ Trade-offs:
- Different API (not drop-in replacement)
- Different dimensions (960/1152/2560 vs 1280)
- Requires retraining model with new embeddings