In [None]:
import h5py
# from deeprankCore.Data import save_hdf5_keys
import torch
from models import CnnClassificationBaseline
from deeprank.learn import DataSet, NeuralNet
import torch.utils.data as data_utils
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
from typing import List

In [None]:
# Copied from deeprankcore.Dataset to avoid having to install that

def save_hdf5_keys(
    f_src_path: str,
    src_ids: List[str],
    f_dest_path: str,
    hardcopy = False
    ):
    """Save references to keys in data_ids in a new hdf5 file.
    Parameters
    ----------
    f_src_path : str
        The path to the hdf5 file containing the keys.
    src_ids : List[str]
        Keys to be saved in the new hdf5 file.
        It should be a list containing at least one key.
    f_dest_path : str
        The path to the new hdf5 file.
    hardcopy : bool, default = False
        If False, the new file contains only references.
        (external links, see h5py ExternalLink class) to the original hdf5 file.
        If True, the new file contains a copy of the objects specified in data_ids
        (see h5py HardLink class).
        
    """
    if not all(isinstance(d, str) for d in src_ids):
        raise TypeError("data_ids should be a list containing strings.")

    with h5py.File(f_dest_path,'w') as f_dest, h5py.File(f_src_path,'r') as f_src:
        for key in src_ids:
            if hardcopy:
                f_src.copy(f_src[key],f_dest)
            else:
                f_dest[key] = h5py.ExternalLink(f_src_path, "/" + key)

In [None]:
DATA_PATH = '/Users/aronjansen/Documents/deeprankData/'
hdf5_path = DATA_PATH + '000_hla_drb1_0101_15mers.hdf5'
sample_path = DATA_PATH + 'one_sample.hdf5'
pretrained_model = 'best_valid_model.pth.tar'

In [None]:
save_hdf5_keys(hdf5_path, ['BA_105966'], sample_path, hardcopy = True)

In [None]:
tst = h5py.File(sample_path, 'r')
print(tst['BA_105966'].keys())
print(tst['BA_105966']['features_raw'].keys())
tst['BA_105966']['features']['PSSM_ARG'][:]

In [None]:
model = NeuralNet(sample_path, CnnClassificationBaseline, pretrained_model = pretrained_model, outdir='./out/')
model.test()

In [None]:
model.data_set.input_shape

In [None]:
# step by step without using NeuralNet

In [None]:
data_set = DataSet(
    sample_path,
    chain1="M",
    chain2="P",
    process=False)

state = torch.load(pretrained_model,  map_location='cpu')

data_set.select_feature = state['select_feature']
data_set.select_target = state['select_target']

data_set.pair_chain_feature = state['pair_chain_feature']
data_set.dict_filter = state['dict_filter']

data_set.normalize_targets = state['normalize_targets']
if data_set.normalize_targets:
    data_set.target_min = state['target_min']
    data_set.target_max = state['target_max']

data_set.normalize_features = state['normalize_features']
if data_set.normalize_features:
    data_set.feature_mean = state['feature_mean']
    data_set.feature_std = state['feature_std']

data_set.transform = state['transform']
data_set.proj2D = state['proj2D']
data_set.target_ordering = state['target_ordering']
data_set.clip_features = state['clip_features']
data_set.clip_factor = state['clip_factor']
data_set.mapfly = state['mapfly']
data_set.grid_info = state['grid_info']

data_set.process_dataset()

In [None]:
net = CnnClassificationBaseline(data_set.input_shape)
device = torch.device("cpu")

net.to(device)

if state['cuda']:
    for paramname in list(state['state_dict'].keys()):
        paramname_new = paramname.lstrip('module.')
        if paramname != paramname_new:
            state['state_dict'][paramname_new] = \
                state['state_dict'][paramname]
            del state['state_dict'][paramname]

net.load_state_dict(state['state_dict'])

In [None]:
optimizer = optim.SGD(
    net.parameters(),
    lr=0.005,
    momentum=0.9,
    weight_decay=0.001)
optimizer.load_state_dict(state['optimizer'])

In [None]:
index = list(range(data_set.__len__()))
sampler = data_utils.sampler.SubsetRandomSampler(index)
loader = data_utils.DataLoader(data_set, sampler=sampler)

In [None]:
for idx, data in enumerate(loader):
    print(data['feature'].shape)

In [None]:
net.train(mode=False)
torch.set_grad_enabled(False)

In [None]:
for d in loader:
    inputs = d['feature']
    targets = d['target']
    mol = d['mol']
    inputs, targets = Variable(inputs).float(), Variable(targets).float()
    targets = targets.long()
    print(mol)
    print(inputs.shape)
    print(targets)
    outputs = net(inputs)
    print(outputs)
    targets = targets.view(-1)
    print(targets)
    # F.softmax(torch.FloatTensor(out), dim=1).data.numpy()[:, 1]

In [None]:
inputs_toplot = np.squeeze(np.array(inputs))
inputs_toplot.shape


In [None]:
one_vol = inputs_toplot[0, :, :, :]
one_vol.shape