In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import networkx as nx
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph
from GraphRicciCurvature.OllivierRicci import OllivierRicci
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, classification_report
from scipy.stats import ttest_ind

# Load scRNA data
data_path = '/users/barmanjy/Desktop/Persister Cell/GSE150949_scRNA.csv'
data = pd.read_csv(data_path,engine='python',encoding='utf-8')

# Ensure 'persister_label' column exists
if 'persister_label' not in data.columns:
    np.random.seed(42)  # For reproducibility
    data['persister_label'] = np.random.randint(0, 2, size=len(data))

actual_labels = data['persister_label'].values

# Drop the label column from the data
data = data.drop(columns=['persister_label'])

# Preprocess the scRNA data
def preprocess_scRNA_data(data):
    adata = sc.AnnData(data)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    return adata.X

# Preprocess scRNA data
X_data = preprocess_scRNA_data(data)

# Apply PCA for dimensionality reduction
pca = PCA(n_components=50)
X_pca = pca.fit_transform(X_data)

# Construct k-NN graph
def construct_knn_graph(X, k):
    knn_graph = kneighbors_graph(X, k, mode='connectivity', include_self=False)
    coo_graph = knn_graph.tocoo()  # Convert to COO format
    G = nx.from_edgelist(zip(coo_graph.row, coo_graph.col))  # Create graph from COO format
    return G

# Calculate Ollivier-Ricci curvature
def compute_ricci_curvature(G, alpha=0.5):
    orc = OllivierRicci(G, alpha=alpha, verbose="INFO")
    orc.compute_ricci_curvature()
    return nx.get_edge_attributes(orc.G, "ricciCurvature")

# Identify high proliferation index cells with alternative curvature aggregation methods
def identify_high_proliferation_nodes(ricci_curvatures, method='threshold', threshold=None, top_percentile=None, centrality_measure=None):
    if method == 'threshold':
        high_proliferation_edges = [(u, v) for (u, v), curvature in ricci_curvatures.items() if curvature > threshold]
    elif method == 'top_percentile':
        sorted_curvatures = sorted(ricci_curvatures.items(), key=lambda x: x[1], reverse=True)
        top_n = int(len(sorted_curvatures) * top_percentile / 100)
        high_proliferation_edges = [(u, v) for (u, v), _ in sorted_curvatures[:top_n]]
    elif method == 'centrality':
        centrality_scores = nx.betweenness_centrality(G, weight='ricciCurvature')
        sorted_centrality = sorted(centrality_scores.items(), key=lambda x: x[1], reverse=True)
        top_n = int(len(sorted_centrality) * top_percentile / 100)
        high_proliferation_nodes = [node for node, _ in sorted_centrality[:top_n]]
        high_proliferation_edges = [(u, v) for (u, v) in G.edges() if u in high_proliferation_nodes or v in high_proliferation_nodes]
    else:
        raise ValueError("Invalid method. Choose from 'threshold', 'top_percentile', or 'centrality'.")
    
    high_proliferation_nodes = set([node for edge in high_proliferation_edges for node in edge])
    return high_proliferation_nodes

# Differential Expression Analysis
def differential_expression_analysis(data, high_proliferation_nodes):
    if 'persister_label' in data.columns:
        persister_labels = data['persister_label'].values
        high_proliferation_labels = np.zeros_like(persister_labels)
        high_proliferation_labels[list(high_proliferation_nodes)] = 1

        # Perform differential expression analysis
        differential_genes = []
        for gene_index in range(data.shape[1]):
            gene_expr_high_prolif = data[high_proliferation_labels == 1, gene_index]
            gene_expr_non_high_prolif = data[high_proliferation_labels == 0, gene_index]
            _, p_value = ttest_ind(gene_expr_high_prolif, gene_expr_non_high_prolif)
            if p_value < 0.05:  # Adjust significance level as needed
                differential_genes.append(gene_index)
        return differential_genes
    else:
        print("Ground truth labels ('persister_label') not available. Skipping differential expression analysis.")
        return None

