In [1]:
import numpy as np
import torch
import torch.utils.data
import matplotlib
import os
import sys
matplotlib.use('Agg')

import aitac
import plot_utils

import time
from sklearn.model_selection import train_test_split


In [2]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [3]:
# Hyper parameters
num_classes = 141
batch_size = 10
num_filters = 300


In [4]:
#create output figure directory
model_name = 'mini_sample'
output_file_path = "../outputs/" + model_name + "/motifs/"
directory = os.path.dirname(output_file_path)
if not os.path.exists(directory):
    os.makedirs(directory)



In [5]:
# Load all data
x = np.load('../BRCA_data/mini_sample_one_hot_seqs.npy')
x = x.astype(np.float32)
y = np.load('../BRCA_data/mini_sample_cell_type_array.npy')
y = y.astype(np.float32)
peak_names = np.load('../BRCA_data/mini_sample_peak_names.npy')


In [6]:
# Data loader
dataset = torch.utils.data.TensorDataset(torch.from_numpy(x), torch.from_numpy(y))
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)


In [7]:
# load trained model
model = aitac.ConvNet(num_classes, num_filters).to(device)
checkpoint = torch.load('../models/' + model_name + '.ckpt')
model.load_state_dict(checkpoint)


<All keys matched successfully>

In [8]:
#copy trained model weights to motif extraction model
motif_model = aitac.motifCNN(model).to(device)
motif_model.load_state_dict(model.state_dict())


<All keys matched successfully>

In [9]:
# run predictions with full model on all data
pred_full_model, max_activations, activation_idx = aitac.test_model(data_loader, model, device)
correlations = plot_utils.plot_cors(y, pred_full_model, output_file_path)


weighted_cor is 0.6108408034010132
number of NaN values: 0


In [10]:
# find well predicted OCRs
idx = np.argwhere(np.asarray(correlations)>0.75).squeeze()


In [11]:
#get data subset for well predicted OCRs to run further test
x2 = x[idx, :, :]
y2 = y[idx, :]

dataset = torch.utils.data.TensorDataset(torch.from_numpy(x2), torch.from_numpy(y2))
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)


In [12]:
# non-modified results for well-predicted OCRs only
pred_full_model2 = pred_full_model[idx,:]
correlations2 = plot_utils.plot_cors(y2, pred_full_model2, output_file_path)


weighted_cor is 0.7901855047060239
number of NaN values: 0


In [13]:
# get first layer activations and predictions with leave-one-filter-out
start = time.time()
activations, predictions = aitac.get_motifs(data_loader, motif_model, device)
print(time.time()- start)


77.33405542373657


In [14]:
#PLOTING
filt_corr, filt_infl, ave_filt_infl = plot_utils.plot_filt_corr(predictions, y2, correlations2, output_file_path)


could not set the title for graph
Replacement index 1 out of range for positional args tuple
could not set the title for graph
Replacement index 1 out of range for positional args tuple
Shape of filter-wise correlations:
(27, 300)
Shape of filter influence:
(27, 300)


In [15]:
infl, infl_by_OCR = plot_utils.plot_filt_infl(pred_full_model2, predictions, output_file_path)


could not set the title for graph
Replacement index 1 out of range for positional args tuple


In [16]:
pwm, act_ind, nseqs, activated_OCRs, n_activated_OCRs, OCR_matrix = plot_utils.get_memes(activations, x2, y2, output_file_path)


In [17]:
#save predictions
np.save(output_file_path + "filter_predictions.npy", predictions)
np.save(output_file_path + "predictions.npy", pred_full_model)


In [18]:
#save correlations
np.save(output_file_path + "correlations.npy", correlations)
np.save(output_file_path + "correlations_per_filter.npy", filt_corr)


In [19]:
#overall influence:
np.save(output_file_path + "influence.npy", ave_filt_infl)
np.save(output_file_path + "influence_by_OCR.npy", filt_infl)


In [20]:
#influence by cell type:
np.save(output_file_path + "filter_cellwise_influence.npy", infl)
np.save(output_file_path + "cellwise_influence_by_OCR.npy", infl_by_OCR)


In [21]:
#other metrics
np.savetxt(output_file_path + "nseqs_per_filters.txt", nseqs)
np.save(output_file_path + "mean_OCR_activation.npy", activated_OCRs)
np.save(output_file_path + "n_activated_OCRs.npy",  n_activated_OCRs)
np.save(output_file_path + "OCR_matrix.npy", OCR_matrix)


