In [1]:
import numpy as np
import h5py
import random
random.seed(42)
import uproot
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import joblib
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch.nn import BatchNorm1d
import pickle

In [2]:
def branches_from_root_file(filename):
    '''
    Returns the branches from a root file
    '''
    file = uproot.open(filename)
    tree = file[file.keys()[0]]
    branches = tree.arrays()
    return branches

In [3]:
def extractData(branches, dataName):
    return np.array(branches[dataName])

In [4]:
def create_data_dict(rangeValue, variables, dataDict):
    for i in range(rangeValue):
        dataDict[f"data_{i}"] = np.concatenate([np.expand_dims(var[i], axis=1) for var in variables], axis=1)
    return dataDict

In [5]:
def save_dict_to_hdf5(dic, filename):
    """Save a dictionary to an HDF5 file"""
    with h5py.File(filename, 'w') as f:
        _save_dict_to_hdf5(f, dic)


def _save_dict_to_hdf5(group, dic):
    """Save a dictionary to an HDF5 group"""
    for key, value in dic.items():
        if isinstance(value, dict):
            subgroup = group.create_group(key)
            _save_dict_to_hdf5(subgroup, value)
        else:
            if isinstance(value, list):
                """Convert list to numpy array before saving"""
                value = np.array(value)
            group[key] = value

In [6]:
# A code to remove permutation variant
def canonical_form(t):
    """Sorts elements of the tuple and converts the sorted list back into a tuple."""
    return tuple(sorted(t))

def remove_permutation_variants(tuple_list):
    """
    Creates a set of unique tuples by converting each tuple to its canonical form.
    Remove permutation variants from a list of tuples.
    Converts set back into a list of tuples.
    """
    unique_tuples = set(canonical_form(t) for t in tuple_list)
    return [tuple(sorted(t)) for t in unique_tuples]

In [7]:
# y = cell_to_cluster_index[i]
# i = pair[0]=z[0], j = pair[1]=z[1]
def cluster_cluster_true(y,i,j):
    return y[i]==y[j] and y[i]!=0
def lone_lone(y,i,j):
    return y[i]==y[j] and y[i]==0
def cluster_cluster_false(y,i,j):
    return y[i]!=y[j] and y[i]!=0 and y[j]!=0
def cluster_lone(y,i,j):
    return y[i]!=y[j] and y[i]!=0 and y[j]==0
def lone_cluster(y,i,j):
    return y[i]!=y[j] and y[i]==0 and y[j]!=0

# x = neighbor_pairs_unique_sorted
# y = cell_to_cluster_index[i]
# z = pair
def assign_index(mapping, x, y):
    out = []
    for pair in x:
        for index, test in mapping.items():
            if test(y,pair[0],pair[1]):
                out.append(index)
                continue
    return out

def neighbor_pairs_mapping(loneloneIndex, clusterloneIndex, 
                           loneclusterIndex, clusterclusterFalseIndex):
    '''
    Set the class value for the background types (integer excluding 1) for the cases 
    where the neighbor pairs both lone cells, one from a cluster and the other 
    a lone cell (and vice versa), or both are from differnt clusters
    '''
    pairs_mapping = {1: cluster_cluster_true, loneloneIndex: lone_lone, 
                     clusterloneIndex: cluster_lone, loneclusterIndex: lone_cluster, 
                     clusterclusterFalseIndex: cluster_cluster_false}
    return pairs_mapping

# list_of_pair_indices=neighbor_pairs_unique_sorted
# index of cluster cell is a part of =cell_to_cluster_index
def label_neighbor_pairs(range_value, cell_to_cluster_index, list_of_pair_indices, mapping):
    neighbor_labels = []
    for i in range(range_value):
        neighbor_labels.append(assign_index(mapping=mapping, x=list_of_pair_indices, y=cell_to_cluster_index[i]))    
    return np.array(neighbor_labels)

In [8]:
def hdf5_to_dict(hdf5_file):
    """
    Convert HDF5 file to Python dictionary
    """
    data_dict = {}
    _hdf5_to_dict(hdf5_file, data_dict)
    return data_dict
