In [None]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from deepchem.feat import PagtnMolGraphFeaturizer
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import rdFingerprintGenerator
from sklearn.metrics import accuracy_score, precision_score, recall_score
from torch.utils.data import random_split
import torch.nn as nn
from torch_geometric.nn import GATConv, global_max_pool

In [None]:
mol = pd.read_csv("/result.csv")

smiles = mol.drop_duplicates(subset = ['Pubid'], keep='first')
smiles = smiles['Smiles']

smiles_counts = mol['Smiles'].value_counts().reset_index()
pd.set_option('display.max_rows', 10)

# Initialize the featurizer
featurizer = PagtnMolGraphFeaturizer(max_length=5)

# Apply the featurizer to the SMILES series
featurized_data = featurizer.featurize(smiles.tolist())



# Initialize the featurizer
featurizer = PagtnMolGraphFeaturizer(max_length=5)

# Apply the featurizer to the SMILES series
featurized_data = featurizer.featurize(smiles.tolist())

results = []

for i, graph_data in enumerate(featurized_data):
    result = {
        "SMILES": smiles.iloc[i],
        "Node Features": graph_data.node_features.tolist(),  # Convert to list for storage
        "Edge Features": graph_data.edge_features.tolist(),   # Convert to list for storage
        "Edge Index": graph_data.edge_index.tolist()         # Convert to list for storage
    }
    results.append(result)

results_df = pd.DataFrame(results)

results_df.to_csv("node_edge_features.csv", index=False)

# Display the results
print(results_df.head(6))

In [None]:
df = pd.merge(results_df, smiles_counts, left_on='SMILES', right_on='Smiles', how = 'left')
(df['count']).describe()

unique_interactions = mol.groupby("Compound ID")["gene_name"].nunique()

# Display results
df_dedup = unique_interactions.reset_index()
df['number_dedup'] = df_dedup['gene_name']
df['count_binned_custom'] = pd.cut(df['number_dedup'], bins=[0, 16, 60.1, 115], labels=[0, 1, 2])

##Test start

In [None]:
df['count_binned_custom'] = pd.cut(df['number_dedup'], bins=[0, 20, 60.1, 115], labels=[0, 1, 2])
graphs = []

for i, graph_data in enumerate(featurized_data):
    # Convert node features and edge indices into PyTorch tensors
    x = torch.tensor(graph_data.node_features, dtype=torch.float)  # Node features
    edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long).t().contiguous()  # Edge index
    edge_attr = torch.tensor(graph_data.edge_features, dtype=torch.float)  # Edge features
    edge_index = edge_index.T
    y = torch.tensor(df['count_binned_custom'][i], dtype=torch.long)  # Target label

    # Create a PyTorch Geometric Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    print(data)

    graphs.append(data)

# Create a DataLoader
loader = DataLoader(graphs, batch_size=32, shuffle=True)

In [None]:
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)

