In [None]:
pip install git+https://github.com/msmalmir/scTransID.git

In [1]:
import scTransID
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from scTransID.data_utils import load_data, preprocess_data, split_data
from scTransID.model import TransformerModel
from scTransID.train import train_model
from scTransID.evaluation import evaluate_on_query
import time

# Paths to sample data (Update these paths as needed)
train_path = './Datasets/hArtery/hArtery_train_adata.h5ad'
test_path = './Datasets/hArtery/hArtery_test_adata.h5ad'


train_adata, query_adata = load_data(train_path, test_path)
X_train, y_train, X_query, le = preprocess_data(train_adata, query_adata)
X_train_tensor, y_train_tensor, X_val_tensor, y_val_tensor = split_data(X_train, y_train)

# Initialize model, loss function, and optimizer
num_genes = X_train_tensor.shape[1]
num_classes = len(le.classes_)
model = TransformerModel(num_genes=num_genes, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Load data into DataLoader
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=2)



# Convert query dataset and true labels to PyTorch tensors
X_query_tensor = torch.tensor(X_query, dtype=torch.float32)
y_query = query_adata.obs['celltype'].values  # Ensure true labels are available
y_query_encoded = le.transform(y_query)  # Transform true labels to encoded form

# Evaluate the model on the query dataset
predicted_celltypes, accuracy, f1 = evaluate_on_query(model, X_query_tensor, y_query_encoded, le)

# Display the results
print("Predicted cell types for the query dataset:")
print(predicted_celltypes)
print(f"Accuracy on query dataset: {accuracy:.2f}")
print(f"F1 score on query dataset: {f1:.2f}")

FileNotFoundError: [Errno 2] Unable to open file (unable to open file: name = './Datasets/hArtery/hArtery_train_adata.h5ad', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
# Add predicted cell types to query_adata
query_adata.obs['predicted_celltypes'] = predicted_celltypes

# UMAP visualization for true vs predicted labels
# First, run UMAP on the query dataset
sc.pp.neighbors(query_adata)  # Compute neighbors
sc.pp.pca(query_adata, n_comps=50)  # n_comps sets the number of principal components
sc.tl.umap(query_adata)  # Run UMAP


In [None]:
# Create subplots (1 row, 2 columns)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# Plot UMAP of true labels
sc.pl.umap(query_adata, color='celltype', title='True Cell Types', show=False, ax=ax[0])

# Plot UMAP of predicted labels
sc.pl.umap(query_adata, color='predicted_celltypes', title='Predicted Cell Types', show=False, ax=ax[1])

# Display the plots side by side
plt.tight_layout()
#plt.savefig('umap1.png')  # You can change the path and file name
plt.show()

In [None]:
# Ensure that both columns contain only strings (convert NaNs to a string as well if they exist)
query_adata.obs['celltype'] = query_adata.obs['celltype'].astype(str)
query_adata.obs['predicted_celltypes'] = query_adata.obs['predicted_celltypes'].astype(str)

# Convert columns to categorical
query_adata.obs['celltype'] = query_adata.obs['celltype'].astype('category')
query_adata.obs['predicted_celltypes'] = query_adata.obs['predicted_celltypes'].astype('category')

# Ensure we have the full set of cell types (even if missing in true labels)
all_cell_types = np.union1d(query_adata.obs['celltype'].unique(), query_adata.obs['predicted_celltypes'].unique())

# Update 'Celltype2' and 'predicted_cell_types' to include all categories
query_adata.obs['celltype'] = query_adata.obs['celltype'].cat.add_categories([ctype for ctype in all_cell_types if ctype not in query_adata.obs['celltype'].cat.categories])
query_adata.obs['predicted_celltypes'] = query_adata.obs['predicted_celltypes'].cat.add_categories([ctype for ctype in all_cell_types if ctype not in query_adata.obs['predicted_celltypes'].cat.categories])

# Create confusion matrix
conf_matrix = confusion_matrix(query_adata.obs['celltype'], query_adata.obs['predicted_celltypes'], labels=all_cell_types)

# Normalize the confusion matrix by row (true labels)
conf_matrix_normalized = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

# Round the normalized confusion matrix
conf_matrix_normalized_rounded = np.round(conf_matrix_normalized, 2)

# Plot rounded, normalized confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_normalized_rounded, annot=True, cmap="Blues", xticklabels=all_cell_types, yticklabels=all_cell_types, fmt='.2f')
plt.title('Rounded Normalized Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
# Save the plot as a PNG file
#plt.savefig('confusion_matrix.png')  # You can change the path and file name
plt.show()
