# Running Pika for CARE

To run Pika for CARE and retrain the model you first need to create an environment for Pika.

This requires creating a new environment and then running this notebook from within that environmemnt.

Pika can be found here: https://github.com/EMCarrami/Pika

### Installation/setup of Pika environment
```
conda create --name pika python=3.10
```

```
conda activate pika
```

```
pip install git+https://github.com/EMCarrami/Pika.git
```

In [2]:
import sys
from pika.main import Pika
from pika.utils.helpers import load_config
import warnings
import logging

warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)

  from .autonotebook import tqdm as notebook_tqdm


### Running Pika and retraining

In order to run and re-train Pika you need to create the datasets in the same format. 

This requires creating a `metrics`, `sequences`, `split` and `annotations` file. Additionally the config needs to be updated to refelect this:

```
{
  "seed": 7,
  "datamodule": {
    "sequence_data_path": "sequences.csv",
    "annotations_path": "annotations.csv",
    "metrics_data_path": "metrics.csv",
    "split_path": "split.csv",
    "max_protein_length": 1500,
    "max_text_length": 250,
    "data_types_to_use": ["qa"],
    "sequence_placeholder": "<protein sequence placeholder> ",
    "train_batch_size": 10,
    "eval_batch_size": 2,
    "num_workers": 0
  },
  "model": {
    "language_model": "gpt2",
    "protein_model": "esm2_t6_8M_UR50D",
    "multimodal_strategy": "self-pika",
    "protein_layer_to_use": -1,
    "perceiver_latent_size": 10,
    "num_perceiver_layers": 4,
    "multimodal_layers": [0],
    "enable_gradient_checkpointing": false,
    "lr": 1e-4,
    "weight_decay": 1e-4
  },
  "checkpoint_callback": {
    "checkpoint_path": "test_checkpoint",
    "save_partial_checkpoints": true,
    "checkpoint_monitors": ["loss/val_loss"],
    "checkpoint_modes": ["min"]
  },
  "trainer": {
    "max_epochs": 2,
    "limit_train_batches": 100,
    "limit_val_batches": 1,
    "limit_test_batches": 100
  }
}

```

In [3]:
import sys
from pika.main import Pika
from pika.utils.helpers import load_config
import warnings
import logging
import pandas as pd
import numpy as np


ec_column = 'EC All'

df_train = pd.read_csv('../../splits/task1/protein_train.csv')
rows = []
for entry, seq, ec in df_train[['Entry', 'Sequence', ec_column]].values:
    rows.append([entry, 'qa', f"What is the EC number of this protein? {ec}"])
    
sample_annotations = pd.DataFrame(rows, columns=['uniprot_id', 'type', 'annotation'])
sample_annotations.to_csv('annotations.csv', index=False)

# Also split into a train test and validation set for the model training
from sklearn.model_selection import train_test_split

train, test = train_test_split(df_train, test_size=0.3)
rows = []
for entry, seq, ec in df_train[['Entry', 'Sequence', ec_column]].values:
        rows.append([entry, len(seq), 'train'])
    
for entry, seq, ec in test[['Entry', 'Sequence', ec_column]].values[:int(0.5*(len(test)))]:
    rows.append([entry, len(seq), 'test'])

for entry, seq, ec in test[['Entry', 'Sequence', ec_column]].values[int(0.5*(len(test))):]:
    rows.append([entry, len(seq), 'val'])
    
sample_split = pd.DataFrame(rows, columns=['uniprot_id' , 'protein_length', 'split'])

sample_split.to_csv('split.csv', index=False)

# Next we need to make the metrics
# uniprot_id,metric,value
# A0A068BGA5,is_enzyme,True


## Pika requires knowing other info about the enzyme

Even though this likely doesn't affect the re-training, we update this information as well.

Download the dataset from Pika (they used ChatGPT3.5 to extract the metrics for each protein).

So we wget the metrics and use this to fill in the metrics for the training dataset: 

https://huggingface.co/datasets/EMCarrami/Pika-DS/tree/main/dataset

