In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import json
import gc
import requests
import os
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.optim import Adam
import torch.optim as optim
from sklearn.model_selection import train_test_split
import random
import seaborn as sns
gc.collect()
torch.cuda.empty_cache()
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)
# Define device
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [3]:
# Define model

class ModifiedResNet18(nn.Module):
    def __init__(self, input_channels=1, pretrained=False):
        super(ModifiedResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=pretrained)
        self.resnet.conv1 = nn.Conv2d(
            input_channels,
            64,
            kernel_size=7,
            stride=1,
            padding=3,
            bias=False
        )
        self.resnet.maxpool = nn.Identity()
        self._modify_resnet_layers()
    def _modify_resnet_layers(self):
        for layer in [self.resnet.layer3, self.resnet.layer4]:
            for block in layer:
                block.conv1.stride = (1, 1)
                if block.downsample:
                    block.downsample[0].stride = (1, 1)
    def forward(self, x):
        # Collect intermediate features before residual connections
        features = {}
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        # Layer1
        before_layer1 = x.clone()
        x = self.resnet.layer1(x)
        features['layer1'] = before_layer1  # Features before skip connections in layer1
        # Layer2
        before_layer2 = x.clone()
        x = self.resnet.layer2(x)
        features['layer2'] = before_layer2  # Features before skip connections in layer2
        # Layer3
        before_layer3 = x.clone()
        x = self.resnet.layer3(x)
        features['layer3'] = before_layer3  # Features before skip connections in layer3
        # Layer4 (Final Output)
        x = self.resnet.layer4(x)
        features['layer4'] = x  
        return x, features  # Return final output and intermediate features

class ResNetAttentionModel(nn.Module):
    def __init__(self, num_classes=1, input_channels=1, seq_length=300):
        super(ResNetAttentionModel, self).__init__()
        self.seq_length = seq_length
        self.resnet = ModifiedResNet18(input_channels=input_channels, pretrained=False).to(device)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # Pool to (batch_size, 512, 1, 1)
            nn.Flatten(),                  # Flatten to (batch_size, 512)
            nn.Linear(512, num_classes)    # Final classification layer
        )
    def forward(self, x):
        # Passes input through the modified ResNet18 to get features
        features, intermediate_features = self.resnet(x)  
        logits = self.classifier(features).squeeze(1) 
        # Extract intermediate features before residual connections
        feat_low = intermediate_features['layer1']   
        feat_mid = intermediate_features['layer2']   
        feat_high = intermediate_features['layer3']  
        feat_vhigh = intermediate_features['layer4']
        attention_low = self.generate_attention(feat_low, target_height=self.seq_length)
        attention_mid = self.generate_attention(feat_mid, target_height=self.seq_length)
        attention_high = self.generate_attention(feat_high, target_height=self.seq_length)
        attention_vhigh = self.generate_attention(feat_vhigh, target_height=self.seq_length)
        attention_combined = (attention_low + attention_mid + attention_high) / 3  # Simple average
        return logits, attention_combined, attention_low, attention_mid, attention_high, attention_vhigh
    def generate_attention(self, feature_map, target_height):
        """
        Generates attention maps from feature maps using non-trainable operations.
        Returns attention map of shape (batch_size, target_height)
        """
        attention = torch.mean(feature_map, dim=1, keepdim=True)
        attention = F.adaptive_avg_pool2d(attention, (self.seq_length, 1))  
        attention = attention.squeeze(3).squeeze(1) 
        return attention

In [9]:
# Load enzyme sequences in the test set

csv_file_path = 'uniprot_sequences_with_positions_validation.csv'
df = pd.read_csv(csv_file_path)
# Extract the sequences and labels directly as lists
all_sequences = df['sequence'].tolist()
labels = df['label'].tolist()
# Function to parse the JSON-formatted functional positions
def parse_positions(pos):
    if isinstance(pos, str):
        try:
            return json.loads(pos)
        except json.JSONDecodeError:
            print(f"JSON decoding failed for input: {pos}")
            return []
    elif isinstance(pos, list):
        return pos
    else:
        print(f"Unexpected type {type(pos)} for input: {pos}")
        return []
# Apply the parsing function to the 'functional_positions' column
valid_functional_positions = df['functional_positions'].apply(parse_positions).tolist()

In [11]:
labels = np.array([1]*len(all_sequences))
important_positions_combined = valid_functional_positions

In [12]:
# Encoding

amino_acids = 'ACDEFGHIKLMNPQRSTVWYX'
aa_to_int = {aa: idx for idx, aa in enumerate(amino_acids)}
NUM_AA = len(amino_acids)
SEQ_LENGTH = 300  # Example sequence length; adjust as needed
NUM_CLASSES = 1    # Example number of classes; adjust as needed
EPOCHS = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def encode_sequence(seq, max_length):
    seq = seq.upper()
    int_seq = [aa_to_int.get(aa, NUM_AA-1) for aa in seq]  # Use 'X' index for unknowns
    if len(int_seq) < max_length:
        int_seq += [NUM_AA-1] * (max_length - len(int_seq))
    else:
        int_seq = int_seq[:max_length]
    one_hot = np.zeros((max_length, NUM_AA), dtype=np.float32)
    one_hot[np.arange(max_length), int_seq] = 1.0
    return one_hot
encoded_sequences = np.array([encode_sequence(seq, SEQ_LENGTH) for seq in all_sequences])  # Shape: (num_samples, SEQ_LENGTH, NUM_AA)
print(encoded_sequences.shape)
print(len(labels))
X_val_np = np.expand_dims(encoded_sequences, axis=1)  # Shape: (num_samples, 1, SEQ_LENGTH, NUM_AA)
y_val_np = labels

