In [39]:
import numpy as np
import scipy.spatial.distance as sd
from neighborhood import neighbor_graph, laplacian
from correspondence import Correspondence
from stiefel import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from datareader import *
import pandas as pd 
import os.path
import pdb
#cuda = torch.device('cuda') 
import scipy as sp
from collections import Counter
import seaborn as sns
from random import sample
import random
from sklearn import preprocessing
import matplotlib.pyplot as plt
%matplotlib inline

In [40]:
"""Defines the neural network"""

class Net(nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H1)
        self.linear2 = torch.nn.Linear(H1, H2)
        self.linear3 = torch.nn.Linear(H2, D_out)

    def forward(self, x):
        h1_sigmoid = self.linear1(x).sigmoid()
        h2_sigmoid = self.linear2(h1_sigmoid).sigmoid()
        y_pred = self.linear3(h2_sigmoid)
        return y_pred

In [41]:
def train_and_project(x1_np, x2_np):
    
    torch.manual_seed(0)

    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H1, H2, D_out = x1_np.shape[0], x1_np.shape[1], 512, 64, 3

    model = Net(D_in, H1, H2, D_out)

    x1 = torch.from_numpy(x1_np.astype(np.float32))
    x2 = torch.from_numpy(x2_np.astype(np.float32))
    print(x1.dtype)
    
    adj1 = neighbor_graph(x1_np, k=5)
    adj2 = neighbor_graph(x2_np, k=5)

    #corr = Correspondence(matrix=np.eye(N))

    w1 = np.corrcoef(x1, x2)[0:x1.shape[0],x1.shape[0]:(x1.shape[0]+x2.shape[0])]
    w1[abs(w1) > 0.5] = 1
    w1[w1 != 1] = 0
    w = np.block([[w1,adj1],
                  [adj2,w1.T]])

    L_np = laplacian(w, normed=False)
    L = torch.from_numpy(L_np.astype(np.float32))
    
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
    
    for t in range(500):
        # Forward pass: Compute predicted y by passing x to the model
        y1_pred = model(x1)
        y2_pred = model(x2)

        outputs = torch.cat((y1_pred, y2_pred), 0)
        
        # Project the output onto Stiefel Manifold
        u, s, v = torch.svd(outputs, some=True)
        proj_outputs = u@v.t()

        # Compute and print loss
        print(L.dtype)
        loss = torch.trace(proj_outputs.t()@L@proj_outputs)
        print(t, loss.item())

        # Zero gradients, perform a backward pass, and update the weights.
        proj_outputs.retain_grad()

        optimizer.zero_grad()
        loss.backward(retain_graph=True)

        # Project the (Euclidean) gradient onto the tangent space of Stiefel Manifold (to get Rimannian gradient)
        rgrad = proj_stiefel(proj_outputs, proj_outputs.grad) 

        optimizer.zero_grad()
        # Backpropogate the Rimannian gradient w.r.t proj_outputs
        proj_outputs.backward(rgrad)

        optimizer.step()
        
    proj_outputs_np = proj_outputs.detach().numpy()
    return proj_outputs_np

In [42]:
Efeature = pd.read_csv('../data/efeature_filtered.csv',index_col=0)
geneExp = pd.read_csv('../data/expMat_filtered_5000.csv',index_col=0)
label = pd.read_csv('../data/label_visual.csv')
print('Shape of geneExp: ', geneExp.shape)
print('Shape of Efeature: ', Efeature.shape)

#x1_np = preprocessing.scale(np.log(geneExp+1).to_numpy())
#x2_np = preprocessing.scale(Efeature.T.to_numpy())

x1_np = np.log(geneExp+1).to_numpy()[0:500]
x2_np = preprocessing.scale(Efeature.T.to_numpy())

print(x1_np.shape)
print(x2_np.shape)

Shape of geneExp:  (5000, 3654)
Shape of Efeature:  (3654, 41)
(500, 3654)
(41, 3654)


In [43]:
projections = train_and_project(x1_np, x2_np)

