In [1]:
import anndata as ad
import umap

import pronto
import warnings
warnings.filterwarnings("ignore", category=pronto.warnings.ProntoWarning)


import pandas as pd
import numpy as np
from scipy import sparse
import copy
import time
import sys
import os
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler

from torcheval.metrics.functional import multilabel_accuracy

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style(style='whitegrid')
sns.set_context(context='notebook')
plt.rc('figure', autolayout=True)
plt.rc(
    'axes',
    labelweight='bold',
    labelsize='large',
    titleweight='bold',
    titlesize=9,
    linewidth=4
    )

%matplotlib inline

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Load the data and Cell Ontology. 

Great-lakes does not have internet access, so we pull in the data outside first, then load it here as an AnnData object. The Cell Ontology also needs to be loaded for access. 

In [3]:
# change into scratch directory (where data is saved)
#os.chdir('/scratch/welchjd_root/welchjd99/fujoshua')
os.chdir('/nfs/turbo/umms-welchjd/mccell')

FileNotFoundError: [Errno 2] No such file or directory: '/nfs/turbo/umms-welchjd/mccell'

### Load Single Dataset


In [5]:
#adata = ad.read_h5ad('small_1044_900') # 42000 cells
#adata = ad.read_h5ad('leaf_list_in_cl') # 658000 cells
#adata = ad.read_h5ad('leaf_list_leukocyte_24Aug') # 516000 cells, 60k genes
#adata = ad.read_h5ad('leaf_list_leukocyte_6Sep_coding_genes') # 516k cells, 20k genes
#adata = ad.read_h5ad('leaf_list_hematopoietic_14Sep_coding_genes') # 549k cells, 20k genes, upper level hematopoietic (0000988)
#adata = ad.read_h5ad('1044_624_895_27Sep_coding_genes') # 230k cells, only CL1044 (leaf), 624 (internal), and 895 (leaf) including leaf and internal node
#adata = ad.read_h5ad('13Oct_2int_3leaf') #379k cells, 2 internal (576,624), 3 leaf (1044,895,2057)

#adata = ad.read_h5ad('24Oct_hematopoietic_cells_p2') # 472k cells, part of the full hema data set
adata = ad.read_h5ad('24Oct_hematopoietic_cells_p2_subsample') # 47k cells, part of the full hema data set

### Load Multiple Datasets and combine

In [None]:
adata1 = ad.read_h5ad('24Oct_hematopoietic_cells_p1') # 1.98 million cells
adata2 = ad.read_h5ad('24Oct_hematopoietic_cells_p2') # 472k cells


In [None]:
adata = ad.concat([adata1,adata2]) #2.45 million

In [None]:
# delete adata1 and adata2 to save memory
del adata1
del adata2

In [None]:
adata

In [None]:
adata1

In [None]:
adata2

In [6]:
os.getcwd()

'/Users/josh.fuchs/My Drive/Personal/michigan'

### Load the Cell Ontology

You can visualize the ontology using https://www.ebi.ac.uk/ols4/ontologies/cl

And you can download the ontology file here: https://obofoundry.org/ontology/cl.html

In [None]:
os.chdir('/home/fujoshua/cell_classification')

In [7]:
cl = pronto.Ontology.from_obo_library('cl.owl')


## Checking memory allocations


In [None]:
print(torch.cuda.get_device_properties(0).total_memory*1e-9)
print(torch.cuda.memory_reserved()*1e-9)
print(torch.cuda.memory_allocated()*1e-9)
#print(torch.cuda.mem_get_info())


In [8]:
def sizeof_fmt(num, suffix='B'):
    ''' by Fred Cirera,  https://stackoverflow.com/a/1094933/1870254, modified'''
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)


In [9]:
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in list(
                          locals().items())), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))


                         adata: 806.6 MiB
                  LabelEncoder:  1.0 KiB
                           _i1:  966.0 B
                           _i4:  913.0 B
                           _i5:  913.0 B
                      datetime:  416.0 B
                            _i:  371.0 B
                           _i8:  371.0 B
                           _i9:  260.0 B
                           _oh:  232.0 B


In [None]:
del adata

## Data and Ontology Preprocessing

To prepare the data for modeling, we need to perform some preprocessing on the data and the ontology. We'll use three functions to make this happen. Full descriptions of these functions can be found in the functions. 


- set_internal_node_values: build a dictionary to set which internal nodes are to be used in the loss calculation for internal nodes in the data

