In [1]:
%%capture
!pip install scprep
!pip install anndata
!pip install scanpy

In [2]:
import numpy as np
import pandas as pd
import anndata
import scprep
import scanpy as sc
import sklearn
from sklearn.model_selection import train_test_split
import tempfile
import os
import sys
import scipy
from scipy import sparse

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import load_raw
import normalize_tools as nm
import metrics

In [3]:
#set up all hyper-parameters
hyper = {
    "nEpochs":120,
    "dimRNA":3633,
    "dimATAC":4403,
    "n_hidden":1024,
    "layer_sizes":[1024, 1024, 1024, 256, 256],
    "nz":128,
    "batchSize":512,
    "lr":1e-3,
    "lamb_kl":1e-9,
    "lamb_anc":1e-9,
    "clip_grad":0.1,
    "weightDirName": './checkpoint/',
}

In [4]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

# **try out with scicar cell lines dataset**

**1. URLs for raw data**

In [5]:
rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes = load_raw.load_raw_cell_lines()

In [6]:
scicar_data, joint_index, keep_cells_idx = load_raw.merge_data(rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes)
#rna_df, atac_df = ann2df(scica|r_data)|

In [7]:
#tryout log cpm scicar_data
nm.log_cpm(scicar_data)
nm.log_cpm(scicar_data, obsm = "mode2", obs = "mode2_obs", var = "mode2_var")
nm.hvg_by_sc(scicar_data, proportion = 0.06)
nm.hvg_by_sc(scicar_data, obsm = "mode2", obs = "mode2_obs", 
             var = "mode2_var", proportion = 0.03)

In [8]:
scicar_data.uns["mode2_obs"] = np.array(scicar_data.uns["mode2_obs"][0])

In [9]:
scicar_data.uns["mode2_var"] = np.array(scicar_data.uns["mode2_var"][0])

In [10]:
scicar_data.uns = {"mode2_obs": scicar_data.uns["mode2_obs"], "mode2_var": scicar_data.uns["mode2_var"]}

In [11]:
train_data, test_data = load_raw.train_test_split(scicar_data)

In [12]:
test_data.X

<1422x3633 sparse matrix of type '<class 'numpy.float32'>'
	with 68648 stored elements in Compressed Sparse Row format>

In [13]:
train_data.obsm["mode2"]

<3317x4403 sparse matrix of type '<class 'numpy.float32'>'
	with 15736 stored elements in Compressed Sparse Row format>

# **define pytorch datasets for RNA and ATAC**

In [14]:
class Merge_Dataset(Dataset):
    def __init__(self, adata):
        self.rna_data, self.atac_data = self._load_merge_data(adata)

    def __len__(self):
        #assert(len(self.rna_data) == len(self.atac_data))
        return len(self.atac_data)
  
    def __getitem__(self, idx):
        rna_sample = self.rna_data.values[idx]
        atac_sample = self.atac_data.values[idx]
        #return a tensor that for a single observation
        return torch.from_numpy(rna_sample).float(), torch.from_numpy(atac_sample).float()
  
    def _load_merge_data(self, adata):
        rna_df = pd.DataFrame(data = adata.X.toarray(), index = np.array(adata.obs.index), columns = np.array(adata.var.index))
        atac_df = pd.DataFrame(data = adata.obsm["mode2"].toarray(), index = np.array(adata.uns["mode2_obs"]), columns = np.array(adata.uns["mode2_var"]))
        return rna_df, atac_df

# **define basic models(autoencoders) for learning latent space**

In [15]:
class Encoder(nn.Module):
    def __init__(self, n_input, n_latent, layer_sizes):
        super(Encoder, self).__init__()
        self.n_input = n_input
        self.n_latent = n_latent
        self.layer_sizes = [n_input] + layer_sizes + [n_latent]
        self.encoder_layers = []
        
        for idx in range(len(self.layer_sizes) - 1):
            fc1 = nn.Linear(self.layer_sizes[idx], self.layer_sizes[idx + 1])
            nn.init.xavier_uniform_(fc1.weight)
            self.encoder_layers.append(fc1)
            bn1 = nn.BatchNorm1d(self.layer_sizes[idx + 1])
            self.endocer_layers.append(bn1)
            act1 = nn.PReLU()
            self.encoder_layers.append(act1)
            
        self.encoder = nn.Sequential(*self.encoder_layers)
        
        def forward(self, x):
            return(self.encoder(x))

