In [1]:
import sys
sys.path.append("../")

In [2]:
from scGraphLLM.utils import node_batching
import pandas as pd
import numpy as np
import scanpy as sc 

In [3]:
import torch
print(torch.__version__)

2.2.0


# Test Neighborhood Batching

In [4]:
network = pd.read_csv("../../Data/scGraphLLM/SMC23/network.tsv", sep='\t')
network.head()

Unnamed: 0,regulator.values,target.values,mi.values,scc.values,count.values,log.p.values
0,ZMYND15,BAMBI,0.010403,0.210074,1,-5.95313
1,ZMYND15,MMP23B,0.000566,0.004131,1,-5.95313
2,ZMYND15,CDC25C,0.019265,-0.222711,1,-5.95313
3,ZMYND15,UACA,0.001546,0.051827,1,-5.95313
4,ZMYND15,TCEA3,0.000566,0.048722,1,-5.95313


In [5]:
raw_count = sc.read_csv("../../Data/scGraphLLM/SMC23/rank_raw.csv")
raw_count

AnnData object with n_obs × n_vars = 5000 × 271

In [6]:
adata = raw_count
rank_mat = torch.tensor(adata.X)
rank_mat.shape

torch.Size([5000, 271])

In [7]:
genes = list(adata.obs_names)
len(genes)

5000

In [8]:
node_embedding = torch.rand(len(genes), 128)
node_embedding.shape

torch.Size([5000, 128])

In [16]:
dataloader, dataset = node_batching(node_embedding=node_embedding, ranks=rank_mat, 
                                    network=network, genes=genes, batch_size=64,
                                    neigborhood_size=-1, num_hops=2)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  edges['regulator.values'] = edges['regulator.values'].map(gene_to_node_index)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  edges['target.values'] = edges['target.values'].map(gene_to_node_index)


In [17]:
count = 0
for batch in dataloader:
    rank_embedding_batch = dataset.rank_embedding[batch.n_id]
    print("Batched graph embeddings:", batch.x.shape)
    print("Batched edge indices:", batch.edge_index.shape)
    print("Batched edge weights:", batch.edge_weight.shape)
    print("Batched rank embeddings:", rank_embedding_batch.shape)
    print("-----")
    count += 1
    if count > 3:
        break

Batched graph embeddings: torch.Size([820, 128])
Batched edge indices: torch.Size([2, 15044])
Batched edge weights: torch.Size([15044])
Batched rank embeddings: torch.Size([820, 271])
-----
Batched graph embeddings: torch.Size([814, 128])
Batched edge indices: torch.Size([2, 14556])
Batched edge weights: torch.Size([14556])
Batched rank embeddings: torch.Size([814, 271])
-----
Batched graph embeddings: torch.Size([821, 128])
Batched edge indices: torch.Size([2, 14672])
Batched edge weights: torch.Size([14672])
Batched rank embeddings: torch.Size([821, 271])
-----
Batched graph embeddings: torch.Size([813, 128])
Batched edge indices: torch.Size([2, 14116])
Batched edge weights: torch.Size([14116])
Batched rank embeddings: torch.Size([813, 271])
-----


# Test attention network and contrastive loss

In [12]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [18]:
a = torch.rand(20, 20)
b = torch.rand(20, 20)
c = torch.rand(20, 20)
d = torch.rand(20, 20)

In [19]:
from scGraphLLM.MLP_modules import MLPAttention

In [20]:
out, attn_w = MLPAttention(20, 8)(torch.stack([a, b, c, d], dim=1))
out.shape

torch.Size([20, 20])

In [21]:
attn_w = attn_w.squeeze()
attn_w.shape

torch.Size([20, 4])

In [17]:
from scGraphLLM.MLP_modules import ContrastiveLoss

In [24]:
label = [0, 1, 0, 1]
criterion = ContrastiveLoss(margin=1, verbose=True)
loss = criterion([a, b, c, d], label)
print("Contrastive loss:", loss.item())

Number of matches: 2
Number of mismatches: 4
Contrastive loss: 0.08056469261646271


In [25]:
label = [0, 0, 0, 1]
criterion = ContrastiveLoss(margin=1, verbose=True)
loss = criterion([a, b, c, d], label)
print("Contrastive loss:", loss.item())

Number of matches: 3
Number of mismatches: 3
Contrastive loss: 0.07045094668865204


In [26]:
label = [0, 1, 1, 1]
criterion = ContrastiveLoss(margin=1, verbose=True)
loss = criterion([a, b, c, d], label)
print("Contrastive loss:", loss.item())

Number of matches: 3
Number of mismatches: 3
Contrastive loss: 0.1009301245212555