# Initializes an empty dictionary and calls a function to recursively
# fill this dictionary with data from the hdf5 file.


def _hdf5_to_dict(group, dic):
    """
    Convert HDF5 group to dictionary recursively
    """
    for key, item in group.items():
        if isinstance(item, h5py.Group):
            subgroup = {}
            _hdf5_to_dict(item, subgroup)
            dic[key] = subgroup
        else:
            dic[key] = np.array(item)
# Iterates over items in the hdf5 group. If the item is a group, 
# it creates a new dictionary and calls itself recursively. If the item
# is a dataset, it converts it to a numpy array and stores it in the dictionary.

In [9]:
def makeEdges(neighbor_truth, label):
    pair = np.where(neighbor_truth==label)
    return list(pair[0])

In [10]:
def indices_of_edges(range_value, neighbor_truth):
    '''
    bkg_0 represents the background with a truth value of 0
    bkg_2/3/4 represents the background with a truth value of 2/3/4
    '''
    true_list = []
    bkg_0 = []
    bkg_2 = []
    bkg_3 = []
    bkg_4 = []
    for i in range(range_value):
        # y = cell_to_cluster_index[i]
        neighbor_truth_element = neighbor_truth[i]
        true_list.append(makeEdges(neighbor_truth_element, 1))
        bkg_0.append(makeEdges(neighbor_truth_element, 0))
        bkg_2.append(makeEdges(neighbor_truth_element, 2))
        bkg_3.append(makeEdges(neighbor_truth_element, 3))
        bkg_4.append(makeEdges(neighbor_truth_element, 4))
    return true_list, bkg_0, bkg_2, bkg_3, bkg_4

def pairs_of_edges(range_value, true_list, bkg_0_list, bkg_2_list, bkg_3_list, bkg_4_list):
    '''
    bkg_0 represents the background with a truth value of 0
    bkg_2/3 represents the background with a truth value of 2/3
    '''
    true_pairNumber = []
    bkg_0_pairNumber = []
    bkg_2_pairNumber = []
    bkg_3_pairNumber = []
    bkg_4_pairNumber = []
    for i in range(range_value):
        true_pairNumber.append(len(true_list[i]))
        bkg_0_pairNumber.append(len(bkg_0_list[i]))
        bkg_2_pairNumber.append(len(bkg_2_list[i]))
        bkg_3_pairNumber.append(len(bkg_3_list[i]))
        bkg_4_pairNumber.append(len(bkg_4_list[i]))
    return np.array(true_pairNumber), np.array(bkg_0_pairNumber), np.array(bkg_2_pairNumber), np.array(bkg_3_pairNumber), np.array(bkg_4_pairNumber)

In [11]:
def sorting_indices(pair_number_list, true_list, bkg_0, bkg_2, bkg_3, bkg_4):
    sorted_indices = np.argsort(-pair_number_list)
    return true_list[sorted_indices], bkg_0[sorted_indices], bkg_2[sorted_indices], bkg_3[sorted_indices], bkg_4[sorted_indices], sorted_indices

In [12]:
def sortList(data, sorted_indices):
    return [data[i] for i in sorted_indices]

In [13]:
def create_train_set_test_set(split_num, true_list, bkg_0, bkg_2, bkg_3, bkg_4):
    return true_list[:split_num],bkg_0[:split_num],bkg_2[:split_num],bkg_3[:split_num], bkg_4[:split_num], true_list[split_num:],bkg_0[split_num:],bkg_2[split_num:],bkg_3[split_num:], bkg_4[split_num:]

