In [None]:
# versioning issues, run this cell once as a temporary fix

! pip install numpy==1.24.4
! pip install torch==2.0.1 torchvision==0.15.2 torchtext==0.15.2
! pip install scgpt


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting numpy==1.24.4
  Downloading numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m181.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pyro-ppl 1.9.1 requires torch>=2.0, but you have torch 1.8.1 which is incompatible.
pytorch-lightning 1.9.5 requires torch>=1.10.0, but you have torch 1.8.1 which is incompatible.
scgpt 0.2.4 requires torch>=1.13.0

In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import torch

# Standard library imports
import os
import sys
import json
import math
import random
import argparse
from collections import defaultdict
from pathlib import Path
import datetime
from typing import Dict, List, Tuple, Set, Union, Optional

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
import torchvision
from torchvision import transforms
import wandb
import scanpy as sc
from tqdm import tqdm
import sklearn.model_selection
from PIL import Image
import seaborn as sns
from numba import njit, prange
from scipy.stats import wasserstein_distance
from scipy.spatial import cKDTree
import tangram as tg
import imageio.v3 as iio
import cv2
import scgpt
import timm
from einops import rearrange
from torch import einsum
import torch.nn.utils as U

from schaf_method import *

os.environ["CUDA_VISIBLE_DEVICES"] = '7'  # replace as needed 

# Configure system settings
Image.MAX_IMAGE_PIXELS = 933120000  # Allow loading large images
DEVICE = torch.device("cuda:7")  # replace as needed 
device = DEVICE
NUM_WORKERS = 6 if torch.cuda.is_available() else 2
PIN_MEMORY = torch.cuda.is_available()
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,40).__str__()


In [2]:
# Create example single-cell data
reference = sc.datasets.pbmc3k()
real_gene_names = reference.var_names.tolist()

HOLD_OUT_FOLD = 1
fold_to_he = {}
fold_to_he_xs = {}
fold_to_he_ys = {}
fold_to_sc = {}
for fold in range(4):
    n_cells = 1000
    n_genes = 2000
    expression_matrix = np.random.negative_binomial(5, 0.3, size=(n_cells, n_genes))
    adata_sc = sc.AnnData(X=expression_matrix)
    adata_sc.obs.index = [f"cell_{i}" for i in range(n_cells)]
    adata_sc.var.index = real_gene_names[:n_genes]
    adata_sc.obs['cluster'] = np.random.choice([0, 1, 2, 3], n_cells)
    adata_sc.var['gene_col'] = list(adata_sc.var.index)
    sc.pp.log1p(adata_sc)
    fold_to_sc[fold] = adata_sc

    n_cells = 1000
    he_image = np.random.randint(0, 255, size=(1000, 1000, 3), dtype=np.uint8) / 255.
    fold_to_he[fold] = he_image
    fold_to_he_xs[fold] = np.random.choice(list(range(1000)), n_cells)
    fold_to_he_ys[fold] = np.random.choice(list(range(1000)), n_cells) 
    # all hes have 1000 random coords

embedding_maker = ViT_UNI().to(device)
embedding_maker = embedding_maker.eval()
fold_to_he_embed = {}
fold_to_sc_embed = {}
for fold in range(4):
    model_dir = f"/storage/ccomiter/htapp_supervise/new_schaf_experiment_scripts/final_new_schaf_start_jan2324/scgpt_model/scGPT_human/"
    embed_data = scgpt.tasks.embed_data(fold_to_sc[fold], model_dir, gene_col='gene_col', batch_size=64)
    new_adata = sc.AnnData(X=embed_data.obsm['X_scGPT'], obs=embed_data.obs)
    fold_to_sc_embed[fold] = new_adata
    
    the_ds = HistSampleDataset(
        fold_to_he[fold], 
        fold_to_he_xs[fold], 
        fold_to_he_ys[fold], 
        10,
    )
    the_dl = DataLoader(the_ds, batch_size=64, shuffle=0, num_workers=6, pin_memory=1)
    
    with torch.no_grad():
        embeds = []
        for batch in tqdm(the_dl):
            res = embedding_maker(batch.to(device).float()).cpu().detach().numpy()
            embeds.extend(res)    
        torch.cuda.empty_cache()
    embeds = np.array(embeds)
    embed_adata = sc.AnnData(X=embeds,)
    embed_adata.obs['x'] = fold_to_he_xs[fold]
    embed_adata.obs['y'] = fold_to_he_ys[fold]
    fold_to_he_embed[fold] = embed_adata