torch.float32
torch.float32
0 15.873481750488281
torch.float32
1 15.023713111877441
torch.float32
2 14.522339820861816
torch.float32
3 13.960118293762207
torch.float32
4 13.435967445373535
torch.float32
5 13.023812294006348
torch.float32
6 12.69229507446289
torch.float32
7 12.383977890014648
torch.float32
8 12.074462890625
torch.float32
9 11.763294219970703
torch.float32
10 11.461114883422852
torch.float32
11 11.177868843078613
torch.float32
12 10.910758972167969
torch.float32
13 10.648717880249023
torch.float32
14 10.39211368560791
torch.float32
15 10.156060218811035
torch.float32
16 9.949861526489258
torch.float32
17 9.763043403625488
torch.float32
18 9.578285217285156
torch.float32
19 9.389330863952637
torch.float32
20 9.200474739074707
torch.float32
21 9.016444206237793
torch.float32
22 8.838644027709961
torch.float32
23 8.667415618896484
torch.float32
24 8.503211975097656
torch.float32
25 8.347084999084473
torch.float32
26 8.200505256652832
torch.float32
27 8.062524795532227
torch

torch.float32
228 2.467444658279419
torch.float32
229 2.461538314819336
torch.float32
230 2.456616163253784
torch.float32
231 2.4525694847106934
torch.float32
232 2.4472692012786865
torch.float32
233 2.4423043727874756
torch.float32
234 2.438297748565674
torch.float32
235 2.4335763454437256
torch.float32
236 2.428725242614746
torch.float32
237 2.4246859550476074
torch.float32
238 2.420468330383301
torch.float32
239 2.415837049484253
torch.float32
240 2.411689281463623
torch.float32
241 2.4078099727630615
torch.float32
242 2.403571367263794
torch.float32
243 2.3993799686431885
torch.float32
244 2.3955891132354736
torch.float32
245 2.3917620182037354
torch.float32
246 2.3877315521240234
torch.float32
247 2.383899211883545
torch.float32
248 2.3802998065948486
torch.float32
249 2.3765978813171387
torch.float32
250 2.37282133102417
torch.float32
251 2.3692235946655273
torch.float32
252 2.365762948989868
torch.float32
253 2.362238645553589
torch.float32
254 2.3586912155151367
torch.float32
2

torch.float32
454 2.1035964488983154
torch.float32
455 2.1033718585968018
torch.float32
456 2.1031441688537598
torch.float32
457 2.1029176712036133
torch.float32
458 2.1026947498321533
torch.float32
459 2.1024768352508545
torch.float32
460 2.1022655963897705
torch.float32
461 2.1020588874816895
torch.float32
462 2.1018526554107666
torch.float32
463 2.101646900177002
torch.float32
464 2.1014416217803955
torch.float32
465 2.1012377738952637
torch.float32
466 2.1010348796844482
torch.float32
467 2.1008377075195312
torch.float32
468 2.1006405353546143
torch.float32
469 2.100449562072754
torch.float32
470 2.1002585887908936
torch.float32
471 2.1000709533691406
torch.float32
472 2.099884271621704
torch.float32
473 2.099701166152954
torch.float32
474 2.0995194911956787
torch.float32
475 2.099339723587036
torch.float32
476 2.0991599559783936
torch.float32
477 2.0989842414855957
torch.float32
478 2.098809003829956
torch.float32
479 2.0986363887786865
torch.float32
480 2.098466157913208
torch.fl

In [45]:
projections.shape

(541, 3)

In [46]:
projections = pd.DataFrame(projections)
features = geneExp.index.tolist()[0:500]+Efeature.columns.tolist()
projections.index = features
projections

Unnamed: 0,0,1,2
Adarb2,-0.026649,0.005948,-0.010327
Sst,-0.009147,0.016273,-0.017417
Vip,-0.026710,0.016914,0.000659
Npy,-0.023905,0.020164,-0.008712
Synpr,-0.011319,0.032751,-0.017051
...,...,...,...
fast_trough_v_short_square,-0.011954,0.022355,-0.010206
fast_trough_t_short_square,-0.014934,0.015737,-0.020701
threshold_v_short_square,-0.011773,0.018150,-0.017844
threshold_i_short_square,-0.009439,0.010755,-0.028545


In [47]:
projections.to_csv("../data/deepmanreg_latent_500.csv")