- build_parent_mask: builds a masking matrix to use for masking internal node loss values

- preprocess_data_ontology: this function encodes the AnnData object, splits apart the target values and primary data, and calculates some important variables from the Cell Ontology for later use 

- transform_data: transforms the data with log(1+x)

- split_format_data: splits the data into train and validation sets, and moves the variables to PyTorch tensors 


In [10]:
def set_internal_node_values(internal_values,all_parent_nodes):
    '''
    Creates a dictionary where each key is an internal cell type and the values are the cell types
    we want to include when calculating the loss. We do not want to consider direct descendents of the
    internal cell type, so those are removed. 
    
    In other words, when calculating the loss for an internal node, we want to include all internal 
    nodes in the ontology EXCEPT those that are direct descendants of the target internal node. 
    
    Parameters
    ----------
    internal_values : list
        list of internal values that are included in the dataset
            
    all_parent_nodes : list
        from the dataset, a list of parent nodes in the ontology. Used to remove portions of
        the Ontology where we do not have child data
    
    Returns
    ----------
    parent_dict : dictionary
        keys are internal_values and values are all internal cell ontology terms EXCEPT descendents 
        of the internal value. The internal value is always included
    '''
    
    parent_dict = {}

    # loop through each value to calculate the values to include in parent_dict for that
    # internal value
    for internal_node in internal_values:
        # 1) get the children of this internal_node
        child_nodes = []
        for term in cl[internal_node].subclasses(distance=None,with_self=False).to_set():
            child_nodes.append(term.id)
        
        # 2) remove those values from all_parent_nodes
        cell_types_to_include = [x for x in all_parent_nodes if x not in child_nodes]
        
        # 3) create dictionary
        parent_dict[internal_node] = cell_types_to_include
    
    return(parent_dict)


In [11]:
def build_parent_mask(leaf_values,internal_values,ontology_df,parent_dict):
    '''
    Function to build a masking matrix for use when calculating the internal loss
    
    Uses parent_dict to denote, for internal cell types, which parents to include in the loss
    calculation. 
    
    Parameters
    -------
    leaf_values : list
        list composed of all leaf values included in the dataset
        includes internal nodes that do not have sub-values in the dataset, and thus are
        treated an leaf nodes

    internal_values : list
        list composed of interanal nodes in the dataset

    ontology_df : pandas dataframe
        pandas dataframe where indices (rows) are all leaf and parent cell IDs from the portion of 
        the ontology being queried, and columns are all leafs in portion of ontology being queried. 
        
        Dataframe is binary. For each parent node, element = 1 if parent node is an ancestor
        of corresponding leaf node.
        
    parent_dict : dictionary
        keys are internal_values and values are all internal cell ontology terms EXCEPT descendents 
        of the internal value. The internal value is always included
    
    Returns
    -------
    cell_parents_mask : tensor
        tensor of shape ik, where i = parent IDs and k = each cell type in the dataset
        binary tensor where 1 means for that cell type, that parent ID will be included
        in the internal loss calculation
        and 0 means for that cell type, that parent ID is excluded in the internal loss
        calculation
    
    '''
    num_leafs = len(leaf_values)
    num_parents = ontology_df.shape[0]

    # internal_values are included as column values AND rows


    # for cell_parents_to_include, each column is a cell type included in the
    # dataset, so it is length = len(leaf_values) + len(internal_values)
    # the row values are the total number of parents included for the dataset 
    # for each internal value, we need to pick (1/0) if we include that parent
    # for the loss. For this, we reference parents_dict
    # WHAT is the order of the cell IDs for the rows???? This is important
    # This needs to match what we are already doing later, so let's go figure that out FIRST. 

    # for the leaf values, we want to include ALL parents in the 
    # loss calculation. So, we initialize the tensor as a ones tensor
    # based on the number of leaf values and the number of parents
    cell_parents_mask = torch.ones(num_parents,num_leafs)

    # now we can deal with the internal values. For these, we will not
    # include all parents. We will use parent_dict to select which to include


    # first, get a list of all the parents. The ordering of this list
    # is used later to propogate probabilities up the ontology.
    list_of_parents = ontology_df.index.tolist()

    # now, we need to loop through each internal value
    # internal_values is ordered as -9999 + n
    # this will be helpful later when we need to pull these values out. 
    # so the columns here are ordered at 0 to (number of leaf values), then -9999
    # to (number of internal values)

    for cell_id in internal_values:
        # get the list of parent cell IDs we want to include for this
        # particular internal_values
        parent_list_for_cell = parent_dict[cell_id]

        # loop through the parent_list_for cell, create a new binary list where
        # list is 1 if the parent is in the list_of_parents, otherwise 0
        parent_binary_list = [1 if parent in parent_list_for_cell else 0 for parent in list_of_parents]

        # convert the list to a tensor and reshape for concatenation
        parent_binary_tensor = torch.tensor(parent_binary_list).reshape(-1,1)

        # append to cell_parents_to_include. 
        # we append along columns
        cell_parents_mask = torch.cat((cell_parents_mask,parent_binary_tensor),1)

    return(cell_parents_mask)