# fold_to_he_embed
# fold_to_sc_embed

for k, v in fold_to_he_embed.items():
    v.obs['fold'] = k
    fold_to_he_embed[k] = v
for k, v in fold_to_sc_embed.items():
    v.obs['fold'] = k
    fold_to_he_embed[k] = v

HOLD_OUT_FOLD = 1

import anndata as ad

train_he_embeds = ad.concat([fold_to_he_embed[k] for k in range(4) if k != HOLD_OUT_FOLD])
train_sc_embeds = ad.concat([fold_to_sc_embed[k] for k in range(4) if k != HOLD_OUT_FOLD])

test_he_embeds = fold_to_he_embed[HOLD_OUT_FOLD]
test_sc_embeds = fold_to_sc_embed[HOLD_OUT_FOLD]

# fold_to_he_embed
# fold_to_sc_embed

for k, v in fold_to_he_embed.items():
    v.obs['fold'] = k
    fold_to_he_embed[k] = v
for k, v in fold_to_sc_embed.items():
    v.obs['fold'] = k
    fold_to_he_embed[k] = v

HOLD_OUT_FOLD = 1

import anndata as ad

train_he_embeds = ad.concat([fold_to_he_embed[k] for k in range(4) if k != HOLD_OUT_FOLD])
train_sc_embeds = ad.concat([fold_to_sc_embed[k] for k in range(4) if k != HOLD_OUT_FOLD])

test_he_embeds = fold_to_he_embed[HOLD_OUT_FOLD]
test_sc_embeds = fold_to_sc_embed[HOLD_OUT_FOLD]

train_he_dl = DataLoader(UnpairedDataset(train_he_embeds.X, is_hist=1), batch_size=64, shuffle=1, num_workers=6, pin_memory=1)
train_sc_dl = DataLoader(UnpairedDataset(train_sc_embeds.X, is_hist=0), batch_size=64, shuffle=1, num_workers=6, pin_memory=1) 
test_he_dl = DataLoader(UnpairedDataset(test_he_embeds.X, is_hist=1), batch_size=64, shuffle=1, num_workers=6, pin_memory=1)
test_sc_dl = DataLoader(UnpairedDataset(test_sc_embeds.X, is_hist=0), batch_size=64, shuffle=1, num_workers=6, pin_memory=1)

generator = HEGen()
discriminator = Discriminator()
hist_decoder = HEDecoder()
sc_decoder = StandardDecoder(2000)

lr = 1e-4
# Initialize optimizers
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
hist_dec_optimizer = optim.Adam(hist_decoder.parameters(), lr=lr)
sc_dec_optimizer = optim.Adam(sc_decoder.parameters(), lr=lr)

train_decoder_dl = DataLoader(
    ConcatDataset([TensorDataset(
        torch.from_numpy(fold_to_sc_embed[k].X),
        torch.from_numpy(fold_to_sc[k].X),
    ) for k, v in fold_to_sc_embed.items() if k != HOLD_OUT_FOLD]),
    batch_size=128, shuffle=1, num_workers=6, pin_memory=1,
)


criter = nn.MSELoss()
num_epochs = 1
for epoch in tqdm(range(num_epochs)):
    epoch_loss = 0.0
    sc_decoder = sc_decoder.train().to(device)

    for _id, (cur_batch) in enumerate(train_decoder_dl):

        [latent, trans,] = cur_batch
        latent = latent.to(device)
        trans = trans.to(device)
        sc_dec_optimizer.zero_grad()
        predicted_labels = sc_decoder(latent.float())
        
        this_batch_loss = criter(
            predicted_labels,
            trans.float(),
        )
        
        this_batch_loss.backward()
        epoch_loss += this_batch_loss.item()
        sc_dec_optimizer.step()
 


