In [2]:
import numpy as np
import h5py
from tqdm import tqdm
import os
import torch
from torch.utils.data import Dataset, DataLoader
from selfpeptide.utils.constants import *

import time

In [23]:
class Self_NonSelf_PeptideDataset(Dataset):
    def __init__(self, hdf5_dataset_fname, gen_size=1000, 
                 val_size=0,
                 negative_label=-1):
        self.hdf5_dataset_fname = hdf5_dataset_fname
        self.gen_size = gen_size
        self.val_size = val_size//2
        self.negative_label = negative_label        

        if not os.path.exists(self.hdf5_dataset_fname):
            raise FileNotFoundError("Specify a valid HDF5 file for the dataset")
        self._get_n_peptides()
        
        self.idx_self = self.val_size
        self.idx_nonself = self.val_size
    
        self._load_peptides(gen_size)
        
    def _get_n_peptides(self):
        with h5py.File(self.hdf5_dataset_fname, 'r') as f:
            self.n_self_peptides = len(f["reference_human_peptides"])
            self.n_nonself_peptides = len(f["nonself_peptides"])

                    
    def _load_peptides(self, n_peptides):
        peptides = torch.zeros((n_peptides, MAX_PEPTIDE_LEN)).long()
        labels = torch.ones(n_peptides).long()
        

        with h5py.File(self.hdf5_dataset_fname, 'r') as f:
            peptides[::2, :] = torch.from_numpy(f["reference_human_peptides"][self.idx_self:self.idx_self+n_peptides//2])
            peptides[1::2, :] = torch.from_numpy(f["nonself_peptides"][self.idx_nonself:self.idx_nonself+n_peptides//2])
            
        labels[1::2] = self.negative_label
        
        self.peptides = peptides.long()
        self.labels = labels.long()
        
        self.idx_self += n_peptides//2
        self.idx_nonself += n_peptides//2
        
        
    def refresh_data(self):
        if self.n_self_peptides-self.idx_self<self.gen_size:
            self.idx_self = self.val_size
        if self.n_nonself_peptides-self.idx_nonself<self.gen_size:
            self.idx_nonself = self.val_size
        self._load_peptides(self.gen_size)
    
    def __len__(self):
        return len(self.peptides)
    
    def __getitem__(self, idx):
        return self.peptides[idx], self.labels[idx]

In [30]:
hdf5_file = "../processed_data/pre_tokenized_peptides_dataset.hdf5"
dset = Self_NonSelf_PeptideDataset(hdf5_file, gen_size=10000, val_size=20)

In [31]:
dset[0]

(tensor([10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22]), tensor(1))

In [19]:
dloader = DataLoader(dset, batch_size=4)

In [22]:
next(iter(dloader))

[tensor([[13, 15, 16,  9,  7,  0, 17, 12, 22, 22, 22, 22],
         [17,  2, 15,  6,  9,  7, 14,  8, 22, 22, 22, 22],
         [17,  2, 17,  3,  9,  7,  2,  8, 15, 16, 11, 14],
         [ 2, 17,  8,  1, 15, 15,  1,  3,  5,  9, 22, 22]]),
 tensor([ 1, -1,  1, -1])]

In [21]:
dset.refresh_data()

In [29]:
dset[20]

(tensor([10,  2, 20, 17, 17,  5,  0, 14, 16,  4, 22, 22]), tensor(1))