# Usage
k_values = [5, 10, 15, 20]
threshold = 0.5  # Threshold for identifying high proliferation cells
top_percentile = 5  # Top percentile for alternative aggregation methods
centrality_percentile = 5  # Percentile of nodes for centrality-based aggregation

# Experiment with different k values for kNN graph
for k in k_values:
    G = construct_knn_graph(X_pca, k)
    ricci_curvatures = compute_ricci_curvature(G)
    
    # Identify high proliferation cells using thresholding method
    high_proliferation_nodes_threshold = identify_high_proliferation_nodes(ricci_curvatures, method='threshold', threshold=threshold)
    print(f"High Proliferation Cells with threshold method and k={k}:", high_proliferation_nodes_threshold)
    
    # Identify high proliferation cells using top percentile method
    high_proliferation_nodes_percentile = identify_high_proliferation_nodes(ricci_curvatures, method='top_percentile', top_percentile=top_percentile)
    print(f"High Proliferation Cells with top percentile method and k={k}:", high_proliferation_nodes_percentile)
    
    # Identify high proliferation cells using centrality method
    high_proliferation_nodes_centrality = identify_high_proliferation_nodes(ricci_curvatures, method='centrality', top_percentile=centrality_percentile)
    print(f"High Proliferation Cells with centrality method and k={k}:", high_proliferation_nodes_centrality)

    # Differential expression analysis
    differential_genes = differential_expression_analysis(X_data, high_proliferation_nodes_threshold)
    if differential_genes is not None:
        print("Differential Genes with threshold method:", differential_genes)
    
    differential_genes = differential_expression_analysis(X_data, high_proliferation_nodes_percentile)
    if differential_genes is not None:
        print("Differential Genes with top percentile method:", differential_genes)
    
    differential_genes = differential_expression_analysis(X_data, high_proliferation_nodes_centrality)
    if differential_genes is not None:
        print("Differential Genes with centrality method:", differential_genes)

    # Visualize curvature distribution
    plt.figure(figsize=(10, 6))
    sns.histplot(list(ricci_curvatures.values()), bins=50, kde=True)
    plt.axvline(np.mean(list(ricci_curvatures.values())), color='r', linestyle='--')
    plt.title(f'Distribution of Ricci Curvatures with k={k}')
    plt.xlabel('Ricci Curvature')
    plt.ylabel('Frequency')
    plt.show()

    # Visualize high proliferation nodes on the PCA plot
    plt.figure(figsize=(10, 6))
    plt.scatter(X_pca[:, 0], X_pca[:, 1], c='gray', alpha=0.5, label='All Cells')
    plt.scatter(X_pca[list(high_proliferation_nodes_threshold), 0], X_pca[list(high_proliferation_nodes_threshold), 1], c='red', label='High Proliferation Cells (Threshold)')
    plt.scatter(X_pca[list(high_proliferation_nodes_percentile), 0], X_pca[list(high_proliferation_nodes_percentile), 1], c='blue', label='High Proliferation Cells (Percentile)')
    plt.scatter(X_pca[list(high_proliferation_nodes_centrality), 0], X_pca[list(high_proliferation_nodes_centrality), 1], c='green', label='High Proliferation Cells (Centrality)')
    plt.title(f'PCA of scRNA-seq Data with k={k}')
    plt.xlabel('PCA 1')
    plt.ylabel('PCA 2')
    plt.legend()
    plt.show()




In [7]:
import numpy as np
import pandas as pd
import scanpy as sc
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests
from transformers import BertModel, BertTokenizer, GPT2Model, GPT2Tokenizer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Load scRNA data
data_path = '/users/barmanjy/Desktop/Persister Cell/GSE150949_scRNA.csv'
data = pd.read_csv(data_path)

# Ensure 'persister_label' column exists
if 'persister_label' not in data.columns:
    np.random.seed(42)  # For reproducibility
    data['persister_label'] = np.random.randint(0, 2, size=len(data))

actual_labels = data['persister_label'].values

# Drop the label column from the data
data = data.drop(columns=['persister_label'])

# Preprocess the scRNA data
def preprocess_scRNA_data(data):
    adata = sc.AnnData(data)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    return adata.X

# Preprocess scRNA data
X_data = preprocess_scRNA_data(data)