In [14]:
def scaling_of_features(dynamic_variables, split_num, sorted_indices, scaler_fileName):
    keys = list(dynamic_variables.keys())
    values = list(dynamic_variables.values())
    #Gets the keys and variables from the dynamic variables array and stores them in arrays.
    
    rearranged_values = [values[i] for i in sorted_indices]
    #Sorts the values with the sorted cluster indices.
    
    rearranged_dict = dict(zip(keys, rearranged_values))
    #Creates a rearranged dictionary out of the keys and sorted values.
    
    data_train = np.concatenate([value for key, value in list(rearranged_dict.items())[:split_num]])
    #Create a training data array.
    
    data_test = np.concatenate([value for key, value in list(dynamic_variables.items())[split_num:]])
    #Create a data testing array.
    
    scaler = MinMaxScaler()
    cellFeatures_trainS = scaler.fit_transform(data_train)
    cell_features_testS = scaler.transform(data_test)
    scaler_filename = "./bscaler_neighbor_data_train_sorted.save"
    joblib.dump(scaler, scaler_filename)
    return cellFeatures_trainS, cell_features_testS
    #Scales the training data with a minmaxscaler, put that scaled data into a training features array, and then save that scaler into a .save file.

In [15]:
def reshapeFeatures(splitNum, features):
    return features.reshape(splitNum, int(features.shape[0]/splitNum), int(features.shape[1]))

In [16]:
def saveDataH5(fileName, dataDict, compressionType):
    with h5py.File(fileName, 'w') as file:
        for key, data in dataDict.items():
            file.create_dataset(key, data=data, compression=compressionType)

def loadDataH5(fileName):
    dataDict = {}
    with h5py.File(fileName, 'r') as file:
        for key in file.keys():
            dataDict[key] = np.array(file[key])
    return dataDict

In [17]:
def removeBrokenCells(cell_noiseSigma, neighbor):
    broken_cells = getBrokenCells(cell_noiseSigma)
    return getNeighborPairs(broken_cells, neighbor)

In [18]:
def getBrokenCells(cell_noiseSigma):
    broken_cell_indices = np.argwhere(cell_noiseSigma[0] == 0)
    broken_cells = []
    for arrays in broken_cell_indices:
        for index in broken_cell_indices:
            broken_cells.append(index)
    return broken_cells

In [19]:
def getNeighborPairs(broken_cells, neighbor):
    neighbor_pairs_set = []
    for i in range(len(neighbor)):
        if i in broken_cells:
            continue
        for cell in neighbor[i]:
            if cell in broken_cells:
                continue
            neighbor_pairs_set.append(((i, cell)))
    return neighbor_pairs_set

In [20]:
def writeH5File(fileName, datasetName, data):
    with h5py.File(fileName, "w") as f:
        dset = f.create_dataset(datasetName, data = data)

In [21]:
def readH5File(fileName, datasetName):
    file = h5py.File(fileName, "r")
    data = file.get(datasetName)[:]
    file.close()
    return np.array(data)

In [22]:
def sampleDataTraining(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster):
    true_sample_size, bkg_sample_size = getTrainingSampleSizes(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster)
    return sampleData(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster, true_sample_size, bkg_sample_size, bkg_sample_size, bkg_sample_size, bkg_sample_size)
    

In [23]:
def sampleDataTesting(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster):
    true_sample_size = getTestingSampleSize(true)
    bkg_lone_sample_size = getTestingSampleSize(bkg_lone)
    bkg_cluster_lone_sample_size = getTestingSampleSize(bkg_cluster_lone)
    bkg_lone_cluster_sample_size = getTestingSampleSize(bkg_lone_cluster)
    bkg_cluster_cluster_sample_size = getTestingSampleSize(bkg_cluster_cluster)
    return sampleData(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster, true_sample_size, bkg_lone_sample_size, bkg_cluster_lone_sample_size, bkg_lone_cluster_sample_size, bkg_cluster_cluster_sample_size)

In [24]:
def getTestingSampleSize(data):
    minimum = getMinimum(data)
    sample_size = minimum - (minimum % 100)
    return sample_size

In [25]:
def getMinimum(data):
    return min([len(row) for row in data])