In [12]:
def preprocess_data_ontology(adata, target_column,upper_limit = None, cl_only = False, include_leafs = False):
    '''
    This function perfroms preprocessing on ann AnnData object to prepare it for modelling. It will encode the 
    target column and returns x_data and y_data for modelling
    
    This function also preprocesses the ontology to build a pandas dataframe that can be used to 
    calculate predicted probabilities. This will enable simple matrix multiplication to calculate
    probabilities and loss.
    
    Can have an upper limit to the ontology if upper_limit is set
    
    
    Assumes there is an active census object already open as cl. 

    
    
    Parameters
    ----------
    adata : AnnData Object
        existing AnnData object to perform processing on 
        
    target_column : string
        string of target column (from cell metadata) to encode
     
    upper_limit : string
        if you want to specify an upper limit in the ontology, set this to 
        the upper limit (inclusive)
        Default: None (no limit to ontology)
        
    cl_only : boolean
        option to only include the Cell Ontology (CL) in the dataframe
        True means only those cell IDs that start with CL are included
        Default: False
        
    include_leafs : boolean
        option to include leafs in the list of parent cell IDs
        Default is False because we are calculating the leaf loss differently
        Default: False
        
    Returns
    -------
    x_data : SciPy Matrix
        scipy sparse CSR matrix
    
    y_data : Series
        Pandas Series of encoded target values
        
    mapping_dict : Dictionary
        dictionary mapping the Cell Ontology IDs (keys) to the encoded values (values)
        Values >= 0 are leaf nodes
        Values < 0 are internal nodes

    leaf_values : list
        list composed of all leaf values included in the dataset
        includes internal nodes that do not have sub-values in the dataset, and thus are
        treated an leaf nodes

    internal_values : list
        list composed of interanal nodes in the dataset

    ontology_df : pandas dataframe
        pandas dataframe where indices (rows) are all leaf and parent cell IDs from the portion of 
        the ontology being queried, and columns are all leafs in portion of ontology being queried. 
        
        Dataframe is binary. For each parent node, element = 1 if parent node is an ancestor
        of corresponding leaf node.
        
    parent_dict : dictionary
        keys are internal_values and values are all cell ontology terms within the same distance
        from the top node. 
        
    cell_parent_mask : tensor
        tensor of shape ik, where i = parent IDs and k = each cell type in the dataset
        binary tensor where 1 means for that cell type, that parent ID will be included
        in the internal loss calculation
        and 0 means for that cell type, that parent ID is excluded in the internal loss
        calculation

    
    '''
    
    # select the labels. 
    labels = adata.obs
    
    # encode the target column
    #lb = LabelEncoder()
    #labels['encoded_labels'] = lb.fit_transform(labels[target_column])
    
    # we want to only encode the targets that are leafs. We will leave 
    # internal nodes as the CL number in order to assist with masking 
    # the appropriate parent nodes 
    # first, get list of all cell values
    all_cell_values = labels[target_column].unique().to_list()
    
    # identify which values are leafs
    # we use positive number for leaf values
    # and negative number for internal nodes
    mapping_dict = {}
    leaf_values = []
    internal_values = []
    encoded_leaf_val = 0
    encoded_internal_val = -9999
    for term in all_cell_values:
        if cl[term].is_leaf():
            mapping_dict[term] = encoded_leaf_val
            leaf_values.append(term)
            encoded_leaf_val += 1
        else:
            # check if internal values have associated sub-values in the dataset
            #    sub-values do not have to be leafs
            # if so, add value as internal values
            # if not, prune ontology so consider 
            term_subvalues = []
            # get leaf values of this term
            for sub_term in cl[term].subclasses(distance=None,with_self=False).to_set():
                    term_subvalues.append(sub_term.id)
            
            # get values in all_call_values in term_leafs
            intersection_list = list(set(all_cell_values).intersection(term_subvalues))
            if len(intersection_list) == 0:
                mapping_dict[term] = encoded_leaf_val
                leaf_values.append(term)
                encoded_leaf_val += 1
            else:
                mapping_dict[term] = encoded_internal_val
                internal_values.append(term)
                encoded_internal_val += 1            
            
            
    # use the leaf_mapping_dict to 
    labels['encoded_labels'] = labels[target_column].map(mapping_dict)
    
    x_data = adata.X.copy()
    y_data = labels['encoded_labels']
    
    #########
    # now get a list of all parent nodes for each value in the dataset
    # if we want to include leafs, set with_self= True
    # else, set with_self = False
    
    all_parent_nodes = []
    for target in all_cell_values:
        for term in cl[target].superclasses(distance=None,with_self=include_leafs).to_set():
            all_parent_nodes.append(term.id)
            #if target == 'CL:0000904':
            #    print(term)
            
    # ensure that we do not have duplicate values
    all_parent_nodes = list(set(all_parent_nodes))

    # select only the Cell Ontology IDs if cl_only = True
    if cl_only:
        all_parent_nodes = [x for x in all_parent_nodes if x.startswith('CL')]
    
    # if there is an upper limit, 
    if upper_limit is not None:
        # get upper limit nodes
        upper_limit_nodes = []
        for term in cl[upper_limit].superclasses(distance=None,with_self=False).to_set():
            upper_limit_nodes.append(term.id)

        # remove these nodes from the parent_nodes list
        all_parent_nodes = [x for x in all_parent_nodes if x not in upper_limit_nodes]
        
    # create a dictionary that maps parents to reduce the ontology_df when
    # dealing with internal nodes
    #parent_dict = set_internal_node_relationships_by_depth(internal_values,upper_limit,all_parent_nodes)
    parent_dict = set_internal_node_values(internal_values,all_parent_nodes)
    
    # create the dataframe
    # use all_cell_values for the columns, because we need both leafs and
    # internals nodes for mapping
    ontology_df = pd.DataFrame(data=0, index = all_parent_nodes,
                                              columns = all_cell_values)
    
    # populate the dataframe with 1 if column is a sub-node 
    # for that particular cell ID
    # with_self = True because we need to include the leafs here
    for cell_id in ontology_df.index:
        for term in cl[cell_id].subclasses(distance=None,with_self=True).to_set():
            if term.id in ontology_df.columns:
                ontology_df.loc[cell_id,[term.id]] = [1]

    # create a dictionary that maps parents to reduce the ontology_df when
    # dealing with internal nodes
    #parent_dict = {}
    #for parent in internal_values:
    #    super_parent_list = []
    #    for term in cl[parent].superclasses(distance=None,with_self=True).to_set():
    #         if term.id in all_parent_nodes:
    #            super_parent_list.append(term.id)
    #    parent_dict[parent] = super_parent_list

    # build a matrix used to mask parent values
    cell_parent_mask = build_parent_mask(leaf_values,internal_values,ontology_df,parent_dict)
    
    return(x_data,y_data, mapping_dict, leaf_values, internal_values, ontology_df, parent_dict, cell_parent_mask)




