In [24]:
import shap
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, BertModel
from datasets import Dataset

# Load and preprocess the dataset
df = pd.read_excel('Aggregated.xlsx')    
X = df['Sequence'].tolist()
y = df['Label'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train_preprocessed = [' '.join(seq) for seq in X_train]
X_test_preprocessed = [' '.join(seq) for seq in X_test]

# Define the Protein Classifier Model
class ProteinClassifier(nn.Module):
    def __init__(self, n_classes=1):  # Default to binary classification
        super(ProteinClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('Rostlab/prot_bert_bfd_localization')
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.4),
            nn.Linear(self.bert.config.hidden_size, n_classes),
            nn.Sigmoid()  # Use Sigmoid for binary classification
        )
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.classifier[1].weight)
        nn.init.constant_(self.classifier[1].bias, 0)

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.classifier(output.pooler_output)

# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')
model_save_path = "protein_classifier_model.pth"
model = ProteinClassifier(1)
model.load_state_dict(torch.load(model_save_path))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Define the SHAP explainer
def f(sequences):
    sequences_as_strings = [''.join(seq) for seq in sequences]
    inputs = tokenizer(sequences_as_strings, return_tensors='pt', truncation=True, padding=True, max_length=60)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = model(inputs['input_ids'], inputs['attention_mask'])
    return outputs.cpu().numpy()

explainer = shap.Explainer(f, tokenizer)

# Create a huggingface dataset object
data_dict = {
    'text': X_train_preprocessed,
    'label': y_train
}
data_df = pd.DataFrame(data_dict)

protein_dataset = Dataset.from_pandas(data_df)

# Pass first 10 examples to the explainer
print(protein_dataset['text'][:10])
print(protein_dataset['label'][:10])

shap_values = explainer(protein_dataset[:10])

  model.load_state_dict(torch.load(model_save_path))


['S D P K I G D G C F G L P L D H I G S V S G L G C N R P V Q N R P K K', 'M N L E V I A Q L T V L S L I V L S G P L V I I L L A A N R G N L', 'S K S S S P C F G G K L D R I G S Y S G L G C N S R K', 'I I P L P L G Y F A K K T', 'G V F G L L A K A A L K G A S K L I P H L L P S R Q Q', 'S V I E L G K M I L Q E T G K N P V T H Y G A', 'M D I L S L G W S A L M V V F T F S L A L V V W G R N G F', 'E V R P F P E V Y E R I A', 'S S R R P C R G R S C G P R L R G G Y T L I G R P V K N Q N R P K Y M W V', 'F I G A L L R P A L K L L A']
[0, 0, 0, 1, 1, 0, 0, 0, 1, 1]


PartitionExplainer explainer: 11it [00:11,  5.88s/it]


In [26]:
# Get the input tensors for the first 10 examples
inputs = tokenizer(protein_dataset['text'][:10], return_tensors='pt', truncation=True, padding=True, max_length=60)
inputs = {key: val.to(device) for key, val in inputs.items()}  # Move inputs to GPU if available

# Make predictions
with torch.no_grad():  # Disable gradient calculation for inference
    predictions = model(inputs['input_ids'], inputs['attention_mask'])

# Detach the predictions from the computation graph and move to CPU
raw_outputs = predictions.detach().cpu().numpy()  # Convert to numpy array

# Display the raw outputs alongside SHAP explanations and ground truth labels
for i in range(len(shap_values)):
    # Compute predicted labels based on a threshold (0.5 for binary classification)
    predicted_label = 1 if raw_outputs[i][0] >= 0.5 else 0

    print(f"Sequence: {protein_dataset['text'][i]}")  # Original text input
    print(f"Raw Output Probability: {raw_outputs[i][0]:.4f}")  # Display the raw output (probabilities)
    print(f"Predicted Label: {predicted_label}")  # Display predicted label
    print(f"Ground Truth Label: {protein_dataset['label'][i]}")  # Display the ground truth label
    print("SHAP Explanation:")
    shap.plots.text(shap_values[i])  # SHAP plot for the i-th example
    print("\n" + "-" * 80 + "\n")  # Separator for clarity


Sequence: S D P K I G D G C F G L P L D H I G S V S G L G C N R P V Q N R P K K
Raw Output Probability: 0.9895
Predicted Label: 1
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: M N L E V I A Q L T V L S L I V L S G P L V I I L L A A N R G N L
Raw Output Probability: 0.0107
Predicted Label: 0
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: S K S S S P C F G G K L D R I G S Y S G L G C N S R K
Raw Output Probability: 0.1189
Predicted Label: 0
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: I I P L P L G Y F A K K T
Raw Output Probability: 0.9873
Predicted Label: 1
Ground Truth Label: 1
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: G V F G L L A K A A L K G A S K L I P H L L P S R Q Q
Raw Output Probability: 0.9897
Predicted Label: 1
Ground Truth Label: 1
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: S V I E L G K M I L Q E T G K N P V T H Y G A
Raw Output Probability: 0.0204
Predicted Label: 0
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: M D I L S L G W S A L M V V F T F S L A L V V W G R N G F
Raw Output Probability: 0.0109
Predicted Label: 0
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: E V R P F P E V Y E R I A
Raw Output Probability: 0.0227
Predicted Label: 0
Ground Truth Label: 0
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: S S R R P C R G R S C G P R L R G G Y T L I G R P V K N Q N R P K Y M W V
Raw Output Probability: 0.9883
Predicted Label: 1
Ground Truth Label: 1
SHAP Explanation:



--------------------------------------------------------------------------------

Sequence: F I G A L L R P A L K L L A
Raw Output Probability: 0.9895
Predicted Label: 1
Ground Truth Label: 1
SHAP Explanation:



--------------------------------------------------------------------------------