# Function to compute Morgan Fingerprints
def compute_morgan_fingerprint(smiles, radius=2, fpSize=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        # Generate Morgan fingerprint
        fingerprint = mfpgen.GetFingerprint(mol)
        return np.array(fingerprint)
    else:
        return None

# Have checked, no none
df['morgan_fingerprint'] = df['SMILES'].apply(compute_morgan_fingerprint)

In [None]:
numeric_columns = ['Exact Mass', 'XLogP3', 'Heavy Atom Count', 'Ring Count',
                   'Hydrogen Bond Acceptor Count', 'Hydrogen Bond Donor Count',
                   'Rotatable Bond Count', 'Topological Polar Surface Area']

# Additional columns to include (if any)
additional_columns = []
data = mol.drop_duplicates(subset = ['Pubid'], keep='first')
new_df = pd.merge(data, df, left_on='Smiles', right_on='SMILES')
# Combine the selected columns with the "Features" (fingerprints) to form the input data (X)
df_combined = new_df[numeric_columns + additional_columns].copy()
df_combined = df_combined.astype(float)

In [None]:
df_full = pd.concat([df, df_combined], axis=1, ignore_index=False)

In [None]:
additional_features_list = [
    'Exact Mass', 'XLogP3', 'Heavy Atom Count',
    'Ring Count', 'Hydrogen Bond Acceptor Count', 'Hydrogen Bond Donor Count',
    'Rotatable Bond Count', 'Topological Polar Surface Area'
]
for col in additional_features_list:
    unique_types = df_full[col].apply(type).value_counts()
    print(f"\nColumn: {col} - Unique Data Types:\n{unique_types}")

In [None]:
graphs = []

for i, graph_data in enumerate(featurized_data):
    # Node-level features
    node_features = torch.tensor(graph_data.node_features, dtype=torch.float)  # Original node features

    # Graph-level features
    morgan_fp = df_full.loc[i, 'morgan_fingerprint']  # Extract Morgan fingerprint
    if isinstance(morgan_fp, np.ndarray):
        morgan_fp = torch.tensor(morgan_fp.astype(np.float32), dtype=torch.float)
    else:
        raise TypeError(f"Expected np.ndarray for 'morgan_fingerprint', got {type(morgan_fp)}")

    # Extract additional graph-level features from df_full
    additional_features = torch.tensor(df_full.loc[i, additional_features_list], dtype=torch.float)

    # Combine all graph-level features into a single tensor
    graph_features = torch.cat([morgan_fp, additional_features], dim=0)  # Shape: [2048 + len(additional_features_list)]

    # Edge indices and attributes
    edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
    edge_attr = torch.tensor(graph_data.edge_features, dtype=torch.float)

    # Target label
    y = torch.tensor(df_full['count_binned_custom'][i], dtype=torch.long)

    # Create a PyTorch Geometric Data object
    data = Data(
        x=node_features,              # Node-level features
        edge_index=edge_index,        # Edge indices
        edge_attr=edge_attr,          # Edge features
        y=y,                          # Target label
        graph_features=graph_features  # Graph-level features
    )
    print(data)
    graphs.append(data)

# Create a DataLoader
loader = DataLoader(graphs, batch_size=32, shuffle=True)

In [None]:
class GATWithGraphFeatures(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_graph_features, hidden_channels, num_classes):
        super(GATWithGraphFeatures, self).__init__()

        # GAT layers for node features
        self.conv1 = GATConv(num_node_features, hidden_channels, edge_dim=num_edge_features)
        self.conv2 = GATConv(hidden_channels , hidden_channels, edge_dim=num_edge_features)

        # Linear layers for node feature processing after pooling
        self.node_lin1 = nn.Linear(hidden_channels, hidden_channels)

        # Linear layers for graph-level features
        self.graph_lin1 = nn.Linear(num_graph_features * 4, hidden_channels)
        self.graph_lin2 = nn.Linear(hidden_channels, hidden_channels)

        # Final linear layer combining node and graph features
        self.final_lin = nn.Linear(hidden_channels * 2, num_classes)  # Combine node and graph embeddings

        # Regularization layers
        self.dropout = nn.Dropout(p=0.5)
        self.batch_norm1 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm2 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm_graph = nn.BatchNorm1d(hidden_channels)  # For graph features

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        graph_features = data.graph_features  # Graph-level features

        # First graph convolution layer (with edge attributes)
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm1(x)

        # Second graph convolution layer (with edge attributes)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm2(x)

        # Global pooling to get graph-level representation from node features
        node_embedding = global_max_pool(x, batch)  # Shape: [batch_size, hidden_channels * num_heads]
        node_embedding = self.node_lin1(node_embedding)
        node_embedding = F.relu(node_embedding)
        node_embedding = self.dropout(node_embedding)

        # Process graph-level features
        graph_embedding = self.graph_lin1(graph_features)
        graph_embedding = F.relu(graph_embedding)
        graph_embedding = self.graph_lin2(graph_embedding)
        graph_embedding = F.relu(graph_embedding)
        graph_embedding = self.dropout(graph_embedding)

        # Combine node and graph embeddings
        graph_embedding_expanded = graph_embedding.unsqueeze(0).expand(node_embedding.size(0), -1)  # [32, 64]
        combined = torch.cat([node_embedding, graph_embedding_expanded], dim=-1)  # Shape: [batch_size, hidden_channels * 2]

        # Debugging shape before final linear layer

        # Final classification layer
        out = self.final_lin(combined)  # This should match the expected size
        return F.log_softmax(out, dim=-1)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphOnlyModel(nn.Module):
    def __init__(self, num_graph_features, hidden_channels, num_classes):
        super(GraphOnlyModel, self).__init__()
        self.graph_lin1 = nn.Linear(num_graph_features, hidden_channels)
        self.graph_lin2 = nn.Linear(hidden_channels, hidden_channels)
        self.out = nn.Linear(hidden_channels, num_classes)

        self.dropout = nn.Dropout(0.5)
        self.batch_norm1 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm2 = nn.BatchNorm1d(hidden_channels)

    def forward(self, data):
        x = data.graph_features  # [batch_size, num_graph_features]

        x = self.graph_lin1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm1(x)

        x = self.graph_lin2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm2(x)

        x = self.out(x)
        return F.log_softmax(x, dim=-1)

In [None]:
def train_and_evaluate(graphs_subset, label):
    train_subset, test_subset = torch.utils.data.random_split(
        graphs_subset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    train_loader = DataLoader(train_subset, batch_size=4, shuffle=True, drop_last = True)
    test_loader = DataLoader(test_subset, batch_size=4, drop_last = True)

    model = GATWithGraphFeatures(
        num_node_features=num_node_features,
        num_edge_features=num_edge_features,
        num_graph_features=len(graphs_subset[0].graph_features),
        hidden_channels=64,
        num_classes=3
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(30):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = F.nll_loss(out, batch.y)
            loss.backward()
            optimizer.step()

    acc, prec, rec = evaluate_more(model, test_loader)
    print(f"[{label}] Test Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}")
    return acc

In [None]:
def remove_numeric_feature(graphs, numeric_index_to_remove):
    """
    Removes a specific numeric feature from the graph_features tensor of each graph.

    Parameters:
    - graphs: list of PyG Data objects
    - numeric_index_to_remove: int, index (0–6) of the numeric feature to remove
      (0 = first numeric feature, corresponds to 'Exact Mass', etc.)

    Returns:
    - modified_graphs: list of modified PyG Data objects
    """
    modified_graphs = []
    for g in graphs:
        new_graph = g.clone()
        # Split graph features
        morgan_fp = new_graph.graph_features[:2048]
        numeric_feats = new_graph.graph_features[2048:]

        # Remove one numeric feature
        mask = torch.ones(numeric_feats.size(0), dtype=torch.bool)
        mask[numeric_index_to_remove] = False
        reduced_numeric_feats = numeric_feats[mask]

        # Combine back
        new_graph.graph_features = torch.cat([morgan_fp, reduced_numeric_feats], dim=0)
        modified_graphs.append(new_graph)
    return modified_graphs


In [None]:
for i, feature_name in enumerate(numeric_columns):
    print(f"\n>>> Removing feature: {feature_name}")
    ablated_graphs = remove_numeric_feature(graphs, i)
    # Then rerun training/eval using `ablated_graphs`

In [None]:
# Step 1: Remove the numeric feature (e.g., index 2 = 'Ring Count')
ablated_graphs = remove_numeric_feature(graphs, 0)

# Step 2: Train/test split (80/20)
from sklearn.model_selection import train_test_split

train_indices, test_indices = train_test_split(
    list(range(len(ablated_graphs))), test_size=0.2, random_state=42
)
train_graphs = [ablated_graphs[i] for i in train_indices]
test_graphs = [ablated_graphs[i] for i in test_indices]

# Step 3: Create DataLoaders
train_loader = DataLoader(train_graphs, batch_size=4, shuffle=True, drop_last=True)
test_loader = DataLoader(test_graphs, batch_size=4, shuffle=False, drop_last=True)

# Step 4: Re-initialize model
num_node_features = ablated_graphs[0].num_node_features
num_edge_features = ablated_graphs[0].edge_attr.size(1) if ablated_graphs[0].edge_attr is not None else 0
num_graph_features = ablated_graphs[0].graph_features.size(0)

model = GATWithGraphFeatures(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    num_graph_features=num_graph_features,
    hidden_channels=64,
    num_classes=3
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Step 5: Training loop
for epoch in range(30):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# Step 6: Evaluate on test set
accuracy, precision, recall = evaluate_more(model, test_loader)
print("\n=== Test Performance ===")
print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")

In [None]:
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)

# Function to compute Morgan Fingerprints
def compute_morgan_fingerprint(smiles, radius=2, fpSize=2048):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        # Generate Morgan fingerprint
        fingerprint = mfpgen.GetFingerprint(mol)
        return np.array(fingerprint)
    else:
        return None

# Have checked, no none
df['morgan_fingerprint'] = df['SMILES'].apply(compute_morgan_fingerprint)

numeric_columns = ['Exact Mass', 'XLogP3', 'Ring Count',
                   'Hydrogen Bond Acceptor Count', 'Hydrogen Bond Donor Count',
                   'Rotatable Bond Count', 'Topological Polar Surface Area']

# Additional columns to include (if any)
additional_columns = []
data = mol.drop_duplicates(subset = ['Pubid'], keep='first')
new_df = pd.merge(data, df, left_on='Smiles', right_on='SMILES')
# Combine the selected columns with the "Features" (fingerprints) to form the input data (X)
df_combined = new_df[numeric_columns + additional_columns].copy()
df_combined = df_combined.astype(float)

df_full = pd.concat([df, df_combined], axis=1, ignore_index=False)

graphs = []

for i, graph_data in enumerate(featurized_data):
    # Node-level features
    node_features = torch.tensor(graph_data.node_features, dtype=torch.float)  # Original node features

    # Graph-level features
    morgan_fp = df_full.loc[i, 'morgan_fingerprint']  # Extract Morgan fingerprint
    if isinstance(morgan_fp, np.ndarray):
        morgan_fp = torch.tensor(morgan_fp.astype(np.float32), dtype=torch.float)
    else:
        raise TypeError(f"Expected np.ndarray for 'morgan_fingerprint', got {type(morgan_fp)}")

    # Extract additional graph-level features from df_full
    additional_features = torch.tensor(df_full.loc[i, numeric_columns], dtype=torch.float)

    # Combine all graph-level features into a single tensor
    graph_features = torch.cat([morgan_fp, additional_features], dim=0)  # Shape: [2048 + len(additional_features_list)]

    # Edge indices and attributes
    edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
    edge_attr = torch.tensor(graph_data.edge_features, dtype=torch.float)

    # Target label
    y = torch.tensor(df_full['count_binned_custom'][i], dtype=torch.long)

    # Create a PyTorch Geometric Data object
    data = Data(
        x=node_features,              # Node-level features
        edge_index=edge_index,        # Edge indices
        edge_attr=edge_attr,          # Edge features
        y=y,                          # Target label
        graph_features=graph_features  # Graph-level features
    )
    print(data)
    graphs.append(data)

# Create a DataLoader
loader = DataLoader(graphs, batch_size=32, shuffle=True)


class GAT(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_channels, num_classes):
        super(GAT, self).__init__()
        # Use GATConv which supports edge features
        self.conv1 = GATConv(num_node_features, hidden_channels, edge_dim=num_edge_features)
        self.conv2 = GATConv(hidden_channels, hidden_channels, edge_dim=num_edge_features)
        self.lin = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # First graph convolution layer (with edge attributes)
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)

        # Second graph convolution layer (with edge attributes)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)

        # Apply global mean pooling to aggregate node features
        x = global_mean_pool(x, batch)

        # Apply dropout for regularization
        x = F.dropout(x, p=0.5, training=self.training)

        # Linear layer to get final logits
        x = self.lin(x)

        # LogSoftmax for multi-class classification
        return F.log_softmax(x, dim=-1)

class GATWithGraphFeatures(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_graph_features, hidden_channels, num_classes):
        super(GATWithGraphFeatures, self).__init__()

        # GAT layers for node features
        self.conv1 = GATConv(num_node_features, hidden_channels, edge_dim=num_edge_features)
        self.conv2 = GATConv(hidden_channels , hidden_channels, edge_dim=num_edge_features)

        # Linear layers for node feature processing after pooling
        self.node_lin1 = nn.Linear(hidden_channels, hidden_channels)

        # Linear layers for graph-level features
        self.graph_lin1 = nn.Linear(num_graph_features * 43, hidden_channels)
        self.graph_lin2 = nn.Linear(hidden_channels, hidden_channels)

        # Final linear layer combining node and graph features
        self.final_lin = nn.Linear(hidden_channels * 2, num_classes)  # Combine node and graph embeddings

        # Regularization layers
        self.dropout = nn.Dropout(p=0.5)
        self.batch_norm1 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm2 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm_graph = nn.BatchNorm1d(hidden_channels)  # For graph features

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        graph_features = data.graph_features  # Graph-level features

        # First graph convolution layer (with edge attributes)
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm1(x)

        # Second graph convolution layer (with edge attributes)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.batch_norm2(x)

        # Global pooling to get graph-level representation from node features
        node_embedding = global_max_pool(x, batch)  # Shape: [batch_size, hidden_channels * num_heads]
        node_embedding = self.node_lin1(node_embedding)
        node_embedding = F.relu(node_embedding)
        node_embedding = self.dropout(node_embedding)

        # Process graph-level features
        graph_embedding = self.graph_lin1(graph_features)
        graph_embedding = F.relu(graph_embedding)
        graph_embedding = self.graph_lin2(graph_embedding)
        graph_embedding = F.relu(graph_embedding)
        graph_embedding = self.dropout(graph_embedding)

        # Combine node and graph embeddings
        graph_embedding_expanded = graph_embedding.unsqueeze(0).expand(node_embedding.size(0), -1)  # [32, 64]
        combined = torch.cat([node_embedding, graph_embedding_expanded], dim=-1)  # Shape: [batch_size, hidden_channels * 2]

        # Debugging shape before final linear layer

        # Final classification layer
        out = self.final_lin(combined)  # This should match the expected size
        return F.log_softmax(out, dim=-1)


# Assuming 'graphs' is already created with node and graph-level features
num_graphs = 172  # Adjust based on dataset size

# Split into train and test sets
train_size = 129
test_size = num_graphs - train_size
train_graphs, test_graphs = torch.utils.data.random_split(
    graphs,
    [train_size, test_size],
    generator=torch.Generator().manual_seed(42)  # Fixed seed for splits
)

# Create DataLoaders with consistent shuffling
train_loader = DataLoader(
    train_graphs,
    batch_size=43,
    shuffle=True,
    generator=torch.Generator().manual_seed(42)  # Fixed shuffle seed
)
test_loader = DataLoader(
    test_graphs,
    batch_size=43,
    shuffle=False  # Never shuffle test set!
)

# Get feature dimensions
num_node_features = graphs[0].num_node_features
num_edge_features = graphs[0].edge_attr.size(1) if graphs[0].edge_attr is not None else 0
num_graph_features = graphs[0].graph_features.size(0)  # From Morgan FP + additional features (e.g., 2056)
# num_graph_features = 0
# Initialize the model
model = GATWithGraphFeatures(
    num_node_features=num_node_features,
    num_edge_features=num_edge_features,
    num_graph_features=num_graph_features,
    hidden_channels=64,
    num_classes=3
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Function to evaluate the model (simple accuracy)
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            out = model(batch)
            pred = out.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.y.size(0)
    return correct / total

# Function to evaluate the model with more metrics
def evaluate_more(model, loader):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in loader:
            out = model(batch)
            pred = out.argmax(dim=1)
            y_true.extend(batch.y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    return accuracy, precision, recall

# Training loop
for epoch in range(30):
    model.train()
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    # Print the loss for the current epoch
    print(f"Epoch {epoch+1}/100, Loss: {epoch_loss/len(train_loader):.4f}")

    # Evaluate on the training set (you can change this to test_loader if preferred)
    val_accuracy = evaluate(model, train_loader)  # Changed from 'loader' to 'train_loader'
    print(f"Validation Accuracy after Epoch {epoch+1}: {val_accuracy:.4f}")

# Final evaluation on test set
test_loader = DataLoader(test_graphs, batch_size=43, shuffle=True, drop_last = True)
test_accuracy, test_precision, test_recall = evaluate_more(model, test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")