In [13]:
def transform_data(x_data):
    '''
    This function takes the input x_data, transforms the data with log(1+x) and 
    returns the transformed data
    
    Parameters
    ----------
    x_data : scipy matrix
        scipy sparse CSR matrix  
        
    Returns
    -------
    x_data : SciPy Matrix
        scipy sparse CSR matrix
    
    '''
    
    # np.log takes the natural log
    x_data.data = np.log(1+ x_data.data)

    return x_data


In [14]:
def split_format_data(x_data, y_data, train_size, val_size, holdout_size = None, random_state = None):
    '''
    This function splits x_data and y_data into training and validation sets, then formats the data into
    tensors for modeling with PyTorch
    
    
    Parameters
    ----------
    x_data : scipy matrix
        scipy sparse CSR matrix  
        
    y_data : Series
        Pandas series of encoded target values
        
    train_size: float
        float between 0.0 and 1.0 to select the training fraction of the data set
        
        
    Returns
    -------
    x_data : SciPy Matrix
        scipy sparse CSR matrix
    
    y_data : Series
        Pandas Series of encoded target values
    
    X_train : Tensor
        pytorch tensor of training values
    
    X_val : Tensor
        pytorch tensor of validation values
        
    y_train : Tensor
        pytorch tensor of training target values
        
    y_val : Tensor
        pytorch tensor of validation target values
    '''
    
    
    if holdout_size:
        # split into training and validation sets
        # first split into train and validation/holdout
        X_train, X_val_holdout, y_train, y_val_holdout = train_test_split(x_data,y_data,
                                                       train_size = train_size,
                                                         random_state=random_state)

        # calculate the validation split of the remainder 
        val_split_size = val_size / (val_size + holdout_size)

        # split the validation/holdout set to separate sets
        X_val, X_holdout, y_val, y_holdout = train_test_split(X_val_holdout, y_val_holdout,
                                                             train_size = val_split_size,
                                                             random_state=random_state)
    else:
        # split into training and validation sets
        # first split into train and validation/holdout
        X_train, X_val, y_train, y_val = train_test_split(x_data,y_data,
                                                       train_size = train_size,
                                                         random_state=random_state)

    
    # check if number of genes in X_train = X_val
    # rarely, splitting the dataset can cause a difference
    # if it does, resplit the data by looping back into this function

    if holdout_size:
        if (X_train.shape[1] != X_val.shape[1]) or (X_train.shape[1] != X_holdout.shape[1]):
            X_train, X_val, y_train, y_val, X_holdout, y_holdout = split_format_data(x_data, y_data, train_size, val_size, holdout_size, random_state = None)
        else:
            print('Success. Number of genes in datasets match.')
    else:
        if (X_train.shape[1] != X_val.shape[1]):
            X_train, X_val, y_train, y_val, X_holdout, y_holdout = split_format_data(x_data, y_data, train_size, val_size, holdout_size, random_state = None)
        else:
            print('Success. Number of genes in datasets match.')


    
    # convert the data to tensors
    # we'll change the data from CSR (compressed sparse row) format
    # to COO (coordinate) format for better use with pytorch
    # see https://pytorch.org/docs/stable/sparse.html for additional details
    # conversion from COO to tensor based on https://stackoverflow.com/questions/50665141/converting-a-scipy-coo-matrix-to-pytorch-sparse-tensor

    # copy the X matrix to save in scipy CSR format
    x_train_csr = X_train.copy()
    

    X_train_coo = X_train.tocoo()
    #X_train = torch.sparse.FloatTensor(torch.LongTensor(np.vstack((X_train_coo.row,X_train_coo.col))),
    #                              torch.FloatTensor(X_train_coo.data))

    X_train_values = X_train_coo.data
    X_train_indices = np.vstack((X_train_coo.row, X_train_coo.col))

    X_train_i = torch.LongTensor(X_train_indices)
    X_train_v = torch.FloatTensor(X_train_values)
    X_train_shape = X_train_coo.shape

    X_train = torch.sparse.FloatTensor(X_train_i, X_train_v, torch.Size(X_train_shape))  
    
    
    # y_train is a Series, so it is easier to convert to a tensor
    y_train = torch.tensor(y_train,device=device)#, dtype=torch.long)

    # and the same for the validation set
    X_val_coo = X_val.tocoo()
    #X_val = torch.sparse.FloatTensor(torch.LongTensor(np.vstack((X_val_coo.row,X_val_coo.col))),
    #                                  torch.FloatTensor(X_val_coo.data))

    
    X_val_values = X_val_coo.data
    X_val_indices = np.vstack((X_val_coo.row, X_val_coo.col))

    X_val_i = torch.LongTensor(X_val_indices)
    X_val_v = torch.FloatTensor(X_val_values)
    X_val_shape = X_val_coo.shape

    X_val = torch.sparse.FloatTensor(X_val_i, X_val_v, torch.Size(X_val_shape)).to(device)  
    
    
    # y_val is a Series, so it is easier to convert to a tensor
    y_val = torch.tensor(y_val,device=device)#, dtype=torch.long)

    if holdout_size:
        #pass
        # and the same for the holdout set
        X_holdout_coo = X_holdout.tocoo()
        #X_holdout = torch.sparse.FloatTensor(torch.LongTensor(np.vstack((X_holdout_coo.row,X_holdout_coo.col))),
        #                              torch.FloatTensor(X_holdout_coo.data))

        X_holdout_values = X_holdout_coo.data
        X_holdout_indices = np.vstack((X_holdout_coo.row, X_holdout_coo.col))

        X_holdout_i = torch.LongTensor(X_holdout_indices)
        X_holdout_v = torch.FloatTensor(X_holdout_values)
        X_holdout_shape = X_holdout_coo.shape

        X_holdout = torch.sparse.FloatTensor(X_holdout_i, X_holdout_v, torch.Size(X_holdout_shape)).to(device)   

        
        # y_holdout is a Series, so it is easier to convert to a tensor
        y_holdout = torch.tensor(y_holdout, dtype=torch.long,device=device)
    
    
    if device.type == "cuda":
        print('y_train, X_val, y_val moved to GPU. X_train not moved to GPU. ')
            
    if holdout_size:
        return(X_train, X_val, y_train, y_val, X_holdout, y_holdout, x_train_csr)
    else:
        return(X_train, X_val, y_train, y_val,x_train_csr)

