In [1]:
import scanpy as sc
import muon as mu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_curve, auc

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# using different methods to get the initial graph construction of the cells

def permutate_graph_construct(adata, dim=50, init_neighbor_method='louvain', res=1, batch_size=256, device='cuda'):
    """To use contrastive learning, we need to permutate the graph construct of the data."""
    adata_perm = adata.copy()
    # Permutate the cell expression matrix and get the new graph cnostruct of the data
    adata_perm.X = np.random.permutation(adata_perm.X)
    if init_neighbor_method == 'louvain':
        # Get the louvain graph construct of the data
        expression = adata_perm.X
        sc.pp.neighbors(adata_perm, use_rep='X', n_neighbors=30)
        sc.tl.louvain(adata_perm, resolution=res)


def get_contrastive_loss(adata, adata_perm, dim=50, batch_size=256, device='cuda'):
    """Get the contrastive loss of the data."""
    # Get the contrastive loss
    model = mu.models.MuonNet(adata_perm, n_hidden=dim, n_layers=2, dropout=0.1, batchnorm=True, residual=True)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    loader = mu.data.DataLoader(adata, batch_size=batch_size, shuffle=True, device=device)
    loss = mu.train.train(model, loader, optimizer, criterion, epochs=100, verbose=False)
    return loss