In [57]:
import warnings
warnings.filterwarnings("ignore")

In [58]:
import scanpy as sc
import anndata as ad
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim


In [59]:
file = ''
gdata = sc.read_visium(file, count_file=r'C:\Users\KARAN\Desktop\MultiOmics-Research\STAGATE\Landau\SPOTS Landau paper dataset\protein\GSE198353_mmtv_pymt_GEX_filtered_feature_bc_matrix.h5',load_images=True)
gdata.var_names_make_unique()

In [60]:
pdata = pd.read_csv(r'C:\Users\KARAN\Desktop\MultiOmics-Research\STAGATE\Landau\SPOTS Landau paper dataset\protein\GSE198353_mmtv_pymt_ADT_t.csv', index_col=0)

In [61]:
pdata

Unnamed: 0_level_0,CD4,CD8a,CD366,CD279,CD117,Ly-6C,Ly-6G,CD19,CD45,CD25,...,CD11a,P2X7R,CD1d,Notch 4,CD31,Podoplanin,CD45R/B220,CD27,CD11b,CD202b
FIELD1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACAAGTATCTCCCA-1,478,583,877,446,57,0,481,434,3157,430,...,638,1222,1253,1273,1354,4858,656,1067,1193,587
AAACACCAATAACTGC-1,1504,1217,1731,943,64,2,1027,933,6580,872,...,1808,2240,1932,2253,3095,10214,1266,2040,3056,1000
AAACAGGGTCTATATT-1,1526,1231,1433,849,23,1,1322,1515,5964,1020,...,1778,2120,1971,2216,2927,2700,1437,2193,3863,985
AAACAGTGTTCCTGGG-1,847,787,1028,517,67,0,610,567,3476,579,...,939,1266,1242,1268,1742,4985,881,1230,1046,634
AAACATGGTGAGAGGA-1,2317,1770,2347,1475,58,1,1802,2371,7370,1333,...,2573,3359,2859,3335,4107,6983,2777,3766,4316,1723
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTTGGCAATGACTG-1,1669,1378,1766,932,60,0,1070,957,6174,1022,...,1595,2025,1963,2319,2821,6929,1472,2130,1904,1053
TTGTTTCACATCCAGG-1,903,693,919,523,61,2,613,623,3531,526,...,917,1395,1224,1319,1623,9180,870,1226,1119,569
TTGTTTCATTAGTCTA-1,661,528,729,477,23,1,512,404,2571,465,...,714,984,952,1057,1242,4286,699,948,801,404
TTGTTTCCATACAACT-1,1031,857,1250,666,40,0,951,938,5599,782,...,1425,1454,1123,1547,2710,2784,1106,1582,3864,652


In [62]:
gene_data = gdata
protein_data = pdata


In [64]:
# Add protein data to AnnData object
gdata.obsm['protein_data'] = protein_data.values


In [66]:
gdata.obsm['protein_data']

array([[ 478,  583,  877, ..., 1067, 1193,  587],
       [1504, 1217, 1731, ..., 2040, 3056, 1000],
       [1526, 1231, 1433, ..., 2193, 3863,  985],
       ...,
       [ 661,  528,  729, ...,  948,  801,  404],
       [1031,  857, 1250, ..., 1582, 3864,  652],
       [ 861,  720, 1080, ..., 1124, 1002,  492]], dtype=int64)

In [67]:
#Normalization
sc.pp.highly_variable_genes(gdata, flavor="seurat_v3", n_top_genes=50)
sc.pp.normalize_total(gdata, target_sum=1e4)
sc.pp.log1p(gdata)

In [68]:
class GeneProteinDataset(Dataset):
    def __init__(self, gene_data, protein_data):
        self.gene_data = gene_data
        self.protein_data = protein_data
    
    def __len__(self):
        return len(self.gene_data)
    
    def __getitem__(self, index):
        gene_item = self.gene_data[index]
        protein_item = self.protein_data[index]
        
        gene_tensor = torch.Tensor(gene_item)
        protein_tensor = torch.Tensor(protein_item)
        
        return {'gene_data': gene_tensor, 'protein_data': protein_tensor}


In [69]:
class GATE(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GATE, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [75]:
class STAGATE:
    def __init__(self, adata, hidden_dims, alpha, n_epochs):
        self.adata = adata
        self.hidden_dims = hidden_dims
        self.alpha = alpha
        self.n_epochs = n_epochs
        self.scaler = StandardScaler()
        self.gate_net = None
    
    def prepare_graph_data(self):
        # Prepare the graph data (genes and proteins)
        gene_data = self.adata.X.toarray()
        protein_data = self.adata.obsm['protein_data']  # Assuming you have protein data stored in adata.obsm['protein_data']
        
        # Normalize the gene data and protein data
        gene_data = self.scaler.fit_transform(gene_data)
        protein_data = self.scaler.fit_transform(protein_data)
        
        return gene_data, protein_data
    
    def train(self):
        # Prepare the graph data
        gene_data, protein_data = self.prepare_graph_data()
        
        # Create the GATE network
        self.gate_net = GATE(gene_data.shape[1], self.hidden_dims[0])
        
        # Define the loss function and optimizer
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.gate_net.parameters(), lr=self.alpha)
        
        # Convert the numpy arrays to tensors
        gene_data = torch.tensor(gene_data, dtype=torch.float32)
        protein_data = torch.tensor(protein_data, dtype=torch.float32)
        
        # Training loop
        for epoch in range(self.n_epochs):
            # Perform one epoch of training
            for batch in dataloader:
                genes_batch, proteins_batch, spatial_batch = batch

                # Forward pass
                latent_z, pred_genes, pred_proteins = self.model(genes_batch, proteins_batch, spatial_batch)

                # Compute losses
                genes_loss = self.genes_criterion(pred_genes, genes_batch)
                proteins_loss = self.proteins_criterion(pred_proteins, proteins_batch)
                spatial_loss = self.spatial_criterion(latent_z, spatial_batch)

                # Backpropagation and optimization
                total_loss = genes_loss + proteins_loss + self.alpha * spatial_loss
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()

    def cal_spatial_net(self, adata, rad_cutoff=150):
        """Calculate spatial network based on spatial coordinates."""
        self.spatial_net = spatnet.cal_spatial_net(adata, rad_cutoff)

    def stats_spatial_net(self, adata):
        """Compute statistics of the spatial network."""
        if not hasattr(self, 'spatial_net'):
            raise ValueError("Spatial network not initialized. Run 'cal_spatial_net' first.")
        spatnet.stats_spatial_net(adata, self.spatial_net)


In [None]:
# Calculate spatial network
stagate.cal_spatial_net(train_data, rad_cutoff=150)


In [76]:
# Instantiate STAGATE model
hidden_dims = [128]
alpha = 0.001
n_epochs = 100
stagate = STAGATE(gdata, hidden_dims, alpha, n_epochs)

# Train STAGATE model
stagate.train()


RuntimeError: The size of tensor a (128) must match the size of tensor b (32) at non-singleton dimension 1