In [16]:
class Decoder(nn.Module):
    def __init__(self, n_output, n_latent, layer_sizes, final_activation=None):
        self.n_output = n_output
        self.n_latent = n_latent
        self.layer_sizes = [n_output] + layer_sizes + [n_latent]
        self.decoder1_layers = []
        for idx in range(len(self.layer_sizes) - 1, 1,  -1):
            fc1 = nn.Linear(self.layer_sizes[idx], self.layer_sizes[idx - 1])
            nn.init.xavier_uniform_(fc1.weight)
            self.decoder1_layers.append(fc1)
            bn1 = nn.BatchNorm1d(self.layer_sizes[idx - 1])
            self.decoder1_layers.append(bn1)
            act1 = nn.PReLU()
            self.decoder1_layers.append(act1)
        self.decoder1 = nn.Sequential(self.decoder1_layers)
        
        self.n_inter = self.layer_sizes[len(self.layer_sizes)-2]
        self.decoder21 = nn.Linear(self.n_inter, self.n_output)
        nn.init.xavier_uniform_(self.decoder21.weight)
        self.decoder22 = nn.Linear(self.n_inter, self.n_output)
        nn.init.xavier_uniform_(self.decoder22.weight)
        self.decoder23 = nn.Linear(self.n_inter, self.n_output)
        nn.init.xavier_uniform_(self.decoder23.weight)
        
        self.final_activations = nn.ModuleDict()
        if final_activation is not None:
            if isinstance(final_activation, list) or isinstance(final_activation, tuple):
                for i, act in enumerate(final_activation):
                    self.final_activations[f"act{i+1}"] = act
            elif isinstance(final_activation, nn.Module):
                self.final_activations["act1"] = final_activation
            else:
                raise ValueError(
                    f"Unrecognized type for final_activation: {type(final_activation)}"
                )
            
        def forward(self, x):
            x = self.decoder1(x)
            
            retval1 = self.decoder21(x)
            if "act1" in self.final_activations.keys():
                retval1 = self.final_activations["act1"](retval1)
            
            retval2 = self.decoder22(x)
            if "act2" in self.final_activations.keys():
                retval2 = self.final_activations["act2"](retval2)
            
            retval3 = self.decoder23(x)
            if "act3" in self.final_activations.keys():
                retval3 = self.final_activations["act3"](retval3)
                
            return retval1, retval2, retval3