```
wget https://huggingface.co/datasets/EMCarrami/Pika-DS/resolve/main/dataset/pika_metrics.csv
```



In [6]:
metrics_df = pd.read_csv('pika_metrics.csv')
metrics_df

Unnamed: 0,uniprot_id,metric,value
0,A0A009IHW8,in_membrane,False
1,A0A009IHW8,in_nucleus,False
2,A0A009IHW8,in_mitochondria,False
3,A0A009IHW8,is_enzyme,True
4,A0A009IHW8,mw,30922
...,...,...,...
1432678,W6Q4Q9,in_nucleus,False
1432679,W6Q4Q9,in_mitochondria,False
1432680,W6Q4Q9,is_enzyme,True
1432681,W6Q4Q9,cofactor,mg(2+)


In [7]:
metrics_df = metrics_df[metrics_df['uniprot_id'].isin(list(set(df_train['Entry'].values)))]
metrics_df

Unnamed: 0,uniprot_id,metric,value
0,A0A009IHW8,in_membrane,False
1,A0A009IHW8,in_nucleus,False
2,A0A009IHW8,in_mitochondria,False
3,A0A009IHW8,is_enzyme,True
4,A0A009IHW8,mw,30922
...,...,...,...
1432660,S3DQP8,in_nucleus,False
1432661,S3DQP8,in_mitochondria,False
1432662,S3DQP8,is_enzyme,True
1432663,S3DQP8,cofactor,pyridoxal 5'-phosphate (plp)


In [8]:

from sklearn.model_selection import train_test_split
from tqdm import tqdm

# A0A084R1H6,in_membrane,False
# A0A084R1H6,in_nucleus,False
# A0A084R1H6,in_mitochondria,False
# A0A084R1H6,is_enzyme,True
# A0A084R1H6,mw,263256
rows = []
for entry, seq, ec in tqdm(df_train[['Entry', 'Sequence', ec_column]].values):
    metrics = metrics_df[metrics_df['uniprot_id'] == entry]
    # now we can assign the map for each one
    for metric_name, value in metrics[['metric', 'value']].values:
        rows.append([entry, metric_name, value])

sample_metrics = pd.DataFrame(rows, columns=['uniprot_id' , 'metric', 'value'])
sample_metrics.to_csv('metrics.csv', index=False)


  0%|          | 0/184529 [00:00<?, ?it/s]

100%|██████████| 184529/184529 [59:15<00:00, 51.90it/s] 


In [2]:
import pandas as pd
pd.read_csv('metrics.csv')

Unnamed: 0,uniprot_id,metric,value
0,A0A009IHW8,in_membrane,False
1,A0A009IHW8,in_nucleus,False
2,A0A009IHW8,in_mitochondria,False
3,A0A009IHW8,is_enzyme,True
4,A0A009IHW8,mw,30922
...,...,...,...
451840,Q9J5H2,in_membrane,False
451841,Q9J5H2,in_nucleus,False
451842,Q9J5H2,in_mitochondria,False
451843,Q9J5H2,is_enzyme,True


## Formatting the sequence dataset

We do the same thing with formatting the sequence dataset.

```
wget https://huggingface.co/datasets/EMCarrami/Pika-DS/resolve/main/dataset/pika_sequences.csv
```

In [6]:
seq_df = pd.read_csv('pika_sequences.csv')
# seq_df = seq_df[seq_df['uniprot_id'].isin(list(set(df_train['Entry'].values)))]
# seq_df

In [15]:
seq_df