(3448, 300, 21)
3448


In [243]:
# Load trained model
model = ResNetAttentionModel(num_classes=1, input_channels=1, seq_length=SEQ_LENGTH)
# Select model to test
saved_dict = torch.load('models/high_attention_NONNUM_60_ALPHA_3.pth', map_location=device)
model.load_state_dict(saved_dict)
model.to(device)
model.eval()

  saved_dict = torch.load('over90_accuracy_student_resnet_high_attention_NONNUM_60_ALPHA_3_EPOCH_32_122924.pth', map_location=device)


ResNetAttentionModel(
  (resnet): ModifiedResNet18(
    (resnet): ResNet(
      (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): Identity()
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True,

In [229]:
# Initialize lists to store correctly predicted sequences and their attention scores
correct_sequences = []
correct_attentions = []  # To store attention scores
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs, student_att, attention_low, attention_mid, attention_high, attention_vhigh = model(inputs)
        preds = torch.sigmoid(outputs) > 0.5
        labels_byte = labels.byte().squeeze(1)
        # Identify correctly predicted samples
        correct_mask = preds == labels_byte
        correct_preds = preds[correct_mask]
        correct_labels = labels_byte[correct_mask]
        # Only interested in correctly predicted enzymes (label == 1 and pred ==1)
        enzyme_mask = correct_labels == 1
        enzyme_indices = torch.nonzero(enzyme_mask).squeeze(1).tolist()
        if isinstance(enzyme_indices, int):
            enzyme_indices = [enzyme_indices]
        for idx in enzyme_indices:
            # Calculate the global index in the dataset
            global_idx = batch_idx * batch_size + idx
            sequence = all_sequences[global_idx]
            correct_sequences.append(sequence)
            # Extract attention scores
            # Assuming 'attention_high' corresponds to layer3 with 256 channels
            # Here, we average across channels and spatial dimensions to get a single attention score per residue
            attention = attention_high[idx].cpu().numpy().squeeze()[:SEQ_LENGTH]  # Shape: (256, H3, W3)
            attention_scores = attention
            # Resize attention scores to match sequence length if necessary
            # Assuming sequence length is SEQ_LENGTH
            if len(attention_scores) != SEQ_LENGTH:
                attention_scores_resized = np.interp(
                    np.linspace(0, len(attention_scores), num=SEQ_LENGTH),
                    np.arange(len(attention_scores)),
                    attention_scores
                )
            else:
            correct_attentions.append(attention_scores)    
        correct = (preds == labels_byte).sum().item()
        val_correct += correct
        val_total += labels.size(0)
val_accuracy = val_correct / val_total
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Number of correctly predicted enzymes: {len(correct_sequences)}")

Validation Accuracy: 0.8744
Number of correctly predicted enzymes: 1517


In [234]:
# Compute % conversion from enzyme to nonenzyme with alanine substitution, given the number of the substitution

from tqdm import tqdm
def get_top_n_positions(attention_scores, top_n=1):
    """
    Returns the indices of the top N attention scores.
    """
    return np.argsort(attention_scores)[-top_n:][::-1]  # Descending order
    #return np.argsort(attention_scores)[:top_n] #Ascending order
def mutate_sequence(sequence, positions):
    """
    Mutate the specified positions in the sequence to alanine ('A').
    Args:
        sequence (str): Original amino acid sequence.
        positions (list of int): 0-based indices to mutate.
    Returns:
        str: Mutated sequence.
    """
    seq = list(sequence)
    for pos in positions:
        if 0 <= pos < len(seq):
            seq[pos] = 'A'
    return ''.join(seq)
# Prepare lists to store mutated sequences and their indices
mutated_sequences = []
mutation_indices = []
for i, (seq, att) in enumerate(zip(correct_sequences, correct_attentions)):
    top_n = get_top_n_positions(att)
    mutated_seq = mutate_sequence(seq, top_n)
    mutated_sequences.append(mutated_seq)
    mutation_indices.append(top_n)
encoded_mutated_sequences = np.array([encode_sequence(seq, SEQ_LENGTH) for seq in mutated_sequences])  # Shape: (num_samples, SEQ_LENGTH, NUM_AA)
# Expand dimensions to match model input: (num_samples, 1, SEQ_LENGTH, NUM_AA)
X_mutated_np = np.expand_dims(encoded_mutated_sequences, axis=1).astype(np.float32)
class MutatedProteinDataset(Dataset):
    def __init__(self, X):
        self.X = torch.tensor(X, dtype=torch.float32)
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx]
mutated_dataset = MutatedProteinDataset(X_mutated_np)
mutated_loader = DataLoader(mutated_dataset, batch_size=128, shuffle=False)
model.eval()
mutation_changes = 0  # Count of sequences where prediction changed to 0
total_mutations = len(mutated_sequences)
with torch.no_grad():
    for batch_mutated in tqdm(mutated_loader, desc="Evaluating Mutated Sequences"):
        batch_mutated = batch_mutated.to(device)
        outputs, _, _, _, _, _ = model(batch_mutated)
        preds = torch.sigmoid(outputs) > 0.5  # Predictions after mutation
        mutation_changes += (preds == 0).sum()
percentage_changes = (mutation_changes / total_mutations) * 100
print(f"Number of mutations that changed prediction from 1 to 0: {mutation_changes}")
print(f"Percentage of mutations that changed prediction: {percentage_changes:.2f}%")

Evaluating Mutated Sequences: 100%|█████████████| 12/12 [01:52<00:00,  9.39s/it]

Number of mutations that changed prediction from 1 to 0: 190
Percentage of mutations that changed prediction: 12.52%