In [22]:
fitl_pred = np.load('../outputs/first_approach/motifs/filter_predictions.npy')
print(fitl_pred.shape)

(4, 300, 81)


In [23]:
pred = np.load('../outputs/first_approach/motifs/predictions.npy')
print(pred.shape)

(100, 81)


In [24]:
corr = np.load('../outputs/first_approach/motifs/correlations.npy')
print(corr.shape)
print(corr)

(100,)
[ 0.47675767  0.40375768  0.43856965  0.34622067  0.34175713  0.33306235
  0.36085143  0.41178661  0.47650863  0.26243808  0.49829282  0.39362301
  0.58005024  0.3121501   0.44273265  0.29923952  0.34399613  0.45722866
  0.04290422  0.20510039  0.23700697  0.40977453  0.39995587  0.49250575
  0.54419714  0.59417629  0.58399674  0.37848552  0.46973537  0.56316622
  0.44795484  0.4284225   0.41034534  0.46075243  0.60010022  0.6027513
  0.43957152  0.66977485  0.32403217  0.35901842  0.17025956  0.6249035
  0.49610941  0.70867305  0.63161055  0.57016048  0.77479902  0.59382748
  0.59704294  0.61372098  0.52449654  0.78738511  0.82606442  0.80597185
  0.5298308   0.67791769  0.32628988  0.62885986  0.69838738  0.65561966
  0.67790521  0.51783238 -0.1070718   0.45097762  0.36749504  0.24292242
  0.45150874  0.57062927  0.4555621   0.48499268  0.49454515  0.39996137
  0.40682     0.59667445  0.51592891  0.1202623   0.44862352  0.67326006
  0.74031491  0.25016936  0.35509721  0.388038

In [26]:
corr_filter = np.load('../outputs/first_approach/motifs/correlations_per_filter.npy')
print(corr_filter.shape)
#print(corr_filter)

(4, 300)


In [27]:
infl = np.load('../outputs/first_approach/motifs/influence.npy')
print(infl.shape)
print(infl)

(300,)
[1.06637769e-07 1.88214652e-07 5.85818634e-08 1.63259709e-07
 2.46407326e-07 1.19791469e-07 2.66905133e-08 2.40893426e-07
 1.98626859e-07 1.01209231e-07 1.33714193e-07 6.95548972e-08
 2.32257736e-07 7.02706632e-08 3.97032948e-07 6.05702951e-08
 4.39679260e-08 1.28310142e-07 2.30410839e-07 5.41441200e-07
 1.47290983e-07 2.11355077e-08 5.34058788e-08 1.00534061e-06
 1.91756856e-07 1.36954842e-07 1.30834185e-07 1.18577096e-07
 1.94126163e-08 3.09706629e-07 8.07221689e-08 1.55142850e-07
 6.35365196e-08 1.31972679e-07 2.76465086e-07 1.01283946e-07
 2.94435637e-08 6.08434691e-08 7.66887126e-08 4.13737483e-07
 7.96533715e-08 2.24572912e-09 9.53475706e-08 5.60325242e-08
 7.45111346e-08 1.17896093e-07 6.97362070e-08 1.54267003e-07
 1.89749798e-07 1.92105551e-07 4.73060523e-08 2.41113845e-08
 5.36700165e-07 1.65246702e-08 9.10137078e-08 1.49676195e-07
 3.40157582e-07 2.16780632e-07 3.24955815e-07 5.45600470e-07
 2.41558239e-07 7.80055891e-08 2.00686572e-07 7.26958767e-08
 8.20406232e-08 1

In [28]:
infl_ocr = np.load('../outputs/first_approach/motifs/influence_by_OCR.npy')
print(infl_ocr.shape)
print(infl_ocr)

(4, 300)
[[1.60859885e-07 1.66444582e-07 1.38686120e-08 ... 5.56852008e-10
  1.73122197e-07 2.33318312e-08]
 [1.39155826e-10 2.37573326e-08 5.31088516e-08 ... 1.61771687e-08
  4.59046313e-08 3.21955786e-08]
 [2.52579880e-07 8.13380674e-09 3.05329318e-08 ... 2.98002981e-08
  2.16931859e-07 3.71149841e-07]
 [1.29721569e-08 5.54522886e-07 1.36817058e-07 ... 7.73108852e-08
  5.25405823e-07 5.83212453e-08]]


In [29]:
filt_cell_infl = np.load('../outputs/first_approach/motifs/filter_cellwise_influence.npy')
print(filt_cell_infl.shape)
print(filt_cell_infl)

(300, 81)
[[2.78251400e-06 2.69042794e-05 1.48753279e-05 ... 1.21165858e-05
  2.51902122e-07 1.52528582e-05]
 [2.00672798e-06 4.20817523e-05 1.65379079e-05 ... 1.70422336e-05
  3.04177007e-07 3.75397758e-05]
 [1.34896652e-06 2.86930172e-05 1.40840957e-05 ... 2.27190558e-05
  1.01004264e-06 2.77765012e-05]
 ...
 [1.96082601e-06 1.88911326e-05 1.48243789e-05 ... 1.84044293e-05
  1.12887392e-06 1.81113719e-05]
 [7.97796929e-06 5.71771379e-05 3.75082309e-05 ... 1.47347837e-05
  9.74385216e-07 5.44135110e-05]
 [2.23765642e-06 1.15168805e-05 8.62323122e-06 ... 6.44846386e-06
  3.56151190e-07 8.70548593e-06]]


In [30]:
cell_infl_ocr = np.load('../outputs/first_approach/motifs/cellwise_influence_by_OCR.npy')
print(cell_infl_ocr.shape)
print(cell_infl_ocr)

(4, 300, 81)
[[[1.08684475e-07 3.33902381e-06 7.97451867e-06 ... 2.96143199e-09
   5.51276798e-07 1.41392438e-05]
  [7.40139058e-06 3.37265847e-05 1.82913773e-05 ... 3.89154138e-07
   3.24758986e-07 4.01642028e-05]
  [1.15642513e-06 1.35533019e-05 7.98125257e-06 ... 5.19621153e-06
   1.49919299e-06 8.86864382e-06]
  ...
  [2.93717835e-06 1.37263251e-05 9.97072493e-06 ... 3.33413851e-07
   1.36498898e-06 1.99882779e-05]
  [5.23296103e-06 8.38103006e-05 3.49504044e-05 ... 1.88634876e-05
   4.57745086e-09 8.04158626e-05]
  [1.56779140e-06 8.68113602e-06 1.37499592e-06 ... 1.26503926e-06
   1.28628717e-08 1.51008753e-05]]

 [[2.58192495e-06 2.59086883e-05 9.64981609e-06 ... 1.79792187e-06
   4.37121344e-08 9.19789727e-06]
  [4.24230429e-09 4.19376120e-05 3.46323259e-06 ... 3.15388897e-05
   1.22488501e-07 4.45000078e-05]
  [1.62084837e-06 8.17582113e-05 2.55681971e-05 ... 7.90369231e-05
   1.05556921e-06 8.46176408e-05]
  ...
  [1.03126894e-08 2.42152564e-05 3.23812819e-05 ... 2.35135412e-

In [31]:
nsew_filter = open('../outputs/first_approach/motifs/nseqs_per_filters.txt','rt')
print(nsew_filter.readline())
print(nsew_filter)
nsew_filter.close()

6.500000000000000000e+01

<_io.TextIOWrapper name='../outputs/first_approach/motifs/nseqs_per_filters.txt' mode='rt' encoding='UTF-8'>


In [32]:
mean_ocr_activ = np.load('../outputs/first_approach/motifs/mean_OCR_activation.npy')
print(mean_ocr_activ.shape)
print(mean_ocr_activ)

(300, 81)
[[2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]
 [2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]
 [2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]
 ...
 [2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]
 [2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]
 [2.41208673 1.7952466  2.08129239 ... 3.75345469 2.96064448 2.63983297]]


In [33]:
n_actv_ocr = np.load('../outputs/first_approach/motifs/n_activated_OCRs.npy')
print(n_actv_ocr.shape)
print(n_actv_ocr)

(300,)
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 3. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 3. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 3. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 1. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 3. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 3. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 3. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 3. 4.
 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]


In [34]:
ocr_matrix = np.load('../outputs/first_approach/motifs/OCR_matrix.npy')
print(ocr_matrix.shape)
print(ocr_matrix)

(300, 4)
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 ...
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