Unnamed: 0,uniprot_id,uniref_cluster,taxonomy,sequence,length,mw,num_fields,num_summary,num_qa
0,A0A009IHW8,UniRef50_A0A009IHW8,"Bacteria, Pseudomonadota, Gammaproteobacteria",MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,269,30922,5,6,11
1,A0A067XGX8,UniRef50_A0A067XH53,"Eukaryota, Viridiplantae, Streptophyta",MALTATATTRGGSALPNSCLQTPKFQSLQKPTFISSFPTNKKTKPR...,512,57062,7,5,8
2,A0A067XH53,UniRef50_A0A067XH53,"Eukaryota, Viridiplantae, Streptophyta",MALSTNSTTSSLLPKTPLVQQPLLKNASLPTTTKAIRFIQPISAIH...,533,58894,7,7,8
3,A0A068BGA5,UniRef50_A0A068BGA5,"Eukaryota, Viridiplantae, Streptophyta",MASFPPSLVFTVRRKEPILVLPSKPTPRELKQLSDIDDQEGLRFQV...,456,50972,4,4,6
4,A0A072VHJ1,UniRef50_A0A072VHJ1,"Eukaryota, Viridiplantae, Streptophyta",MSGVPFPSNLLPSPSSPEWLSKADNAWQLMAATLVGMQSVPGLIIL...,481,52906,4,4,7
...,...,...,...,...,...,...,...,...,...
257162,Q9ZWB3,UniRef50_Q9ZWB3,"Eukaryota, Viridiplantae, Streptophyta",MDLVIGGKFKLGRKIGSGSFGELYLGINVQTGEEVAVKLESVKTKH...,471,53002,6,7,8
257163,S3DQP8,UniRef50_S3DQP8,"Eukaryota, Fungi, Dikarya",MTENFPLPPLLGVDWDHLGFEPLEVNGHVECTFSTTTSCWTEPVFV...,358,38895,5,6,7
257164,V6F510,UniRef50_V6F510,"Bacteria, Pseudomonadota, Alphaproteobacteria",MKFENCRDCREEVVWWAFTADICMTLFKGILGLMSGSVALVADSLH...,297,31942,5,5,6
257165,W6KHH6,UniRef50_W6KHH6,"Bacteria, Pseudomonadota, Alphaproteobacteria",MTTAACRKCRDEVIWWAFFINIGQTTYKGVLGVLSGSAALVADAMH...,293,31446,5,6,8


In [10]:
id_to_cluster = dict(zip(seq_df['uniprot_id'], seq_df['uniref_cluster']))
id_to_tax = dict(zip(seq_df['uniprot_id'], seq_df['taxonomy']))
id_to_mw = dict(zip(seq_df['uniprot_id'], seq_df['mw']))
id_to_num_fields = dict(zip(seq_df['uniprot_id'], seq_df['num_fields']))
id_to_summary = dict(zip(seq_df['uniprot_id'], seq_df['num_summary']))



In [13]:
uniprot = pd.read_csv('../../pretrained/raw_data/uniprotkb_AND_reviewed_true_2024_10_21_annot_pika.tsv', sep='\t')
id_to_tax = dict(zip(uniprot['Entry'], uniprot['Taxonomic lineage']))
id_to_mw = dict(zip(uniprot['Entry'], uniprot['Mass']))
id_to_len = dict(zip(uniprot['Entry'], uniprot['Length']))

In [16]:
df_train['Entry'].value_counts()

Entry
Q04828    9
Q95JH6    9
Q2NM15    9
Q5REQ0    9
G8H5N0    8
         ..
B5E3K8    1
B5E3L5    1
B5E3S7    1
B5E3Z6    1
B5E327    1
Name: count, Length: 173550, dtype: int64

In [18]:
# This time we'll make a map for each to make this more efficient

# Save the training sequences
rows = []
for entry, seq, ec, clust in df_train[['Entry', 'Sequence', ec_column, 'clusterRes50']].values:
    rows.append([entry, f'UniRef50_{clust}', id_to_tax.get(entry), seq, len(seq), id_to_mw.get(entry), 1, 1, 1])
    
sample_seqs = pd.DataFrame(rows, columns=['uniprot_id', 'uniref_cluster', 'taxonomy', 'sequence', 'length', 'mw', 'num_fields', 'num_summary', 'num_qa'])

sample_seqs.to_csv('sequences.csv', index=False)

In [19]:
sample_seqs