## Main Loop For Preprocessing Data

In [15]:
target_column = 'cell_type_ontology_term_id'

upper_limit = 'CL:0000988' # leukocyte = 738, hematopoietic = 988

print('start preprocess data and ontology')
x_data,y_data, mapping_dict, leaf_values,internal_values, \
    ontology_df, parent_dict, cell_parent_mask =  preprocess_data_ontology(adata, target_column,
                                                                           upper_limit = upper_limit, 
                                                                 cl_only = True, include_leafs = False)

###del adata

# create dataframe that only includes leaf nodes
ontology_leaf_df = ontology_df[leaf_values]

print('start transforming data')
x_data = transform_data(x_data)


train_size = 0.8
val_size = 0.2
holdout_size =  None # None if you don't want holdout set
random_state = 42

print('start split and format data')
if holdout_size:
    X_train, X_val, y_train, y_val, \
    X_holdout, y_holdout, x_train_csr = split_format_data(x_data, y_data, train_size, val_size, 
                                                          holdout_size, random_state=random_state)
else:
    X_train, X_val, y_train, y_val, \
    x_train_csr = split_format_data(x_data, y_data, train_size, val_size, 
                                    holdout_size, random_state=random_state)

print('Preprocessing complete. There are {0} leaf values and {1} internal values.'.format(len(leaf_values),len(internal_values)
                                                                                         ))