# Standardize the data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_data)

# Split data for training and testing
X_train, X_test, y_train, y_test = train_test_split(X_scaled, actual_labels, test_size=0.2, random_state=42)

# Define custom dataset for DataLoader
class RNASeqDataset(Dataset):
    def __init__(self, X, y, tokenizer, model_type='bert', max_length=512):
        self.X = X
        self.y = y
        self.tokenizer = tokenizer
        self.model_type = model_type
        self.max_length = max_length

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sequence = self.X[idx]
        label = self.y[idx]

        if self.model_type == 'bert':
            inputs = self.tokenizer(sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()}
        elif self.model_type == 'gpt':
            inputs = self.tokenizer.encode(sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)[0]

        return inputs, label

# Define scBERT model
class SCBERTClassifier(nn.Module):
    def __init__(self, pretrained_model='bert-base-uncased', num_classes=2):
        super(SCBERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Define scGPT model
class SCGPTClassifier(nn.Module):
    def __init__(self, pretrained_model='gpt2', num_classes=2):
        super(SCGPTClassifier, self).__init__()
        self.gpt = GPT2Model.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.gpt.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.gpt(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state.mean(dim=1)
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Instantiate tokenizer and models
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
scbert_model = SCBERTClassifier()
scgpt_model = SCGPTClassifier()

# Define training parameters
batch_size = 32
num_epochs = 10
learning_rate = 1e-4

# Create datasets and dataloaders
train_dataset_bert = RNASeqDataset(X_train, y_train, bert_tokenizer, model_type='bert')
test_dataset_bert = RNASeqDataset(X_test, y_test, bert_tokenizer, model_type='bert')
train_dataloader_bert = DataLoader(train_dataset_bert, batch_size=batch_size, shuffle=True)
test_dataloader_bert = DataLoader(test_dataset_bert, batch_size=batch_size)

train_dataset_gpt = RNASeqDataset(X_train, y_train, gpt_tokenizer, model_type='gpt')
test_dataset_gpt = RNASeqDataset(X_test, y_test, gpt_tokenizer, model_type='gpt')
train_dataloader_gpt = DataLoader(train_dataset_gpt, batch_size=batch_size, shuffle=True)
test_dataloader_gpt = DataLoader(test_dataset_gpt, batch_size=batch_size)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_bert = optim.AdamW(scbert_model.parameters(), lr=learning_rate)
optimizer_gpt = optim.AdamW(scgpt_model.parameters(), lr=learning_rate)

# Training loop for scBERT model
for epoch in range(num_epochs):
    scbert_model.train()
    running_loss = 0.0
    for inputs, labels in train_dataloader_bert:
        optimizer_bert.zero_grad()

        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        outputs = scbert_model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_bert.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {running_loss / len(train_dataloader_bert)}")

# Evaluation loop for scBERT model
scbert_model.eval()
correct_bert = 0
total_bert = 0
with torch.no_grad():
    for inputs, labels in test_dataloader_bert:
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        outputs = scbert_model(input_ids, attention_mask)
        _, predicted = torch.max(outputs, 1)
        total_bert += labels.size(0)
        correct_bert += (predicted == labels).sum().item()

print('Accuracy of scBERT model on test set:', (correct_bert / total_bert))

# Training loop for scGPT model
for epoch in range(num_epochs):
    scgpt_model.train()
    running_loss = 0.0
    for inputs, labels in train_dataloader_gpt:
        optimizer_gpt.zero_grad()

        input_ids = inputs
        outputs = scgpt_model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_gpt.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {running_loss / len(train_dataloader_gpt)}")

# Evaluation loop for scGPT model
scgpt_model.eval()
correct_gpt = 0
total_gpt = 0
with torch.no_grad():
    for inputs, labels in test_dataloader_gpt:
        input_ids = inputs
        outputs = scgpt_model(input_ids)
        _, predicted = torch.max(outputs, 1)
        total_gpt += labels.size(0)
        correct_gpt += (predicted == labels).sum().item()

print('Accuracy of scGPT model on test set:', (correct_gpt / total_gpt))


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
29/05/2024

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, GPT2Model, GPT2Tokenizer
import torch
import torch.nn as nn
import torch.optim as optim
import optuna

# Load scRNA data
data_path = '/users/barmanjy/Desktop/Persister Cell/GSE150949_scRNA.csv'
data = pd.read_csv(data_path)

# Ensure 'persister_label' column exists
if 'persister_label' not in data.columns:
    np.random.seed(42)  # For reproducibility
    data['persister_label'] = np.random.randint(0, 2, size=len(data))

actual_labels = data['persister_label'].values

# Drop the label column from the data
data = data.drop(columns=['persister_label'])

# Preprocess the scRNA data
def preprocess_scRNA_data(data):
    adata = sc.AnnData(sparse.csr_matrix(data.values))  # Convert to sparse matrix
    sc.pp.normalize_total(adata, target_sum=1e4)  # Normalize
    sc.pp.log1p(adata)  # Logarithmic transformation
    sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True)  # Select highly variable genes
    adata.obs['batch'] = np.random.randint(0, 2, size=adata.shape[0])  # Dummy batch column for batch correction
    sc.pp.combat(adata, key='batch')  # Batch correction
    return adata.X

# Preprocess scRNA data
X_data = preprocess_scRNA_data(data)

# Standardize the data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_data)

# Split data for training and testing
X_train, X_test, y_train, y_test = train_test_split(X_scaled, actual_labels, test_size=0.2, random_state=42)

# Define custom dataset for DataLoader
class RNASeqDataset(Dataset):
    def __init__(self, X, y, tokenizer, model_type='bert', max_length=512):
        self.X = X
        self.y = y
        self.tokenizer = tokenizer
        self.model_type = model_type
        self.max_length = max_length

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sequence = ' '.join(map(str, self.X[idx]))
        label = self.y[idx]

        if self.model_type == 'bert':
            inputs = self.tokenizer(sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()}
        elif self.model_type == 'gpt':
            inputs = self.tokenizer.encode(sequence, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length).squeeze(0)

        return inputs, label

# Define scBERT model
class SCBERTClassifier(nn.Module):
    def __init__(self, pretrained_model='bert-base-uncased', num_classes=2):
        super(SCBERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Define scGPT model
class SCGPTClassifier(nn.Module):
    def __init__(self, pretrained_model='gpt2', num_classes=2):
        super(SCGPTClassifier, self).__init__()
        self.gpt = GPT2Model.from_pretrained(pretrained_model)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.gpt.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.gpt(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state.mean(dim=1)
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Instantiate tokenizer and models
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
scbert_model = SCBERTClassifier()
scgpt_model = SCGPTClassifier()

# Define training parameters
batch_size = 32
num_epochs = 10
learning_rate = 1e-4

# Create datasets and dataloaders
train_dataset_bert = RNASeqDataset(X_train, y_train, bert_tokenizer, model_type='bert')
train_dataset_gpt = RNASeqDataset(X_train, y_train, gpt_tokenizer, model_type='gpt')

test_dataset_bert = RNASeqDataset(X_test, y_test, bert_tokenizer, model_type='bert')
test_dataset_gpt = RNASeqDataset(X_test, y_test, gpt_tokenizer, model_type='gpt')

train_dataloader_bert = DataLoader(train_dataset_bert, batch_size=batch_size, shuffle=True)
test_dataloader_bert = DataLoader(test_dataset_bert, batch_size=batch_size, shuffle=False)

train_dataloader_gpt = DataLoader(train_dataset_gpt, batch_size=batch_size, shuffle=True)
test_dataloader_gpt = DataLoader(test_dataset_gpt, batch_size=batch_size, shuffle=False)

# Define training and evaluation functions
criterion = nn.CrossEntropyLoss()

def train_model(model, optimizer, train_dataloader, criterion, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in train_dataloader:
            optimizer.zero_grad()
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_dataloader)}')

def evaluate_model(model, test_dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            outputs = model(input_ids, attention_mask)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Hyperparameter tuning with Optuna for scBERT
def objective(trial):
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)

    scbert_model = SCBERTClassifier()
    scbert_model.dropout = nn.Dropout(dropout_rate)
    optimizer_bert = optim.AdamW(scbert_model.parameters(), lr=learning_rate)

    train_dataloader_bert = DataLoader(train_dataset_bert, batch_size=batch_size, shuffle=True)

    best_loss = float('inf')
    early_stop_count = 0
    for epoch in range(num_epochs):
        scbert_model.train()
        running_loss = 0.0
        for inputs, labels in train_dataloader_bert:
            optimizer_bert.zero_grad()
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            outputs = scbert_model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_bert.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_dataloader_bert)
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            early_stop_count = 0
        else:
            early_stop_count += 1
            if early_stop_count >= 3:
                break
    return best_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=50)
