In [1]:
import torch
import torch.nn as nn
import json
import torch.nn.functional as F

In [2]:
with open('../q8-protein-structure-prediction/preprocess/CB513.json', 'r') as f:
    d = json.load(f)

In [3]:
len(d)

514

In [4]:
example = d['1']

In [5]:
for key in example.keys():
    print(f"{key} : ",end='')
    
    entry = example[key]
    
    if (type(entry) == list):
        print(f"{len(entry)} items")
    elif (type(entry) == int):
        print(entry)
    else:
        print(entry)

protein_encoding : 35700 items
protein_length : 87
secondary_structure_onehot : 6300 items
secondary_structure : LEEEEELLTTTSLLHHHHHHHHHHHHTTLLEEEEESLSBTTBLLHHHHHHHHHHHTLSLSSSLLSLEEELTTSLEEESHHHHHHHTL
primary_structure : MFKVYGYDSNIHKCVYCDNAKRLLTVKKQPFEFINIMPEKGVFDDEKIAELLTKLGRDTQIGLTMPQVFAPDGSHIGGFDQLREYFK


In [6]:
example['protein_encoding']

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 2.35,
 0.22,
 4.43,
 1.23,
 5.71,
 0.38,
 0.32,
 0.0025224474607696942,
 0.0015011822567369917,
 0.00024845508183933427,
 0.0007921241383950947,
 0.007174655609074738,
 0.00037432940001475007,
 0.0012795071344630135,
 0.023660578155461204,
 0.0015938622283030088,
 0.10340045145824957,
 0.9999878722231386,
 0.0006683700347954174,
 0.000471085384762809,
 0.004363968112352542,
 0.001578028060403861,
 0.00138592583863808,
 0.0033682097912468285,
 0.014774031693273055,
 0.0014862674441058345,
 0.2689414213699951,
 0.002375900566444986,
 2.1233990830315563,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 2.94,
 0.29,
 5.89,
 1.79,
 5.67,
 0.3,
 0.38,
 0.8234647252208833,
 0.027652422322823136,
 0.021248322711013484,
 0.03455623028627651,
 0.8975229665559027,
 0.0171240333157277

In [7]:
# vincent's dataloader

import numpy as np
import torch
import torch.utils.data as data

class WaveNetDataset(data.Dataset):
    def __init__(self, protein_data, ids):

        data_len = len(ids)

        # data_len, 700, 22 one hot
        all_encodings = np.zeros([data_len, 700, 22])
        
        # data_len, 700 x 21 PSSM
        all_pssm = np.zeros([data_len, 700, 21])
        all_lengths = []

        for i, id in enumerate(ids):
            id = str(id)
            if i % 250 == 0:
                print("Loading {0}/{1} proteins".format(i, len(ids)))

            d = protein_data[id]
            protein_length = d["protein_length"]
            all_lengths.append(protein_length)
            
            reshaped = np.array(d["protein_encoding"]).reshape([700, -1])

            all_encodings[i, :] = reshaped[:, 0:22]
            all_pssm[i, :] = reshaped[:, 29:50]

        self.all_encodings = all_encodings.astype(np.uint8)
        self.all_pssm = all_pssm.astype(np.float32)
        self.all_lengths = np.array(all_lengths).astype(np.int32)

        print(len(all_pssm), len(all_pssm), len(all_lengths))

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        encoding = self.all_encodings[index]
        pssm = self.all_pssm[index]
        length = self.all_lengths[index]

        return encoding, pssm, length

    def __len__(self):
        return len(self.all_encodings)


def get_loader(protein_data, ids, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader"""

    protein = WaveNetDataset(protein_data, ids)

    # def collate_fn(data):
    #     return data

    data_loader = torch.utils.data.DataLoader(dataset=protein,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers, )
    # collate_fn=collate_fn)
    return data_loader, len(protein)

In [8]:
len_train = len(d)

ids = np.random.choice(len_train, len_train, replace=False)


val_loader, len_val = get_loader(protein_data=d,
                                 ids=[0, 1, 2],
                                 batch_size=5,
                                 num_workers=1,
                                 shuffle=False)

Loading 0/3 proteins
3 3 3


In [9]:
val_loader

<torch.utils.data.dataloader.DataLoader at 0x7fd978d9cc18>

In [12]:
batch = next(iter(val_loader))

In [16]:
print(batch[0].shape)
print(batch[1].shape)
print(batch[2].shape)

torch.Size([3, 700, 22])
torch.Size([3, 700, 21])
torch.Size([3])


In [17]:
batch[0] # encoding (one-hot)

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 1]]], dtype=torch.uint8)

In [18]:
batch[1] # pssm

tensor([[[2.1417e-01, 1.7080e-01, 2.2977e-02,  ..., 2.7091e-01,
          2.6894e-01, 9.2344e-01],
         [8.3173e-02, 1.0457e-02, 4.2290e-02,  ..., 4.1520e-03,
          1.1920e-01, 9.1937e-03],
         [3.4751e-01, 2.5333e-02, 9.9451e-01,  ..., 1.1096e-02,
          5.0000e-01, 2.5087e-02],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[2.5224e-03, 1.5012e-03, 2.4846e-04,  ..., 1.4863e-03,
          2.6894e-01, 2.3759e-03],
         [8.2346e-01, 2.7652e-02, 2.1248e-02,  ..., 1.8524e-02,
          2.6894e-01, 5.1500e-01],
         [1.2675e-01, 3.3017e-03, 3.6124e-01,  ..., 2.4974e-03,
          2.6894e-01, 1.7464e-02],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.000

In [21]:
batch[2] # seq length

tensor([ 67,  87, 449], dtype=torch.int32)