print('There are {0} cells in the training set and {1} cells in the validation set, both contain {2} genes.'.format(X_train.shape[0],X_val.shape[0],X_train.shape[1]))


start preprocess data and ontology
start transforming data
start split and format data
Success. Number of genes in datasets match.
Preprocessing complete. There are 52 leaf values and 16 internal values.
There are 37787 cells in the training set and 9447 cells in the validation set, both contain 19966 genes.


## Save Results of Preprocessing to Disk for Use in Modeling

- X_train,X_val,y_train,y_val (do this only if we do manual batching. If we switch to DataLoader, we might not want to split the data. Come back to this)
- cell_parent_mask
- Mapping_dict
- Ontology_df
- Internal_values
- leaf_values


In [None]:
# change to directory you want to save the results in
os.chdir('/home/fujoshua/cell_classification')


In [18]:
# get today's date for saving information about this model
today = datetime.today().strftime('%Y-%m-%d')


In [20]:
# save information needed for testing external models 

ontology_df_name = today + '_ontology_df.csv'
ontology_df.to_csv(ontology_df_name)

mapping_dict_name = today + '_mapping_dict_df.csv'
mapping_dict_df = pd.DataFrame.from_dict(mapping_dict,orient='index')
mapping_dict_df.to_csv(mapping_dict_name)

leaf_values_name = today + '_leaf_values'
internal_values_name = today + '_internal_values'

with open(leaf_values_name, "wb") as fp:   #Pickling
    pickle.dump(leaf_values, fp)
    
with open(internal_values_name, "wb") as fp:   #Pickling
    pickle.dump(internal_values, fp)

X_train_name = today + '_X_train.pt'
X_val_name = today + '_X_val.pt'
y_train_name = today + '_y_train.pt'
y_val_name = today + '_y_val.pt'

torch.save(X_train, X_train_name)
torch.save(y_train,y_train_name)
torch.save(X_val, X_val_name)
torch.save(y_val,y_val_name)
    
cell_parent_mask_name = today + '_cell_parent_mask.pt'
torch.save(cell_parent_mask,cell_parent_mask_name)