In [50]:
import os
import gc
import argparse
import json
import random
import math
import random
from functools import reduce
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *

## Create guidance mask based on scout results

Set:
- `mask_name` to name of the guidance mask
- `n` to the number of scouts you wish to use to create the guidance mask

In [51]:
path = "ckpts/recount3_mouse_15117_bulk_pretrain_10.pth"
ckpt = torch.load(path)

In [52]:
len(ckpt['model_state_dict'].keys())

91

In [57]:
mask_name = "mask17guidepred" # set a new name for each guidance mask
def create_mask(cousins, begins, model_checkpoint):
    def _compute_mask(model, cousin_dict_list):
        def learning_spread(model, cousin):
            # this code uses squared difference, but other options like absolute difference can work too
            ls = torch.mean((cousin - model) ** 2, 0)
            
            # this code performs 0-1 normalization, but other normalization methods can work too
            min_ls = torch.min(ls)
            max_ls = torch.max(ls)
            spread = max_ls - min_ls
            if spread == 0:
                return (ls - min_ls) * 0
            return (ls - min_ls) / spread

        out_mask = {}
        model_dict = model

        cousins = {}
        bs = {}
        for x in range(len(cousin_dict_list)):
            cousin_dict = cousin_dict_list[x]
            begin_dict = begins[x]
            # for name, param in model_dict.items():
            for name, param in begin_dict.items():
                if name in cousin_dict:
                    cousins.setdefault(name, [])
                    bs.setdefault(name, [])
                    cousins[name].append(cousin_dict[name])  # TODO: this should be done differently...
                    bs[name].append(begin_dict[name])
        #for name, param in model_dict.items():
        for name in begins[0]:
            if name in cousins:
                if name in model_dict:
                    out_mask[name] = learning_spread(model_dict[name].float(), torch.stack(cousins[name]))
                else:
                    out_mask[name] = learning_spread(torch.stack(bs[name]), torch.stack(cousins[name]))

        return out_mask
    
#     bin_num = 7
#     CLASS = bin_num + 2
    
#     gene_num = 15117
#     SEQ_LEN = gene_num + 1
    
#     pos_embed = True
#     POS_EMBED_USING = pos_embed
    
#     model = PerformerLM(
#         num_tokens = CLASS,
#         dim = 200,
#         depth = 6,
#         max_seq_len = SEQ_LEN,
#         heads = 10,
#         local_attn_heads = 0,
#         g2v_position_emb = POS_EMBED_USING
#     )

    ckpt = torch.load(model_checkpoint,  map_location = torch.device('cpu'))
    model = ckpt['model_state_dict']
    #print(model)
    
    mask = _compute_mask(model, cousins)
    torch.save(mask, f'masks/{mask_name}.pt')
    return mask

In [58]:
model_checkpoint = "ckpts/recount3_mouse_15117_bulk_pretrain_continuation_1_bs_3_lr10e4_16.pth"
cousins = []
begins = []
n = 18 # set to number of scouts you want to use to create guidance mask
for i in range(1, n):
    path = f"ckpts/scout_pairs_{i}_best.pth"
    begin_path = f"ckpts/scout_pairs_{i}_begin.pth"
    cousins.append(torch.load(path, map_location = torch.device('cpu'))["model_state_dict"])
    begins.append(torch.load(begin_path, map_location = torch.device('cpu'))["model_state_dict"])


In [59]:
mask = create_mask(cousins, begins, model_checkpoint)

## Check the output

In [60]:
len(torch.load("masks/mask17guidepred.pt"))

97

In [None]:
# Check proportion of guide values that are greater than 0.5 (arbitrary)
np.count_nonzero(mask['to_out.conv1.weight'] > 0.5)/np.count_nonzero(mask['to_out.conv1.weight'] > -1)