In this notebook I will try to implement a new _Dataset_ object and a new _Dataloader_ to deal with the new memory contraints of my computer. I cannot infact just load all the dataset on _RAM_ and the batch it to the model. The main sources for the implementation are this [post](https://teddykoker.com/2020/12/dataloader/) and this [page](https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel). Actually the latter is just an explanation of what happens under the hood in the former, which will be the one that I actually follow in this implementation.

In [1]:
from torch.utils.data import Dataset
import os
from ioutils import read_fasta, read_encodings
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_embedding(q):

    embedding = torch.nn.Embedding(q+1, q).requires_grad_(False)
    embedding.weight.data.zero_()

    embedding.weight[:q, :q] = torch.eye(q)

    return embedding


class EncodedProteinDataset_new(Dataset):

    def __init__(self, msa_folder, encodings_folder, noise=0.0, max_msas=None, max_seqs=None):
        #print("I am here")
        self.msa_folder = msa_folder
        self.encodings_folder = encodings_folder
        self.encodings_paths = []
        self.msas_paths = []
        self.q = None
        self.encoding_dim = None
        self.noise = noise

        # read encoding file names
        encoding_files = {s[:7]: s for s in os.listdir(encodings_folder)}

        # parse data in folder
        for numseq_file in filter(lambda file: file.endswith('pt'), os.listdir(msa_folder)):
            #print(numseq_file)
                #print("Sono qui")
            print(f"Counter is:{counter}, Counter fail 1:{counter_fail1}, Counter fail 2:{counter_fail2}, length data:{len(self.data)}", end="\r")
            counter+=1
            id = numseq_file[:7]
            if id not in encoding_files:
                raise ValueError('No encoding file found for MSA file: ' + numseq_file)

            numseq_path = os.path.join(msa_folder, numseq_file)
            encoding_path = os.path.join(encodings_folder, encoding_files[id])

            if not os.path.isfile(encoding_path):
                ## This does not give problems
                print("{} does not exist, skipping {}".format(encoding_path, numseq_path))
                continue

            msa = torch.load(numseq_path).type(torch.int) 
            if msa.shape[1]>=512:
                continue

            self.encodings_paths.append(encoding_path)
            self.msas_paths.append(numseq_path)

            if max_msas is not None and len(self.msas_paths) >= max_msas:
                break
        

    def __len__(self):
        return len(self.msas_paths)

    def __getitem__(self, idx):
        encoding_path = self.encodings_paths[idx]
        msa_path = self.msas_paths[idx]

        msa = torch.load(msa_path).type(torch.int)  ## For later calculation, embedding does not work with uint or with Int8!
        encodings = torch.tensor(read_encodings(encoding_path, trim=False))
        if self.noise > 0:
            encodings = encodings + self.noise*torch.randn(encodings.shape)
        if self.encoding_dim is None:
                self.encoding_dim = encodings.shape[1]
        else:
            assert self.encoding_dim == encodings.shape[1], "Inconsistent encoding dimension"
        
        N = msa.shape[1]
        if N != encodings.shape[0]: 
            "Inconsistent encoding and sequence length for numerical sequence file: " + numseq_file
            
        #if N < 512:
        return msa, encodings 
        #else:
        #    return ## I return nothing in this case


def collate_fn_new(batch, q, batch_msa_size):
    """ Collate function for data loader
    """
    # subsample msas, here batch_msa_size is referred to the number of MSAS the model sees when training the Potts model.
    msas = [tuple[0][torch.randint(0, tuple[0].shape[0], (batch_msa_size, )), :] for tuple in batch]

    # padding works in the second dimension
    msas = [torch.transpose(msa, 1, 0) for msa in msas]

    encodings = [tuple[1] for tuple in batch]

    msas = pad_sequence(msas, batch_first=True, padding_value=q)
    encodings = pad_sequence(encodings, batch_first=True, padding_value=0.0)

    # permute msa dimension back
    msas = torch.transpose(msas, 2, 1)

    # the padding mask is the same for all sequences in an msa, so we can just take the first one
    padding_mask = msas[:, 0, :] == q

    return msas, encodings, padding_mask