In [26]:
def getBackgroundMin(bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster):
    bkg_lone_min = getMinimum(bkg_lone)
    bkg_cluster_lone_min = getMinimum(bkg_cluster_lone)
    bkg_lone_cluster_min = getMinimum(bkg_lone_cluster)
    bkg_cluster_cluster_min = getMinimum(bkg_cluster_cluster)
    bkg_min = min([bkg_lone_min, bkg_cluster_lone_min, bkg_lone_cluster_min, bkg_cluster_cluster_min])
    return bkg_min

In [27]:
def getTrainingSampleSizes(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster):
    true_min = getMinimum(true)
    bkg_min = getBackgroundMin(bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster)
    bkg_sample_size = bkg_min - (bkg_min % 100)
    true_sample_size = bkg_sample_size*4
    if true_sample_size > true_min:
        true_sample_size = true_min - (true_min % 100)
        true_sample_size = true_sample_size - (true_sample_size % 4)
        bkg_sample_size = true_sample_size/4
    return true_sample_size, bkg_sample_size

In [28]:
def sampleData(true, bkg_lone, bkg_cluster_lone, bkg_lone_cluster, bkg_cluster_cluster, true_sample_size, bkg_lone_sample_size, bkg_cluster_lone_sample_size, bkg_lone_cluster_sample_size, bkg_cluster_cluster_sample_size):
    true_sample = sampleDataset(true, true_sample_size)
    bkg_lone_sample = sampleDataset(bkg_lone, bkg_lone_sample_size)
    bkg_cluster_lone_sample = sampleDataset(bkg_cluster_lone, bkg_cluster_lone_sample_size)
    bkg_lone_cluster_sample = sampleDataset(bkg_lone_cluster, bkg_lone_cluster_sample_size)
    bkg_cluster_cluster_sample = sampleDataset(bkg_cluster_cluster, bkg_cluster_cluster_sample_size)
    return np.array(true_sample), np.array(bkg_lone_sample), np.array(bkg_cluster_lone_sample), np.array(bkg_lone_cluster_sample), np.array(bkg_cluster_cluster_sample)

In [29]:
def sampleDataset(data, data_sample_size):
    return [random.sample(row, data_sample_size) for row in data]

In [30]:
def createRandomIndices(total_indices_shape):
    rand_index = []
    for i in range(total_indices_shape[0]):
        arr = np.arange(total_indices_shape[1])
        np.random.shuffle(arr)
        rand_index.append(arr)
    return np.array(rand_index)

In [31]:
def randomize2DArray(rand_indices, unrandomized_array):
    randomized_list = []
    for i in range(unrandomized_array.shape[0]):
        randomized_list.append(unrandomized_array[i][rand_indices[i]])
    return np.array(randomized_list)

In [32]:
def randomizeEdges(rand_indices, unrandomized_edges):
    randomized_list = []
    for i in range(rand_indices.shape[0]):
        randomized_list.append(unrandomized_edges[rand_indices[i]])
    return np.array(randomized_list)

In [33]:
def createEdgeArrays(inputData):
    source_BD = []
    dest_BD = []
    source_noBD = []
    dest_noBD = []
    for i in range(inputData.shape[0]):
        source_BD_element, dest_BD_element, source_noBD_element, dest_noBD_element = createBDAndNoBDArrays(inputData[i])

        source_BD.append(source_BD_element)
        dest_BD.append(dest_BD_element)
        source_noBD.append(source_noBD_element)
        dest_noBD.append(dest_noBD_element)
    return np.array(source_BD), np.array(dest_BD), np.array(source_noBD), np.array(dest_noBD)

In [34]:
def createBDAndNoBDArrays(inputData):
    source_BD = []
    dest_BD = []
    source_noBD = []
    dest_noBD = []

    for pair in inputData:

        source_BD.append(pair[0])
        source_BD.append(pair[1])
        
        dest_BD.append(pair[1])
        dest_BD.append(pair[0])

        source_noBD.append(pair[0])
        dest_noBD.append(pair[1])
        
    return source_BD, dest_BD, source_noBD, dest_noBD

In [35]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        return self.data_list[idx]