best_params = study.best_params

batch_size = best_params['batch_size']
learning_rate = best_params['learning_rate']
dropout_rate = best_params['dropout_rate']

# Train final scBERT model with best hyperparameters
scbert_model = SCBERTClassifier()
scbert_model.dropout = nn.Dropout(dropout_rate)
optimizer_bert = optim.AdamW(scbert_model.parameters(), lr=learning_rate)
train_dataloader_bert = DataLoader(train_dataset_bert, batch_size=batch_size, shuffle=True)
train_model(scbert_model, optimizer_bert, train_dataloader_bert, criterion, num_epochs)

# Evaluate scBERT model on test set
test_accuracy_bert = evaluate_model(scbert_model, test_dataloader_bert)
print('Accuracy of scBERT model on test set:', test_accuracy_bert)

# Hyperparameter tuning with Optuna for scGPT
def objective_gpt(trial):
    batch_size_gpt = trial.suggest_categorical('batch_size_gpt', [16, 32, 64])
    learning_rate_gpt = trial.suggest_float('learning_rate_gpt', 1e-5, 1e-3, log=True)
    dropout_rate_gpt = trial.suggest_float('dropout_rate_gpt', 0.1, 0.5)

    scgpt_model = SCGPTClassifier()
    scgpt_model.dropout = nn.Dropout(dropout_rate_gpt)
    optimizer_gpt = optim.AdamW(scgpt_model.parameters(), lr=learning_rate_gpt)
    train_dataloader_gpt = DataLoader(train_dataset_gpt, batch_size=batch_size_gpt, shuffle=True)

    best_loss_gpt = float('inf')
    early_stop_count_gpt = 0
    for epoch in range(num_epochs):
        scgpt_model.train()
        running_loss_gpt = 0.0
        for inputs, labels in train_dataloader_gpt:
            optimizer_gpt.zero_grad()
            input_ids = inputs
            outputs = scgpt_model(input_ids)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_gpt.step()
            running_loss_gpt += loss.item()
        epoch_loss_gpt = running_loss_gpt / len(train_dataloader_gpt)
        if epoch_loss_gpt < best_loss_gpt:
            best_loss_gpt = epoch_loss_gpt
            early_stop_count_gpt = 0
        else:
            early_stop_count_gpt += 1
            if early_stop_count_gpt >= 3:
                break
    return best_loss_gpt

study_gpt = optuna.create_study(direction='minimize')
study_gpt.optimize(objective_gpt, n_trials=50)
best_params_gpt = study_gpt.best_params

batch_size_gpt = best_params_gpt['batch_size_gpt']
learning_rate_gpt = best_params_gpt['learning_rate_gpt']
dropout_rate_gpt = best_params_gpt['dropout_rate_gpt']

# Train final scGPT model with best hyperparameters
scgpt_model = SCGPTClassifier() 
scgpt_model.dropout = nn.Dropout(dropout_rate_gpt)
optimizer_gpt = optim.AdamW(scgpt_model.parameters(), lr=learning_rate_gpt)
train_dataloader_gpt = DataLoader(train_dataset_gpt, batch_size=batch_size_gpt, shuffle=True)
train_model(scgpt_model, optimizer_gpt, train_dataloader_gpt, criterion, num_epochs)

# Evaluate scGPT model on test set
test_accuracy_gpt = evaluate_model(scgpt_model, test_dataloader_gpt)
print('Accuracy of scGPT model on test set:', test_accuracy_gpt)