Unnamed: 0,uniprot_id,uniref_cluster,taxonomy,sequence,length,mw,num_fields,num_summary,num_qa
0,A0A009IHW8,UniRef50_A0A009IHW8,"cellular organisms (no rank), Bacteria (superk...",MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,269,30922.0,1,1,1
1,A0A023I7E1,UniRef50_A0A023I7E1,"cellular organisms (no rank), Eukaryota (super...",MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,796,89495.0,1,1,1
2,A0A024SC78,UniRef50_A8QPD8,"cellular organisms (no rank), Eukaryota (super...",MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,248,25924.0,1,1,1
3,A0A024SH76,UniRef50_A1CCN4,"cellular organisms (no rank), Eukaryota (super...",MIVGILTTLATLATLAASVPLEERQACSSVWGQCGGQNWSGPTCCA...,471,49653.0,1,1,1
4,A0A044RE18,UniRef50_A0A044RE18,"cellular organisms (no rank), Eukaryota (super...",MYWQLVRILVLFDCLQKILAIEHDSICIADVDDACPEPSHTVMRLR...,693,76800.0,1,1,1
...,...,...,...,...,...,...,...,...,...
184524,Q05115,UniRef50_Q05115,"cellular organisms (no rank), Bacteria (superk...",MQQASTPTIGMIVPPAAGLVPADGARLYPDLPFIASGLGLGSVTPE...,240,24735.0,1,1,1
184525,Q6HX62,UniRef50_Q6HX62,"cellular organisms (no rank), Bacteria (superk...",MGQNQFRWSNEQLREHVEIIDGTRSPHKLLKNATYLNSYIREWMQA...,584,66760.0,1,1,1
184526,Q6L032,UniRef50_Q6L032,"cellular organisms (no rank), Archaea (superki...",MLLKNIKISNDYNIFMIIASRKPSLKDIYKIIKVSKFDEPADLIIE...,573,65635.0,1,1,1
184527,Q94MV8,UniRef50_P39262,"Viruses (superkingdom), Duplodnaviria (clade),...",MAHFNECAHLIEGVDKANRAYAENIMHNIDPLQVMLDMQRHLQIRL...,172,20191.0,1,1,1


## Re-train Pika

In [None]:
# Make the model 
# prep config
assets_path = "../assets/"
config = load_config("pika_config.json")
config["datamodule"]["split_path"] = "split.csv"
model = Pika(config)
model.train()

# For each of the test sets we want to 
splits = ['30', '30-50', 'price', 'promiscuous']
rows = []
for split in splits: 
    df_test = pd.read_csv(f'../.../splits/task1/{split}_protein_test.csv')
    
    for entry, seq in df_test[['Entry', 'Sequence']].values:
        ec = model.enquire(
            proteins=seq,
            question="What is the EC number of this protein?"
        )
        rows.append([split, seq, entry, '|'.join(ec)])
saving_df = pd.DataFrame(rows, columns=['Split', 'seq', 'Entry', 'EC'])
saving_df.to_csv(f'../results_summary/Pika/all_test_datasets_output.csv', index=False)


### Save the results now individually 

df = saving_df.copy()

# The datasets we want to go through
splits = ['30', '30-50', 'price', 'promiscuous']

for split in splits:
    # Entry,EC number,
    sub_df = df[df['Split'] == split]
    # Make the enrty to the EC 
    test_df = pd.read_csv(f'../../splits/task1/{split}_protein_test.csv')

    # Make sure the EC is clean
    sub_df['EC number'] = [e.strip() for e in sub_df['EC'].values]
    
    # Make the EC format the same as the other datasets
    entry_to_ec = dict(zip(sub_df['Entry'], sub_df['EC number']))
    test_df['0'] = [entry_to_ec.get(e) for e in test_df['Entry'].values]
    test_df.to_csv(f'../results_summary/Pika/{split}_protein_test_results_df.csv', index=False)

In [None]:
task1_baselines/Pika/Pika/notebooks/train_pika.ipynb