In [36]:
def collate(data_list):
    batch_x = [data.x for data in data_list]
    batch_edge_index = [data.edge_index for data in data_list]
    batch_edge_index_out = [data.edge_index_out for data in data_list]
    batch_y = [data.y for data in data_list]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [37]:
def collate_true(data_list_true):
    batch_x = [data.x for data in data_list_true]
    batch_edge_index = [data.edge_index for data in data_list_true]
    batch_edge_index_out = [data.edge_index_out for data in data_list_true]
    batch_y = [data.y for data in data_list_true]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [38]:
def collate_bkg_lone(data_list_bkg_lone):
    batch_x = [data.x for data in data_list_bkg_lone]
    batch_edge_index = [data.edge_index for data in data_list_bkg_lone]
    batch_edge_index_out = [data.edge_index_out for data in data_list_bkg_lone]
    batch_y = [data.y for data in data_list_bkg_lone]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [39]:
def collate_bkg_cluster_lone(data_list_bkg_cluster_lone):
    batch_x = [data.x for data in data_list_bkg_cluster_lone]
    batch_edge_index = [data.edge_index for data in data_list_bkg_cluster_lone]
    batch_edge_index_out = [data.edge_index_out for data in data_list_bkg_cluster_lone]
    batch_y = [data.y for data in data_list_bkg_cluster_lone]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [40]:
def collate_bkg_cluster_cluster(data_list_bkg_cluster_cluster):
    batch_x = [data.x for data in data_list_bkg_cluster_cluster]
    batch_edge_index = [data.edge_index for data in data_list_bkg_cluster_cluster]
    batch_edge_index_out = [data.edge_index_out for data in data_list_bkg_cluster_cluster]
    batch_y = [data.y for data in data_list_bkg_cluster_cluster]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [41]:
def collate_bkg_total(data_list_total_bkg):
    batch_x = [data.x for data in data_list_total_bkg]
    batch_edge_index = [data.edge_index for data in data_list_total_bkg]
    batch_edge_index_out = [data.edge_index_out for data in data_list_total_bkg]
    batch_y = [data.y for data in data_list_total_bkg]

    return batch_x, batch_edge_index, batch_edge_index_out, batch_y

In [42]:
class EdgeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(EdgeClassifier, self).__init__()

        # Node embedding layer
        self.node_embedding = nn.Linear(input_dim, hidden_dim)

        # Graph convolutional layers
        self.conv1 = GCNConv(hidden_dim, 128)
        self.bn1 = BatchNorm1d(128)
        
        self.conv2 = GCNConv(128, 64)
        self.bn2 = BatchNorm1d(64)
        
        # Edge classification layer
        self.fc = nn.Linear(128 , output_dim)

    def forward(self, x, edge_index, edge_index_out):
        edge_index = edge_index
        x = self.node_embedding(x)
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = torch.relu(x)

        # Edge representations
        edge_index_to_compare = edge_index_out
        edge_rep = torch.cat([x[edge_index_to_compare[0]], x[edge_index_to_compare[1]]], dim=1)

        # Edge classification
        edge_scores = torch.sigmoid(self.fc(edge_rep))

        return edge_scores
#Defines an edge classifier convolutional GNN model.

In [43]:
def lr_schedule(epoch):
    if epoch < 10:
        return 1.0  # No change for the first 10 epochs
    else:
        return 0.1  # Decrease learning rate by a factor of 10 after 10 epochs