In [17]:
class FC_VAE(nn.Module):
    def __init__(self, n_input, nz, n_hidden=hyper["n_hidden"], layer_sizes=hyper["layer_sizes"]):
        super(FC_VAE, self).__init__()
        self.n_input = n_input
        self.nz = nz
        self.n_hidden = n_hidden
        self.layer_sizes = layer_sizes

        self.encoder_layers = []

        self.encoder_layers.append(nn.Linear(n_input, self.layer_sizes[0]))
        self.encoder_layers.append(nn.LeakyReLU(inplace=True))
        self.encoder_layers.append(nn.BatchNorm1d(self.layer_sizes[0]))

        for layer_idx in range(len(layer_sizes)-1):
            if layer_idx == len(layer_sizes) - 2:
                self.encoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
            else:
                self.encoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
                self.encoder_layers.append(nn.BatchNorm1d(self.layer_sizes[layer_idx+1]))
                self.encoder_layers.append(nn.LeakyReLU(inplace=True))

        self.encoder = nn.Sequential(
            *self.encoder_layers
        )
        self.fc1 = nn.Linear(self.layer_sizes[-1], nz)
        self.fc2 = nn.Linear(self.layer_sizes[-1], nz)

        self.decoder_layers = []
        self.decoder_layers.append(nn.Linear(nz, self.layer_sizes[-1]))
        self.decoder_layers.append(nn.LeakyReLU(inplace=True))
        self.decoder_layers.append(nn.BatchNorm1d(self.layer_sizes[-1]))

        for layer_idx in range(len(self.layer_sizes)-1, 0, -1):
            self.decoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx-1]))
            self.decoder_layers.append(nn.LeakyReLU(inplace=True))
            self.decoder_layers.append(nn.BatchNorm1d(self.layer_sizes[layer_idx-1]))

        self.decoder_layers.append(nn.Linear(self.layer_sizes[0], self.n_input))

        self.decoder = nn.Sequential(
            *self.decoder_layers
        )
    def encode(self, x):
        h = self.encoder(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        #calculate std from log(var)
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, z, mu, logvar

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
    
    def generate(self, z):
        return self.decode(z)

# **train VAE model based on reconstruction, KL divergence, and anchor loss**

In [18]:
def collate_fn(batch):
    n_svd = 100
    proportion_neighbors = 0.1
    
    rna_inputs, atac_inputs = zip(*batch)
    rna_inputs = torch.stack(rna_inputs)
    atac_inputs = torch.stack(atac_inputs)
    
    n_svd = min([n_svd, min(rna_inputs.shape) - 1])
    n_neighbors = int(np.ceil(proportion_neighbors * rna_inputs.shape[0]))
    X_pca = sklearn.decomposition.TruncatedSVD(n_svd).fit_transform(rna_inputs)
    _, indices_true = (
        sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(X_pca).kneighbors(X_pca)
    )
    
    return rna_inputs, atac_inputs, torch.from_numpy(indices_true)

In [19]:
#load dataset and split train and test data
def get_data_loaders(train_data, test_data):
    train_set = Merge_Dataset(train_data)
    test_set = Merge_Dataset(test_data)
    #load data loader
    train_loader = DataLoader(
        train_set, 
        batch_size=hyper["batchSize"], 
        collate_fn=collate_fn, 
        drop_last=False, 
        shuffle=True
    )
    test_loader = DataLoader(
        test_set, 
        batch_size=test_data.shape[0], 
        collate_fn=collate_fn, 
        drop_last=False,
        shuffle=False
    )
    return train_loader, test_loader

train_loader, test_loader = get_data_loaders(train_data=train_data, test_data=test_data)

In [20]:
#load basic models
netRNA = FC_VAE(n_input=hyper["dimRNA"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
netATAC = FC_VAE(n_input=hyper["dimATAC"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])

In [21]:
#use GPU
if torch.cuda.is_available():
    print("using GPU")
    netRNA.cuda()
    netATAC.cuda()
    device = "cuda"
else:
    device = "cpu"
#setup optimizers for two nets
opt_netRNA = optim.Adam(list(netRNA.parameters()), lr=hyper["lr"])
opt_netATAC = optim.Adam(list(netATAC.parameters()), lr=hyper["lr"])

In [22]:
class StructureHingeLoss(nn.Module):
    def __init__(self, margin, max_val, lamb_match, lamb_nn):
        super(StructureHingeLoss, self).__init__()
        self.margin = margin
        self.max_val = max_val
        self.lamb_match = lamb_match
        self.lamb_nn = lamb_nn
        
    def forward(self, rna_outputs, atac_outputs, nn_indices):
        #rna_outputs: n_batch x n_latent
        #atac_outputs: n_batch x n_latent
        assert rna_outputs.shape[0] == atac_outputs.shape[0]
        assert rna_outputs.shape[1] == atac_outputs.shape[1]
        n_batch = rna_outputs.shape[0]
        
        #calculated pairwise L2 distance
        #dist_rna_atac[i][j]: the L2 distance between RNA embedding i
        #and ATAC embedding j (n_batch x n_batch)
        #constraint for ensuring every rna embedding is close to matched atac embedding
        dist_rna_atac = torch.cdist(rna_outputs, atac_outputs, p=2)
        match_labels = torch.eye(n_batch)
        match_mask = match_labels > 0
        pos_match_dist = torch.masked_select(dist_rna_atac, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_rna_atac, ~match_mask).view(n_batch, -1)
        
        loss_match_rna = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_rna = loss_match_rna.mean()
        #print(f"loss_match_rna: {loss_match_rna}")
        
        #constraint for ensuring every atac embedding is close to matched rna embedding
        dist_atac_rna = dist_rna_atac.t()
        pos_match_dist = torch.masked_select(dist_atac_rna, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_atac_rna, ~match_mask).view(n_batch, -1)
        
        loss_match_atac = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_atac = loss_match_rna.mean()
        #print(f"loss_match_atac: {loss_match_atac}")
        
        #constraint for ensuring that every RNA embedding is close to 
        #the neighboring RNA embeddings.
        nn_masked = torch.zeros(n_batch, n_batch)
        nn_masked.scatter_(1, nn_indices, 1.)
        nn_masked = nn_masked > 0
        
        dist_rna_rna = torch.cdist(rna_outputs, rna_outputs, p=2)
        
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_rna_nn_dist = torch.masked_select(dist_rna_rna, nn_masked).view(n_batch, -1)
        neg_rna_nn_dist = torch.masked_select(dist_rna_rna, ~nn_masked).view(n_batch, -1)
        rna_nn_loss = torch.clamp(self.margin + pos_rna_nn_dist[...,None] - neg_rna_nn_dist[..., None, :], 0, self.max_val)
        rna_nn_loss = rna_nn_loss.mean()
        #print(f"rna_nn_loss: {rna_nn_loss}")
        
        #constraint for ensuring that every ATAC embedding is close to 
        #the neighboring ATAC embeddings.
        dist_atac_atac = torch.cdist(atac_outputs, atac_outputs, p=2)
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_atac_nn_dist = torch.masked_select(dist_atac_atac, nn_masked).view(n_batch, -1)
        neg_atac_nn_dist = torch.masked_select(dist_atac_atac, ~nn_masked).view(n_batch, -1)
        atac_nn_loss = torch.clamp(self.margin + pos_atac_nn_dist[...,None] - neg_atac_nn_dist[..., None, :], 0, self.max_val)
        atac_nn_loss = atac_nn_loss.mean()
        #print(f"atac_nn_loss: {atac_nn_loss}")
        
        loss = (self.lamb_match * loss_match_rna 
                + self.lamb_match * loss_match_atac
                + self.lamb_nn * rna_nn_loss 
                + self.lamb_nn * atac_nn_loss)
        return loss

In [23]:
#set up loss function
def basic_loss(recon_x, x, mu, logvar, lamb1):
    MSE = nn.MSELoss()
    lloss = MSE(recon_x, x)
    #KL divergence
    KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    lloss = lloss + lamb1*KL_loss
    return lloss

#anchor loss for minimizing distance between paired observation
def anchor_loss(embed_rna, embed_atac):
    L1 = nn.L1Loss()
    anc_loss = L1(embed_rna, embed_atac)
    return anc_loss

def hinge_loss(
    margin, 
    max_val, 
    lamb_match,
    lamb_nn, 
    embed_rna, 
    embed_atac, 
    nn_indices,
):
    Hinge_Loss = StructureHingeLoss(margin, max_val, lamb_match, lamb_nn)
    loss = Hinge_Loss(embed_rna, embed_atac, nn_indices)
    return loss

In [24]:
def knn_criteria(rna_inputs, atac_inputs, rna_outputs, atac_outputs, proportion_neighbors=0.1, n_svd=100):
    n_svd = min([n_svd, min(rna_inputs.shape)-1])
    n_neighbors = int(np.ceil(proportion_neighbors*rna_inputs.shape[0]))
    X_pca = sklearn.decomposition.TruncatedSVD(n_svd).fit_transform(rna_inputs)
    _, indices_true = (
        sklearn.neighbors.NearestNeighbors(n_neighbors = n_neighbors).fit(X_pca).kneighbors(X_pca)
    )
    _, indices_pred = (
        sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(rna_outputs).kneighbors(atac_outputs)
    )
    neighbors_match = np.zeros(n_neighbors, dtype=int)
    for i in range(rna_inputs.shape[0]):
        _, pred_matches, true_matches = np.intersect1d(
            indices_pred[i], indices_true[i], return_indices=True
        )
        neighbors_match_idx = np.maximum(pred_matches, true_matches)
        neighbors_match += np.sum(np.arange(n_neighbors) >= neighbors_match_idx[:, None], axis = 0,)
    neighbors_match_curve = neighbors_match/(np.arange(1, n_neighbors + 1) * rna_inputs.shape[0])
    area_under_curve = np.mean(neighbors_match_curve)
    return area_under_curve

In [25]:
#set up train functions
def train(epoch):
    netRNA.train()
    netATAC.train()
    train_losses = []
    for idx, samples in enumerate(train_loader):
        rna_inputs, atac_inputs, nn_indices = Variable(samples[0]), Variable(samples[1]), samples[2]
        if torch.cuda.is_available():
            rna_inputs = rna_inputs.cuda()
            atac_inputs = atac_inputs.cuda()
            
        opt_netATAC.zero_grad()
        opt_netRNA.zero_grad()
        recon_rna, z_rna, mu_rna, logvar_rna = netRNA(rna_inputs)
        recon_atac, z_atac, mu_atac, logvar_atac = netATAC(atac_inputs)
        rna_loss = basic_loss(recon_rna, rna_inputs, mu_rna, logvar_rna, lamb1=hyper["lamb_kl"])
        atac_loss = basic_loss(recon_atac, atac_inputs, mu_atac, logvar_atac, lamb1=hyper["lamb_kl"])
        
        h_loss = hinge_loss(
            margin=0.2, 
            max_val=1e6, 
            lamb_match=1,
            lamb_nn=0.5,
            embed_rna=z_rna, 
            embed_atac=z_atac, 
            nn_indices=nn_indices,
        )
        '''if epoch % 5 == 0:
            print(f"rna_loss: {rna_loss}")
            print(f"atac_loss:{atac_loss}")
            print(f"anc_loss: {anc_loss}")
            print(f"hinge loss: {h_loss}")'''
        
        #loss functions for each modalities
        train_loss = rna_loss + atac_loss + h_loss
        #train_loss = h_loss
        #train_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss
        #rain_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss + h_loss
        train_loss.backward()
        nn.utils.clip_grad_norm_(netRNA.parameters(), max_norm=hyper["clip_grad"])
        nn.utils.clip_grad_norm_(netATAC.parameters(), max_norm=hyper["clip_grad"])
        opt_netRNA.step()
        opt_netATAC.step()
        train_losses.append(train_loss.item())
    if epoch % 5 == 0:
        print("Epoch: " + str(epoch) + ", train loss: " + str(np.mean(train_losses)))

In [26]:
def evaluate(epoch):
    #evaluating step
    with torch.no_grad():
        netRNA.eval()
        netATAC.eval()
        knn_acc = []
        #mse_acc = []
        for idx, samples in enumerate(test_loader):
            rna_inputs = samples[0].float()
            atac_inputs = samples[1].float()
            rna_inputs = rna_inputs.to(device)
            atac_inputs = atac_inputs.to(device)

            _, output_rna, _, _ = netRNA(rna_inputs)
            _, output_atac, _, _ = netATAC(atac_inputs)
            knn_acc.append(knn_criteria(rna_inputs.cpu().detach(), atac_inputs.cpu().detach(), 
                                        output_rna.cpu().detach(), output_atac.cpu().detach()))
        avg_knn_auc = np.mean(knn_acc)
    if epoch % 5 == 0:
        print("Epoch: " + str(epoch) + ", acc: " + str(avg_knn_auc))

In [27]:
#train a toy model and see the scores
max_iter = hyper["nEpochs"]
for epoch in range(max_iter):
    train(epoch)
    evaluate(epoch)
  #set up log
  #if epoch % 50 == 0:
    #print("***saving checkpoints***")
    #path = "{}Max_iter_{}lamb_anc_{}Epoch_{}params.pth".format(hyper["weightDirName"], str(hyper["nEpochs"]), str(hyper["lamb_anc"]), str(epoch))
    
    #torch.save({
    #    "epoch": epoch,
    #    'netRNA_state_dict': netRNA.state_dict(),
    #    'netATAC_state_dict': netATAC.state_dict(),
    # }, path)


Epoch: 0, train loss: 2.5642493792942593
Epoch: 0, acc: 0.08530599526387524
Epoch: 5, train loss: 0.7254585112844195
Epoch: 5, acc: 0.19060071915223933
Epoch: 10, train loss: 0.5740096313612801
Epoch: 10, acc: 0.18146362370125327
Epoch: 15, train loss: 0.5465811320713588
Epoch: 15, acc: 0.20300014506820613
Epoch: 20, train loss: 0.5254029291016715
Epoch: 20, acc: 0.2237362471747925
Epoch: 25, train loss: 0.5062948976244245
Epoch: 25, acc: 0.23392157118731247
Epoch: 30, train loss: 0.493059550012861
Epoch: 30, acc: 0.24246833663124398
Epoch: 35, train loss: 0.482784628868103
Epoch: 35, acc: 0.2479453977558715
Epoch: 40, train loss: 0.4709799417427608
Epoch: 40, acc: 0.2434553108098933
Epoch: 45, train loss: 0.45106356910296846
Epoch: 45, acc: 0.25968260116577185
Epoch: 50, train loss: 0.4330393544265202
Epoch: 50, acc: 0.2624669185834086
Epoch: 55, train loss: 0.4121246039867401
Epoch: 55, acc: 0.25749721364258243
Epoch: 60, train loss: 0.3933856614998409
Epoch: 60, acc: 0.2614627820497

In [28]:
def model_eval(test_adata):
    netRNA.eval()
    netATAC.eval()
    rna_inputs = Variable(torch.from_numpy(test_adata.X.toarray()).float())
    atac_inputs = Variable(torch.from_numpy(test_adata.obsm["mode2"].toarray()).float())
    if torch.cuda.is_available():
        rna_inputs = rna_inputs.cuda()
        atac_inputs = atac_inputs.cuda()
    _, z_rna, _, _ = netRNA(rna_inputs)
    _, z_atac, _, _ = netATAC(atac_inputs)
    test_adata.obsm["aligned"] = sparse.csr_matrix(z_rna.cpu().detach())
    test_adata.obsm["mode2_aligned"] = sparse.csr_matrix(z_atac.cpu().detach())
    knn_score, mse_score = metrics.knn_auc(test_adata), metrics.mse(test_adata)
    return knn_score, mse_score

In [29]:
#test knn_auc plateau at around 0.09, seems that training starts to overfit
test_knn_score, test_mse_score = model_eval(test_data)
print(test_knn_score)
print(test_mse_score)
train_knn_score, train_mse_score = model_eval(train_data)
print(train_knn_score)
print(train_mse_score)

0.2638007403655664
0.9727991
0.3449979906182975
0.6296172


In [None]:
#log the metrics
path = "{}Max_iter_{}lamb_anc_{}metrics.txt".format(hyper["weightDirName"], str(hyper["nEpochs"]), str(hyper["lamb_anc"]))
'''torch.save({
    "num_iter": hyper["nEpochs"],
    "lamb_anc": hyper["lamb_anc"],
    'knn_auc': knn_score,
    'mse': mse_score,
}, path)'''

In [None]:
with open(path, 'a') as f:
        print('nEpoch: ', hyper["nEpochs"], 'lamb_anc:%.8f'%float(hyper["lamb_anc"]) , ',knn_auc: %.8f' % float(knn_score), ', mse_score: %.8f' % float(mse_score), file=f)