scGPT - INFO - match 1569/2000 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████| 16/16 [00:30<00:00,  1.91s/it]
100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


scGPT - INFO - match 1569/2000 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████| 16/16 [00:29<00:00,  1.83s/it]
100%|██████████| 16/16 [00:10<00:00,  1.51it/s]


scGPT - INFO - match 1569/2000 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████| 16/16 [00:28<00:00,  1.78s/it]
100%|██████████| 16/16 [00:10<00:00,  1.48it/s]


scGPT - INFO - match 1569/2000 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████| 16/16 [00:29<00:00,  1.82s/it]
100%|██████████| 16/16 [00:10<00:00,  1.48it/s]
100%|██████████| 1/1 [00:00<00:00,  1.18it/s]


In [3]:
generator = generator.to(device)
discriminator = discriminator.to(device)
hist_decoder = hist_decoder.to(device)
sc_decoder = sc_decoder.to(device)

# gen_optimizer
# disc_optimizer
# hist_dec_optimizer
# sc_dec_optimizer

recon_criter = nn.MSELoss()
num_epochs = 1
for epoch in tqdm(range(num_epochs)):
    generator = generator.train()
    discriminator = discriminator.train()

    for _id, the_batch in enumerate(zip(
        train_sc_dl,
        train_he_dl,
    )):
        (latent_batch, hist_batch) = the_batch

        gen_optimizer.zero_grad()
        disc_optimizer.zero_grad()
        hist_dec_optimizer.zero_grad()
        latent_batch = latent_batch.to(device)
        hist_batch = hist_batch.to(device)

        hist_encoded = generator(hist_batch.float())
        hist_encoded = hist_encoded.detach()

        # put all tensors on the right device 
        source_label, target_label = [1., 0.], [0., 1.]
        encodeds = torch.cat((latent_batch, hist_encoded), axis=0)
        discrim_labels = torch.tensor(
            [source_label] * latent_batch.shape[0]
            + [target_label] * hist_encoded.shape[0]
        ).to(device)

        pred_discrim_labels = discriminator(encodeds.float())
        batch_discrim_loss = F.binary_cross_entropy_with_logits(
            pred_discrim_labels, discrim_labels,
        )

        batch_discrim_loss.backward()
        disc_optimizer.step()     


        #### second part
        for param in discriminator.parameters():
            param.requires_grad = False

        hist_encoded = generator(hist_batch.float())
        hist_discrim_preds = discriminator(hist_encoded)
        he_decoded = hist_decoder(hist_encoded)
            
        discrim_labels = torch.tensor([source_label] * hist_encoded.shape[0]).to(device)
        he_gen_loss = F.binary_cross_entropy_with_logits(
            hist_discrim_preds, discrim_labels, 
        )

        he_recon_loss = recon_criter(
            hist_batch.float(),
            he_decoded,
        )

        beta = .5
        together_loss = beta*he_gen_loss + he_recon_loss
        together_loss.backward()
        # he_gen_loss.backward()
        # he_recon_loss.backward()
        # hist_gen_epoch_loss += he_gen_loss.item()
        # gen_track += he_gen_loss.item() 
        gen_optimizer.step()
        hist_dec_optimizer.step()

        # undo the above
        for param in discriminator.parameters():
            param.requires_grad = True


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:02<00:00,  2.05s/it]


In [4]:

tm = TransferModel(generator, sc_decoder).eval()
tm = tm.eval().to(device)
# infer 
inferred_transcripts = []
with torch.no_grad():
    for images in test_he_dl:
        images = images.to(device)        
        outputs = tm(images)
        inferred_transcripts.extend(outputs.cpu().detach().numpy())
        
inferred_adata = sc.AnnData(X=np.array(inferred_transcripts))
inferred_adata.obs = test_he_embeds.obs