In [44]:
def train(model, device, data_loader, optimizer, criterion):
    model.train()
    #model = nn.parallel.DistributedDataParallel(model)
    model.to(device)
    #output = []
    totalLossPerEpoch = []
    for batch_x, batch_edge_index, batch_edge_index_out, batch_y in data_loader:
        batch_x = torch.stack(batch_x).to(device)
        batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
        #print(batch_edge_index[0].shape)
        batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]
        batch_y = [y.to(device) for y in batch_y]
        #print(len(batch_y))
        optimizer.zero_grad()
        loss_per_batch = []
        for i in range(len(batch_edge_index)):
            _output = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
            #print(len(_output))
            #output.append(_output)
            loss = criterion(_output.squeeze(), batch_y[i].squeeze())
            #print(loss)
            loss_per_batch.append(loss)
        #print(loss_per_batch)
        #loss_per = torch.tensor(loss_per, dtype=torch.float)
        total_loss_per_batch = sum(loss_per_batch)/ len(loss_per_batch)
        totalLossPerEpoch.append(total_loss_per_batch)
        #total_loss = torch.tensor(total_loss, requires_grad=True) 
        #print("total_loss_per_batch: ",total_loss_per_batch)
        #total_loss.backward()
        total_loss_per_batch.backward()
        optimizer.step()
    #print("totalLossPerEpoch: ",totalLossPerEpoch)
    total_loss_per_epoch = sum(totalLossPerEpoch)/len(totalLossPerEpoch)
    print("total_loss_per_epoch:",total_loss_per_epoch)
    return total_loss_per_epoch
#Creates a method to train the GNN and return the loss for that epoch.

In [45]:
def testModel(model, device, data_loader_true, data_loader_bkg_total, criterion):
    all_scores = []
    true_labels = []
    totalLossPerEpochTestTrue = []
    totalLossPerEpochTestBackground = []

    with torch.no_grad():
        model.eval()
        model.to(device)
        for batch_x, batch_edge_index, batch_edge_index_out, batch_y in data_loader_true:
            batch_x = torch.stack(batch_x).to(device)
            batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
            batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]
            batch_y = [y.to(device) for y in batch_y]
            loss_per_batch = []
            for i in range(len(batch_edge_index)):
                test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                loss = criterion(test_edge_scores.squeeze(), batch_y[i].squeeze())
                loss_per_batch.append(loss)
                all_scores.append(test_edge_scores)
                true_labels.append(torch.ones(test_edge_scores.size(0)))
            total_loss_per_batch = sum(loss_per_batch)/len(loss_per_batch)
            totalLossPerEpochTestTrue.append(total_loss_per_batch)
        for batch_x, batch_edge_index, batch_edge_index_out, batch_y in data_loader_bkg_total:
            batch_x = torch.stack(batch_x).to(device)
            batch_edge_index = [edge_index.to(device) for edge_index in batch_edge_index]
            batch_edge_index_out = [edge_index.to(device) for edge_index in batch_edge_index_out]
            batch_y = [y.to(device) for y in batch_y]
            loss_per_batch = []
            for i in range(len(batch_edge_index)):
                test_edge_scores = model(batch_x[i], batch_edge_index[i], batch_edge_index_out[i])
                loss = criterion(test_edge_scores.squeeze(), batch_y[i].squeeze())
                loss_per_batch.append(loss)
                all_scores.append(test_edge_scores)
                true_labels.append(torch.zeros(test_edge_scores.size(0)))
            total_loss_per_batch = sum(loss_per_batch)/len(loss_per_batch)
            totalLossPerEpochTestBackground.append(total_loss_per_batch)
        
    all_scores = torch.cat(all_scores, dim = 0).cpu().numpy()
    true_labels = torch.cat(true_labels, dim = 0).cpu().numpy()
    total_loss_per_epoch_test_true = sum(totalLossPerEpochTestTrue)/len(totalLossPerEpochTestTrue)
    total_loss_per_epoch_test_background = sum(totalLossPerEpochTestBackground)/len(totalLossPerEpochTestBackground)
    return all_scores, true_labels, total_loss_per_epoch_test_true, total_loss_per_epoch_test_background

In [46]:
def pickleData(fileName, data):
    file = open(fileName, 'wb')
    pickle.dump(data, file)
    file.close()

In [47]:
def makeTruthArray(data, truth):
    shape = data.shape
    if truth:
        return torch.ones(getTruthArrayShape(shape))
    else:
        return torch.zeros(getTruthArrayShape(shape))

In [48]:
def getTruthArrayShape(shape):
    return (int(shape[0]), int(shape[1]-1), int(shape[2]))