In [2]:
import pickle
import torch
from tqdm import tqdm
import os

In [3]:
path = "AlphaFoldDBEncoded"
AlphaFoldResults = [f[:-3] for f in os.listdir(path) if (os.path.isfile(path + "/" + f) & f.__contains__(".pt") )]
# torch.load("AlphaFoldDBEncoded/" + proteinID + ".pt")

In [7]:
AlphaFoldResults.__len__()

3432

In [9]:
torch.load(path + "/" + AlphaFoldResults[0] + ".pt")

tensor([[ 2.2407e-01,  3.9482e-02, -1.8339e-01,  ..., -3.7869e-02,
          1.3181e-01, -7.8958e-01],
        [-1.3097e+00, -8.8381e-01, -9.4595e-01,  ...,  1.1374e-01,
          3.9836e-01, -5.5720e-01],
        [-1.7352e+00, -9.2590e-01, -1.2808e+00,  ...,  2.8873e-01,
          9.2625e-02, -1.1629e-01],
        ...,
        [-4.8368e-01,  3.2090e-01,  1.1987e-03,  ...,  4.6423e-01,
         -1.0691e-01, -5.1099e-01],
        [ 4.4373e-01, -1.3974e-01,  2.5706e-01,  ...,  4.5446e-01,
         -1.6559e-01, -4.2754e-01],
        [ 4.7123e-01, -2.8098e-01,  3.8144e-01,  ...,  3.0119e-01,
         -3.3479e-02, -4.0293e-01]])

In [11]:
A = torch.zeros((3,3,3))

In [5]:
# A.squeeze().shape
import json
import pandas as pd

In [6]:
BlackListProteins = ['Q841A2', 'D6R8X8', 'Q8I2A6']
proteinIDs = []
AlphaFoldResultsSet = set(AlphaFoldResults)

with open("DeepTMHMM.partitions.json","r") as FileObj:
    CVs = json.loads(FileObj.read())
    for cv in CVs.keys():
        cvProteins = CVs[cv]
        for idx, protein in enumerate(cvProteins):
          if (protein["sequence"].__len__() > 1_500):
            continue
          if protein["id"] in BlackListProteins:
            continue
          if (protein["id"] in AlphaFoldResultsSet):
              proteinIDs += [[protein["id"], protein, protein["sequence"], protein["labels"], cv, idx]]


columns = ["proteinID", "protein", "sequence", "labels", "CV", "index"]
proteinMap = pd.DataFrame(proteinIDs, columns=columns)
proteinMap.index = proteinMap["proteinID"].values

In [7]:
cv0Indices = list(proteinMap[proteinMap.CV == "cv0"].index.values)

In [8]:
from torch.utils.data import Dataset
unique_label_ids = ['B','I','M','O','P','S','X']

encode_length = 1_500

def EncodeLabel(label):
  return [unique_label_ids.index(lab) for lab in label]

class CustomProteinDataset(Dataset):
  def __init__(self, protein_code, transform=None, target_transform=None):
    self.proteins = proteinMap.loc[protein_code][["proteinID","labels"]]
    self.transform = transform
    self.target_transform = target_transform
    self.proteinsEncoded = []
    self.labels = []

    print("encoding proteins")
    for index, protein in tqdm(self.proteins.iterrows()):
      latent = torch.load(path + "/" + protein["proteinID"] + ".pt")
      self.proteinsEncoded += [ torch.cat([latent, torch.zeros((encode_length-latent.shape[0],latent.shape[1]))],0) ]
      self.labels += [torch.tensor(EncodeLabel(protein["labels"].ljust(encode_length,'X')))]

    self.proteinsEncoded = torch.stack(self.proteinsEncoded,0)
    self.labels = torch.stack(self.labels,0)

    # print(self.labels.shape)


  def __len__(self):
    return self.proteinsEncoded.shape[0]

  def __getitem__(self, idx):
    encodeLatent = self.proteinsEncoded[idx]
    label = self.labels[idx]

    if self.transform:
      encodeLatent = self.transform(encodeLatent)
    if self.target_transform:
      label = self.target_transform(label)
    return encodeLatent, label

In [23]:
train_dataset = CustomProteinDataset(cv0Indices[0:100])

encoding proteins


100it [00:00, 246.62it/s]


torch.Size([100, 1500])


In [24]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=list,
)

In [29]:
def make_batch(batch):
    labels = [b[1] for b in batch]
    latent = [b[0] for b in batch]
    return torch.stack(latent), torch.stack(labels)
    

In [31]:
for batch in train_loader:
    print(batch)
    inputs, targets = make_batch(batch)

    # print(inputs.shape)
    # print(targets.shape)
    # print(batch[0][1])

    # print(targets.__len__(), targets[0].shape)
    print("")

[(tensor([[ 0.0168,  0.1001, -0.2247,  ...,  0.4565,  0.1462, -0.4621],
        [-1.4382, -0.7791, -0.8029,  ...,  0.3053,  0.0196,  0.1065],
        [-1.3222, -0.6228, -0.8992,  ...,  0.2575, -0.0287,  0.0280],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]), tensor([5, 5, 5,  ..., 6, 6, 6])), (tensor([[-0.1411,  0.1488, -0.1018,  ..., -0.0033,  0.0910, -0.3126],
        [-1.1729, -0.5648, -0.5289,  ...,  0.1490, -0.2954,  0.0986],
        [-1.3112, -0.5239, -0.5366,  ...,  0.3141, -0.5098, -0.1722],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]), tensor([3, 3, 3,  ..., 6, 6, 6])), (tensor([[-0.0085,  0.2434, -0.0823,  ...,  0.4764,  0.09

In [58]:
print(train_dataset.proteinsEncoded.shape)
print(train_dataset.labels.shape)


torch.Size([10, 1500, 512])
torch.Size([10, 1500])
