In [None]:
# copied from github below for use in my project at work
# https://github.com/kipoi/models/blob/master/Basset/pretrained_model_reloaded_th.py
# see paper at
# http://kipoi.org/models/Basset/

In [2]:
# imports
import torch
from torch import nn
import twobitreader
from twobitreader import TwoBitFile

print("got pytorch version of {}".format(torch.__version__))


got pytorch version of 1.5.1


In [3]:
# import relative libraries
import sys
sys.path.insert(0, '../')
import dcc_basset_lib



have pytorch version 1.5.1
have numpy version 1.19.0


In [4]:
class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))
        

In [5]:
# load the Basset model
pretrained_model_reloaded_th = nn.Sequential( # Sequential,
        nn.Conv2d(4,300,(19, 1)),
        nn.BatchNorm2d(300),
        nn.ReLU(),
        nn.MaxPool2d((3, 1),(3, 1)),
        nn.Conv2d(300,200,(11, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        nn.Conv2d(200,200,(7, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear,
        nn.Sigmoid(),
    )

print("got model of type {}".format(type(pretrained_model_reloaded_th)))

got model of type <class 'torch.nn.modules.container.Sequential'>


In [6]:
# print out the model
print(pretrained_model_reloaded_th)

Sequential(
  (0): Conv2d(4, 300, kernel_size=(19, 1), stride=(1, 1))
  (1): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(300, 200, kernel_size=(11, 1), stride=(1, 1))
  (5): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU()
  (7): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(200, 200, kernel_size=(7, 1), stride=(1, 1))
  (9): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ReLU()
  (11): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)
  (12): Lambda()
  (13): Sequential(
    (0): Lambda()
    (1): Linear(in_features=2000, out_features=1000, bias=True)
  )
  (14): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (15): ReLU()

In [7]:
# load the weights
# sd = torch.load('/home/javaprog/Data/Broad/Basset/Model/predictions.h5')
sd = torch.load('/home/javaprog/Data/Broad/Basset/Model/pretrained_model_reloaded_th.pth')
pretrained_model_reloaded_th.load_state_dict(sd)



<All keys matched successfully>

In [26]:
# summarize the model - LARGE
model_weights = pretrained_model_reloaded_th.state_dict()


In [8]:
# make the model eval
pretrained_model_reloaded_th.eval()

# better summary
print(pretrained_model_reloaded_th)

Sequential(
  (0): Conv2d(4, 300, kernel_size=(19, 1), stride=(1, 1))
  (1): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(300, 200, kernel_size=(11, 1), stride=(1, 1))
  (5): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU()
  (7): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(200, 200, kernel_size=(7, 1), stride=(1, 1))
  (9): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ReLU()
  (11): MaxPool2d(kernel_size=(4, 1), stride=(4, 1), padding=0, dilation=1, ceil_mode=False)
  (12): Lambda()
  (13): Sequential(
    (0): Lambda()
    (1): Linear(in_features=2000, out_features=1000, bias=True)
  )
  (14): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (15): ReLU()

In [9]:
# load the chromosome data
# get the genome file
hg19 = TwoBitFile('../../../../../../Data/Broad/Basset/TwoBitReader/hg19.2bit')

print("two bit file of type {}".format(type(hg19)))

# get the chrom
chromosome = hg19['chr11']
position = 95311422

# load the data
ref_sequence, alt_sequence = dcc_basset_lib.get_ref_alt_sequences(position, 300, chromosome, 'C')

print("got ref sequence one hot of type {} and shape {}".format(type(ref_sequence), len(ref_sequence)))
print("got alt sequence one hot of type {} and shape {}".format(type(alt_sequence), len(alt_sequence)))




two bit file of type <class 'twobitreader.TwoBitFile'>
got ref sequence one hot of type <class 'str'> and shape 600
got alt sequence one hot of type <class 'str'> and shape 600


In [10]:
# build list and transform into input
sequence_list = []
sequence_list.append(ref_sequence)
sequence_list.append(alt_sequence)

# get the np array of right shape
sequence_one_hot = dcc_basset_lib.get_one_hot_sequence_array(sequence_list)
print("got sequence one hot of type {} and shape {}".format(type(sequence_one_hot), sequence_one_hot.shape))



got sequence one hot of type <class 'numpy.ndarray'> and shape (2, 600, 4)


In [11]:
# create a pytorch tensor
tensor = torch.from_numpy(sequence_one_hot)

print("got pytorch tensor with type {} and shape {} and data type \n{}".format(type(tensor), tensor.shape, tensor.dtype))


got pytorch tensor with type <class 'torch.Tensor'> and shape torch.Size([2, 600, 4]) and data type 
torch.float64


In [14]:
# add a dimension to the tensor and convert to float 32
tensor_input = torch.unsqueeze(tensor, 3)
tensor_input = torch.transpose(tensor_input, 1, 2)
tensor_input = tensor_input.to(torch.float)

print("got pytorch tensor with type {} and shape {} and data type \n{}".format(type(tensor_input), tensor_input.shape, tensor_input.dtype))


got pytorch tensor with type <class 'torch.Tensor'> and shape torch.Size([2, 4, 600, 1]) and data type 
torch.float32


In [15]:
# run the model predictions
pretrained_model_reloaded_th.eval()
predictions = pretrained_model_reloaded_th(tensor_input)

print("got predictions of type {} and shape {} and result \n{}".format(type(predictions), predictions.shape, predictions))


got predictions of type <class 'torch.Tensor'> and shape torch.Size([2, 164]) and result 
tensor([[1.0749e-02, 1.0213e-03, 1.4848e-02, 6.5434e-03, 8.6458e-02, 3.0645e-02,
         4.6448e-03, 5.1843e-03, 4.9607e-03, 7.4255e-03, 9.2591e-03, 7.7353e-03,
         2.2206e-02, 3.3359e-03, 3.6842e-03, 2.9438e-02, 4.2232e-03, 9.1712e-04,
         1.6221e-03, 1.0524e-02, 3.6748e-03, 1.8557e-03, 1.6109e-03, 1.3017e-03,
         3.6147e-03, 1.0297e-02, 6.1578e-02, 6.6933e-02, 1.0849e-02, 7.4299e-02,
         6.2833e-03, 1.7966e-02, 3.4399e-02, 6.2949e-03, 2.3675e-03, 1.8804e-03,
         9.4161e-03, 2.7993e-02, 2.4177e-03, 8.7791e-03, 8.7258e-04, 6.5026e-04,
         1.0449e-03, 4.0967e-04, 5.7718e-04, 7.3653e-04, 1.3984e-02, 6.1543e-04,
         4.3805e-04, 1.3010e-02, 8.3348e-03, 1.8281e-02, 9.5618e-03, 2.8722e-02,
         2.6964e-02, 1.4387e-02, 6.9775e-04, 2.8968e-03, 1.2265e-03, 7.0210e-03,
         2.7426e-03, 6.4689e-04, 4.3560e-03, 1.0491e-03, 6.2473e-04, 2.2319e-03,
         3.2119e-03