# Quick Start: Generate Embeddings for Custom Dataset

This notebook generates pretrained embeddings for your drugs and proteins.

## Input Files Required:
1. **drugs.csv** with columns: `drug_id`, `drug_smile`
2. **proteins.csv** with columns: `prot_id`, `prot_seq`
3. **model_300dim.pkl** (mol2vec model) - already in `data/`
4. **protVec_100d_3grams.csv** (protein model) - download if needed

## Output:
- `data/{dataset_name}/{dataset_name}_drug_pretrain.pkl`
- `data/{dataset_name}/{dataset_name}_prot_pretrain.pkl`

## Steps:
1. Run all cells in order
2. Update `dataset_name`, `drugs_csv`, and `proteins_csv` paths in configuration cells
3. Check output files in `data/{dataset_name}/` folder

# Prepare ESM2 pretrain

We choose the esm2_t33_650M_UR50D model from this [hub](https://github.com/facebookresearch/esm#available-models).

For fit esm2 model input, we set the amino acid seqs MAX_LEN = 1022 (+ cls, eos == 1024).

For each protein(length = m), it will generate a feature mat, shape=(m, 1280). 

By mean this mat, it output a feature vec, dim=(1280).

Cuz the limitation of the memory, we only compute one protein each time.

## Prepare Input
- [dta-origin-dataset](https://www.kaggle.com/datasets/christang0002/llmdta/data)
    - davis.txt
    - kiba.txt
    - metz.txt
- model_300dim.pkl
- protVec_100d_3grams.csv

In [None]:
# Install required packages
!pip install fair-esm

In [None]:
import torch
import pandas as pd
import esm

: 

In [None]:
torch.set_num_threads(2)

In [None]:
# Load ESM-2 model (after installing fair-esm)
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

In [None]:
from tqdm import tqdm
import pickle
import os
import gc

def get_esm_pretrain(model, df_dir, db_name, sep=' ', header=None, col_names=['drug_id', 'prot_id', 'drug_smile', 'prot_seq', 'label'], batch_size=50):
    df = pd.read_csv(df_dir, sep=sep, header=header)
    
    # Only set column names if number matches
    if len(df.columns) == len(col_names):
        df.columns = col_names
    else:
        print(f"Warning: CSV has {len(df.columns)} columns but {len(col_names)} names provided. Using first 2 columns as prot_id and prot_seq.")
        if len(df.columns) >= 2:
            df.columns = ['prot_id', 'prot_seq'] + [f'col_{i}' for i in range(2, len(df.columns))]
        else:
            # Only 1 column, assume it's sequence and generate IDs
            df.columns = ['prot_seq']
            df['prot_id'] = [f'prot_{i}' for i in range(len(df))]
    
    df.drop_duplicates(subset='prot_id', inplace=True)
    prot_ids = df['prot_id'].tolist()
    prot_seqs = df['prot_seq'].tolist()
    data = []
    prot_size = len(prot_ids)
    for i in range(prot_size):
        seq_len = min(len(prot_seqs[i]),1022)
        data.append((prot_ids[i], prot_seqs[i][:seq_len]))
    
    emb_dict = {}
    emb_mat_dict = {}
    length_target = {}

    print(f"Processing {len(data)} proteins in batches of {batch_size}...")
    
    # Process in batches and save periodically to avoid RAM overflow
    for d in tqdm(data):
        prot_id = d[0]
        batch_labels, batch_strs, batch_tokens = batch_converter([d])
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        
        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)  # Disable contact map to save memory
        token_representations = results["representations"][33].cpu().numpy()  # Explicitly move to CPU

        sequence_representations = []
        for i, tokens_len in enumerate(batch_lens):
            sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

        emb_dict[prot_id] = sequence_representations[0]
        emb_mat_dict[prot_id] = token_representations[0]
        length_target[prot_id] = len(d[1])
        
        # Clear tensors every batch_size iterations
        if len(emb_dict) % batch_size == 0:
            del batch_tokens, results, token_representations
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    dump_data = {
        "dataset": db_name,
        "vec_dict": emb_dict,
        "mat_dict": emb_mat_dict,
        "length_dict": length_target
    }
    
    # Create directory if not exists
    os.makedirs(f'./data/{db_name}', exist_ok=True)
    
    with open(f'./data/{db_name}/{db_name}_esm_pretrain.pkl', 'wb+') as f:
        pickle.dump(dump_data, f)
    print(f"✓ Saved to ./data/{db_name}/{db_name}_esm_pretrain.pkl")
    
    # Final cleanup
    del emb_dict, emb_mat_dict, length_target
    gc.collect()

In [None]:
# Check proteins.csv format first
df_dir = '/content/Temp/data/simple-Case/proteins.csv'
df_test = pd.read_csv(df_dir, sep=',', header=0)  # Try comma separator
print(f"Shape: {df_test.shape}")
print(f"Columns: {df_test.columns.tolist()}")
print(f"First 3 rows:\n{df_test.head(3)}")

# Now run get_esm_pretrain with correct separator
db_name = 'case_study'
col_names = ['prot_id', 'prot_seq']
get_esm_pretrain(model, df_dir, db_name, sep=',', header=0, col_names=col_names)

In [None]:
df = pd.read_csv(df_dir, sep='\t')
df.columns = col_names
df

In [None]:
db_names = ['davis', 'kiba', 'metz']
df_dirs = [r'/home/tangwuguo/datasets/davis.txt', r'/home/tangwuguo/datasets/kiba.txt', r'/home/tangwuguo/datasets/metz.txt']

for i in range(1,2):
    print(f'Compute {df_dirs[i]} protein pretrain feature by esm2.')
    get_esm_pretrain(model, df_dirs[i], db_names[i])

# Mol2Vec pretrain
Input drug smiles seq, firstly it will compute the sub-structure of this drug.

For one SMILES, sub-strutures num=m, it outputs a (m,300) feature mat and a 300-dim feature vector.

In [None]:
# Install required packages for mol2vec
!pip install gensim mol2vec rdkit

import numpy as np 
import pandas as pd 
import pickle
from gensim.models import word2vec
from mol2vec.features import mol2alt_sentence
from rdkit import Chem
from tqdm import tqdm

In [None]:
def get_mol2vec(mol2vec_dir, df_dir, db_name, sep=',', header=0, col_names=['drug_id', 'drug_smile'], embedding_dimension=300, is_debug=False, show_miss_details=False):
    mol2vec_model = word2vec.Word2Vec.load(mol2vec_dir)
    
    df = pd.read_csv(df_dir, header=header, sep=sep)
    df.columns = col_names
    df.drop_duplicates(subset='drug_id', inplace=True)    
    drug_ids = df['drug_id'].tolist()
    drug_seqs = df['drug_smile'].tolist()
    
    emb_dict = {}
    emb_mat_dict = {}
    length_dict = {}
    
    percent_unknown = []
    bad_mol = 0
    miss_details = []  # Track which drugs have missing substructures
    
    # get pretrain feature
    for idx in tqdm(range(len(drug_ids))):
        flag = 0
        mol_miss_words = 0
        
        drug_id = str(drug_ids[idx])
        molecule = Chem.MolFromSmiles(drug_seqs[idx])
        
        try:
            # Get fingerprint from molecule
            sub_structures = mol2alt_sentence(molecule, 2)
        except Exception as e: 
            if is_debug: 
                print (e)
            percent_unknown.append(100)
            continue    
                
        emb_mat = np.zeros((len(sub_structures), embedding_dimension))
        length_dict[drug_id] = len(sub_structures)
        
        missed_subs = []  # Track missing substructures for this drug
        for i, sub in enumerate(sub_structures):
            # Check to see if substructure exists
            try:
                emb_dict[drug_id] = emb_dict.get(drug_id, np.zeros(embedding_dimension)) + mol2vec_model.wv[sub]  
                emb_mat[i] = mol2vec_model.wv[sub]  
            # If not, replace with UNK (unknown)
            except Exception as e:
                if is_debug : 
                    print ("Sub structure not found")
                    print (e)
                emb_dict[drug_id] = emb_dict.get(drug_id, np.zeros(embedding_dimension)) + mol2vec_model.wv['UNK']
                emb_mat[i] = mol2vec_model.wv['UNK']                
                flag = 1
                mol_miss_words = mol_miss_words + 1
                missed_subs.append(sub)
        
        emb_mat_dict[drug_id] = emb_mat
        
        miss_rate = (mol_miss_words / len(sub_structures)) * 100
        percent_unknown.append(miss_rate)
        if flag == 1:
            bad_mol = bad_mol + 1
            if show_miss_details:
                miss_details.append({
                    'drug_id': drug_id,
                    'smile': drug_seqs[idx],
                    'total_subs': len(sub_structures),
                    'missed_subs': mol_miss_words,
                    'miss_rate': miss_rate,
                    'missed_substructures': missed_subs[:5]  # Show first 5 missing
                })
            
    print(f'All Bad Mol: {bad_mol}, Avg Miss Rate: {sum(percent_unknown)/len(percent_unknown):.2f}%')
    
    # Show details of top 10 worst drugs
    if show_miss_details and miss_details:
        print("\n=== Top 10 drugs with highest miss rates ===")
        miss_details.sort(key=lambda x: x['miss_rate'], reverse=True)
        for i, detail in enumerate(miss_details[:10]):
            print(f"\n{i+1}. Drug ID: {detail['drug_id']}")
            print(f"   SMILES: {detail['smile'][:50]}...")
            print(f"   Miss Rate: {detail['miss_rate']:.1f}% ({detail['missed_subs']}/{detail['total_subs']} substructures)")
            print(f"   Missing substructures: {detail['missed_substructures']}")
        
    dump_data = {
        "dataset": db_name,
        "vec_dict": emb_dict,
        "mat_dict": emb_mat_dict,
        "length_dict": length_dict
    }
    
    # Create directory if not exists
    import os
    os.makedirs(f'./data/{db_name}', exist_ok=True)
    
    with open(f'./data/{db_name}/{db_name}_drug_pretrain.pkl', 'wb+') as f:
        pickle.dump(dump_data, f)
    print(f"\n✓ Saved to ./data/{db_name}/{db_name}_drug_pretrain.pkl")

In [None]:
# Download mol2vec model if not exists
import os
if not os.path.exists('./data/model_300dim.pkl'):
    print("Downloading mol2vec model (73MB)...")
    !mkdir -p ./data
    # Download from GitHub repo or use wget
    !wget -O ./data/model_300dim.pkl https://github.com/samoturk/mol2vec/raw/master/examples/models/model_300dim.pkl
    print("✓ Downloaded model_300dim.pkl")
else:
    print("✓ model_300dim.pkl already exists")

In [None]:
# Configuration for your custom dataset
mol2vec_dir = './data/model_300dim.pkl'
dataset_name = 'custom'  # Change this to your dataset name
drugs_csv = './data/simple-Case/drugs.csv'  # Path to your drugs.csv

# Generate mol2vec embeddings with miss details
get_mol2vec(mol2vec_dir, drugs_csv, dataset_name, sep=',', header=0, 
            col_names=['drug_id', 'drug_smile'], show_miss_details=True)

In [None]:
mol2vec_dir = './data/model_300dim.pkl'
db_names = ['davis', 'kiba', 'metz']
df_dirs = [r'/home/tangwuguo/datasets/davis.txt', r'/home/tangwuguo/datasets/kiba.txt', r'/home/tangwuguo/datasets/metz.txt']

for i in range(3):
    print(f'Compute {db_names[i]} drug pretrain feature by protvec.')
    get_esm_pretrain(mol2vec_dir, df_dirs[i], db_names[i])

# ProtVec pretrain
Inuput protein's amino acid seq, len = m

Output feature mat(m,100), feature vector(100)

In [None]:
def get_protvec(protvec_dir, df_dir, db_name, col_names=['prot_id', 'prot_seq'], embedding_dimension=100, is_debug=False, sep=',', header=0):        
    protvec_model = pd.read_csv(protvec_dir, delimiter = '\t')
    trigram_dict = {}
    for idx, row in tqdm(protvec_model.iterrows()):
        trigram_dict[row['words']] = protvec_model.iloc[idx, 1:].values.astype(np.float64)
    trigram_list = set(trigram_dict.keys())
    
    # Read CSV with proper parameters
    df = pd.read_csv(df_dir, header=header, sep=sep)
    df.columns = col_names
    df.drop_duplicates(subset='prot_id', inplace=True)    
    prot_ids = df['prot_id'].tolist()
    prot_seqs = df['prot_seq'].tolist()
    
    emb_dict = {}
    emb_mat_dict = {}
    length_3mer_target = {}

    # get pretrain feature
    for idx in tqdm(range(len(prot_ids))):
        n = 3
        target = prot_seqs[idx]
        prot_id = str(prot_ids[idx])
        split_by_three = [target[i : i + n] for i in range(0, len(target), n)]
        mer_len = len(split_by_three)
        length_3mer_target[prot_id] = mer_len
        
        emb_mat = np.zeros((mer_len, embedding_dimension))
        for i, trigram in enumerate(split_by_three): 
            if len(trigram) == 2: 
                trigram = "X" + trigram
            elif len(trigram) == 1:
                trigram = "XX" + trigram
            if trigram in trigram_list:
                emb_dict[prot_id] = emb_dict.get(prot_id, np.zeros(embedding_dimension))+ trigram_dict[trigram]
                emb_mat[i] = trigram_dict[trigram]
        emb_mat_dict[prot_id] = emb_mat
        
    # Save pretrain embeddings
    dump_data = {
        "dataset": db_name,
        "vec_dict": emb_dict,
        "mat_dict": emb_mat_dict,
        "length_dict": length_3mer_target
    }    
    with open(f'./data/{db_name}/{db_name}_prot_pretrain.pkl', 'wb+') as f:
        pickle.dump(dump_data, f)
    print(f"✓ Saved to ./data/{db_name}/{db_name}_prot_pretrain.pkl")

In [None]:
# Configuration for protein embeddings
protvec_dir = './data/protVec_100d_3grams.csv'
dataset_name = 'custom'  # Same name as above
proteins_csv = './data/simple-Case/proteins.csv'  # Path to your proteins.csv

# Generate protein embeddings
print(f'Computing {dataset_name} protein pretrain features by protvec...')
get_protvec(protvec_dir, proteins_csv, dataset_name, col_names=['prot_id', 'prot_seq'], embedding_dimension=100)