In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
import gc
import sys
from latentmi import lmi
from data_utils import *
import seaborn as sns
import matplotlib.pyplot as plt

from nanoTxformer.model import nanoTxformer
from nanoTxformer.train import train_model, generalization_loss
from nanoTxformer.utils import get_mean_pooled_embeddings

torch.set_float32_matmul_precision('medium')



In [2]:
d = {
    "MI" : [],
    "Layers" : [],
    "Embedding size" : [],
    "Heads" : [],
    "Val. loss" : [],
    "Extract layer" : []
}

ad = sc.read_h5ad('../../scaling_playground/data/PBMC_CITEseq_Q%.3f_rep%d.h5ad'
                  %(1., 0))

seen, held_out = split_ad(ad, frac=0.75)

In [3]:
layers = [2, 6]
heads = [4]
embs = [128, 256]
for l in layers:
    for h in heads:
        for e in embs:

            ad1, _ = split_ad(seen, frac=0.25)
            print(len(ad1))

            # batches = 10**4
            # batch_size = 32
            # epochs = batches * batch_size // len(seen)

            # print(epochs)

            model = nanoTxformer(ad1, embed_size=e, num_heads=h, num_encoder_layers=l).cuda()
            
            train_losses, val_losses = train_model(model, ad1, epochs=10**4)
            
            emb = get_mean_pooled_embeddings(model, held_out).cpu()
            
            emb_penultimate = get_mean_pooled_embeddings(model, held_out, layer_index=l-2).cpu()

            pmis, _, _ = lmi.estimate(emb, held_out.obsm['protein_counts'], 
                                quiet=True, batch_size=2048)
            
            d["MI"].append(np.nanmean(pmis))
            d["Layers"].append(l)
            d["Embedding size"].append(e)
            d["Heads"].append(h)
            d["Val. loss"].append(val_losses)
            d['Extract layer'].append('Last')

            pmis, _, _ = lmi.estimate(emb_penultimate, held_out.obsm['protein_counts'], 
                    quiet=True, batch_size=2048)
            
            d["MI"].append(np.nanmean(pmis))
            d["Layers"].append(l)
            d["Embedding size"].append(e)
            d["Heads"].append(h)
            d["Val. loss"].append(val_losses)
            d['Extract layer'].append('Second to last')

            print(d)


28517
epoch 2/10000 (batch 1000) - train loss: 7.4290, val loss: 7.4170
epoch 3/10000 (batch 2000) - train loss: 7.0215, val loss: 6.8082
epoch 5/10000 (batch 3000) - train loss: 5.8815, val loss: 5.7896
epoch 6/10000 (batch 4000) - train loss: 4.7728, val loss: 4.4709
epoch 8/10000 (batch 5000) - train loss: 3.1957, val loss: 3.1783
epoch 9/10000 (batch 6000) - train loss: 2.3811, val loss: 2.2802
epoch 10/10000 (batch 7000) - train loss: 1.9330, val loss: 1.8378
epoch 12/10000 (batch 8000) - train loss: 1.6496, val loss: 1.6363
epoch 13/10000 (batch 9000) - train loss: 1.5470, val loss: 1.5314
epoch 15/10000 (batch 10000) - train loss: 1.4236, val loss: 1.4579
epoch 16/10000 (batch 11000) - train loss: 1.3961, val loss: 1.3934
epoch 17/10000 (batch 12000) - train loss: 1.3510, val loss: 1.3388
epoch 19/10000 (batch 13000) - train loss: 1.3037, val loss: 1.2995
epoch 20/10000 (batch 14000) - train loss: 1.2698, val loss: 1.2687
epoch 22/10000 (batch 15000) - train loss: 1.2411, val lo

In [5]:
pd.DataFrame(d).to_csv('../results/Tx_param_screen.csv')