In [1]:
from pathlib import Path
from pynvml import *

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
os.chdir(str(curdir.parent.absolute()))
curdir = Path(os.getcwd())

from src.utils.data import (
    load_model,
    seed_everything,
    log_gpu_memory_usage
)
from src.utils.main_utils import get_or_generate_vocabularies,  get_or_generate_label_embeddings, get_or_generate_sequence_embeddings, validate_arguments
from src.data.datasets import ProteinDataset, calculate_pos_weight, create_multiple_loaders, calculate_label_weights
from src.models.ProTCLTrainer import ProTCLTrainer
from src.models.ProTCL import ProTCL
from src.models.protein_encoders import ProteInfer
from src.utils.evaluation import EvalMetrics
from src.utils.models import count_parameters_by_layer, sigmoid_bias_from_prob,load_checkpoint
from src.utils.configs import get_setup
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
import torch
import wandb
import os
import argparse
import json
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from src.data.collators import collate_variable_sequence_length
import mlflow
import loralib as lora
import random


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)


2023-12-16 10:06:14 PST INFO Logging to ./outputs/logs/2023-12-16_10-06-14_Test.log and console...
2023-12-16 10:06:14 PST INFO Using device: cuda:0
2023-12-16 10:06:14 PST INFO {
    "TRAIN_BATCH_SIZE": 32,
    "VALIDATION_BATCH_SIZE": 64,
    "TEST_BATCH_SIZE": 64,
    "GRID_SAMPLER_LABEL_BATCH_SIZE": 1000,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.0003,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 3,
    "OUTPUT_MLP_NUM_LAYERS": 3,
    "OUTPUT_NEURON_PROBABILITY_BIAS": null,
    "OUTPUT_MLP_BATCHNORM": true,
    "PROJECTION_HEAD_NUM_LAYERS": 3,
    "PROJECTION_HEAD_HIDDEN_DIM_SCALE_FACTOR": 2,
    "FEATURE_FUSION": "concatenation",
    "LABEL_EMBEDDING_POOLING_METHOD": "mean",
    "

In [5]:
from torch.utils.data import Dataset, DataLoader
from src.utils.data import read_fasta, read_json, get_vocab_mappings, read_pickle
from src.utils.models import tokenize_labels, get_label_embeddings
from typing import Dict
import pandas as pd
import logging
from typing import List
from src.data.collators import collate_variable_sequence_length
from collections import defaultdict
from joblib import Parallel, delayed, cpu_count
from functools import partial
from collections import Counter
from torch.utils.data.distributed import DistributedSampler
from src.utils.main_utils import get_or_generate_label_embeddings


class ProteinDataset(Dataset):
    """
    Dataset class for protein sequences with GO annotations.
    """
    def __init__(
        self,
        data_paths: dict,
        config: dict,
        vocabularies: dict,
        label_tokenizer=None,
        label_encoder=None,
        logger=None,
        require_label_idxs=False,
        subset_fraction: float = 1.0,
        deduplicate: bool = False,
        is_master: bool = True,
    ):
        """
        paths (dict): Dictionary containing paths to the data and vocabularies.
            data_path (str): Path to the FASTA file containing the protein sequences and corresponding GO annotations
            dataset_type (str): One of 'train', 'validation', or 'test'
            go_descriptions_path (str): Path to the pickled file containing the GO term descriptions mapped to GO term IDs
        deduplicate (bool): Whether to remove duplicate sequences (default: False)
        """
        # Error handling: check for missing keys and invalid dataset types
        required_keys = ["data_path", "dataset_type"]
        for key in required_keys:
            if key not in data_paths:
                raise ValueError(
                    f"Missing required key in paths dictionary: {key}")

        assert data_paths["dataset_type"] in [
            "train",
            "validation",
            "test",
        ], "dataset_type must be one of 'train', 'val', or 'test'"

        # Set the dataset type and data path
        self.dataset_type = data_paths["dataset_type"]
        self.data_path = data_paths["data_path"]

        # Set and process the vocabularies
        self.amino_acid_vocabulary = vocabularies["amino_acid_vocab"]
        self.label_vocabulary = vocabularies["GO_label_vocab"]
        self.sequence_id_vocabulary = vocabularies["sequence_id_vocab"]
        self._process_vocab()

        # Initialize class variables
        self.data = read_fasta(data_paths["data_path"])
        self.label_embedding_matrix = self.sequence_embedding_df = None
        
        #Flag to know how Dataset indexing will be handle.
        self.require_label_idxs = require_label_idxs

        # Subset the data if subset_fraction is provided
        if subset_fraction < 1.0:
            logging.info(
                f"Subsetting {subset_fraction*100}% of the {self.dataset_type} set..."
            )
            self.data = self.data[:int(subset_fraction * len(self.data))]

        # Deduplicate the data if deduplicate is True
        if deduplicate:
            self._remove_duplicates()

        # Load the map from alphanumeric label id to text label
        self.label_annotation_map = {key: value['label'] for key, value in read_pickle(
            data_paths["go_annotations_path"]).to_dict(orient='index').items()}

        # Create ordered list of labels
        label_text_list = []
        for label_id in self.label_vocabulary:
            label_text_list.append(self.label_annotation_map[label_id])
        self.label_text_list = label_text_list

        # Loop through the label IDs and tokenize the labels if a label tokenizer is provided
        self.tokenized_labels = None
        self.label_tokenizer = None
        if label_tokenizer is not None:
            self.label_tokenizer = label_tokenizer
            self.tokenized_labels = tokenize_labels(
                label_text_list, label_tokenizer)

        # If a label encoder is provided, encode the labels
        # TODO: Move back to main to remove warning
        self.label_embedding_matrix = None
        self.label_encoder = None
        if label_encoder is not None and config["params"]["LABEL_ENCODER_NUM_TRAINABLE_LAYERS"]==0:
            self.label_encoder = label_encoder
            label_embedding_matrix = get_or_generate_label_embeddings(
                label_annotations=self.label_text_list,
                label_tokenizer=label_tokenizer,
                label_encoder=label_encoder,
                label_embedding_path=config["paths"]["LABEL_EMBEDDING_PATH"],
                logger=logger,
                batch_size_limit=config["params"]["LABEL_BATCH_SIZE_LIMIT_NO_GRAD"],
                is_master=is_master,
                pooling_method=config["params"]["LABEL_EMBEDDING_POOLING_METHOD"]
            )
            self.label_embedding_matrix = label_embedding_matrix

    # Helper functions for setting embedding dictionaries
    def set_sequence_embedding_df(self, embedding_df: pd.DataFrame):
        self.sequence_embedding_df = embedding_df

    def set_label_embedding_matrix(self, embedding_matrix: torch.Tensor):
        self.label_embedding_matrix = embedding_matrix

    def _remove_duplicates(self):
        """
        Remove duplicate sequences from self.data, keeping only the first instance of each sequence
        Use pandas to improve performance
        """

        # Convert self.data to a DataFrame
        df = pd.DataFrame(self.data, columns=["sequence", "labels"])

        # Drop duplicate rows based on the 'sequence' column, keeping the first instance
        df = df.drop_duplicates(subset="sequence", keep="first")

        # Log the number of duplicate sequences removed
        num_duplicates = len(self.data) - len(df)
        logging.info(
            f"Removing {num_duplicates} duplicate sequences from {self.data_path}...")

        # Convert the DataFrame back to the list of tuples format
        self.data = list(df.itertuples(index=False, name=None))

    # Helper functions for processing and loading vocabularies
    def _process_vocab(self):
        self._process_amino_acid_vocab()
        self._process_label_vocab()
        self._process_sequence_id_vocab()

    def _process_amino_acid_vocab(self):
        self.aminoacid2int, self.int2aminoacid = get_vocab_mappings(
            self.amino_acid_vocabulary
        )

    def _process_label_vocab(self):
        self.label2int, self.int2label = get_vocab_mappings(
            self.label_vocabulary)

    def _process_sequence_id_vocab(self):
        self.sequence_id2int, self.int2sequence_id = get_vocab_mappings(
            self.sequence_id_vocabulary
        )

    def __len__(self) -> int:
        return len(self.data)

    def process_example(self, sequence: str, labels: list[str], label_idxs:list[int] = None) -> dict:
        sequence_id_alphanumeric, labels = labels[0], labels[1:]

        

        # Convert the sequence and labels to integers for one-hot encoding
        amino_acid_ints = torch.tensor(
            [self.aminoacid2int[aa] for aa in sequence], dtype=torch.long
        )

        labels_ints = torch.tensor(
            [self.label2int[label] for label in labels], dtype=torch.long
        )

        # Get the length of the sequence
        sequence_length = torch.tensor(len(amino_acid_ints))

        # Get multi-hot encoding of sequence and labels
        sequence_onehots = torch.nn.functional.one_hot(
            amino_acid_ints, num_classes=len(self.amino_acid_vocabulary)
        ).permute(1, 0)
        label_multihots = torch.nn.functional.one_hot(
            labels_ints, num_classes=len(self.label_vocabulary)
        ).sum(dim=0)

        if label_idxs is not None:
            label_multihots = label_multihots[label_idxs]
            label_idxs = torch.tensor(label_idxs)

        # Set the label embeddings, if provided
        label_embeddings = self.label_embedding_matrix if self.label_embedding_matrix is not None else None

        # Get the sequence embedding, if provided
        sequence_embedding = None
        # TODO: Remove this check
        if self.sequence_embedding_df is not None:
            sequence_embedding = torch.tensor(
                self.sequence_embedding_df.loc[sequence_id_alphanumeric].values)

        # Get the tokenized labels, if provided
        tokenized_labels = self.tokenized_labels if self.tokenized_labels is not None else None

        # Return a dict containing the processed example
        return {
            "sequence_onehots": sequence_onehots,
            "sequence_id": sequence_id_alphanumeric,
            "sequence_embedding": sequence_embedding,
            "sequence_length": sequence_length,
            "label_multihots": label_multihots,
            "tokenized_labels": tokenized_labels,
            "label_embeddings": label_embeddings,
            "label_idxs":label_idxs
        }

    def __getitem__(self, idx) -> tuple:
        
        if self.require_label_idxs:
            sequence_idx,label_idxs = idx[0],idx[1]
            sequence = self.data[sequence_idx][0]
            labels = self.data[sequence_idx][1]
        else:
            label_idxs=None
            sequence, labels = self.data[idx]
        

        return self.process_example(sequence, labels,label_idxs)

    @classmethod
    def create_multiple_datasets(
        cls,
        paths_list: List[Dict[str, str]],
        config: dict,
        vocabularies: dict,
        require_train_label_idxs:bool,
        subset_fractions: dict = None,
        label_tokenizer=None,
        label_encoder=None,
        logger=None,
        deduplicate: bool = False,
    ) -> List[Dataset]:
        """
        paths_list (List[Dict[str, str]]): List of dictionaries, each containing paths to the data and vocabularies.
        subset_fractions (dict): Dictionary containing the subset fraction for each dataset type (default: None)
        """
        datasets = defaultdict(list)
        subset_fractions = subset_fractions or {}
        for data_paths in paths_list:
            datasets[data_paths["dataset_type"]].append(
                cls(
                    data_paths,
                    config,
                    vocabularies,
                    label_tokenizer=label_tokenizer,
                    label_encoder=label_encoder,
                    logger=logger,
                    require_label_idxs=require_train_label_idxs if data_paths["dataset_type"]=='train' else False,
                    subset_fraction=subset_fractions.get(
                        data_paths["dataset_type"], 1.0),
                    deduplicate=deduplicate
                )
            )
        return datasets


In [4]:
datasets = ProteinDataset.create_multiple_datasets(
    paths_list=config['dataset_paths_list'],
    config=config,
    logger=logger,
    label_tokenizer=label_tokenizer,
    label_encoder=label_encoder,
    vocabularies=vocabularies,
    subset_fractions={
        "train": params["TRAIN_SUBSET_FRACTION"],
        "validation": params["VALIDATION_SUBSET_FRACTION"],
        "test": params["TEST_SUBSET_FRACTION"],
    },
    deduplicate=params["DEDUPLICATE"],
)

2023-12-16 05:32:17 PST INFO Removing 66586 duplicate sequences from ./data/swissprot/proteinfer_splits/random/train_GO.fasta...
2023-12-16 05:32:31 PST INFO Loaded label embeddings from ./data/embeddings/frozen_BioGPT_label_embeddings_mean.pkl
2023-12-16 05:32:31 PST INFO Removing 8479 duplicate sequences from ./data/swissprot/proteinfer_splits/random/dev_GO.fasta...
2023-12-16 05:32:43 PST INFO Loaded label embeddings from ./data/embeddings/frozen_BioGPT_label_embeddings_mean.pkl
2023-12-16 05:32:44 PST INFO Removing 8176 duplicate sequences from ./data/swissprot/proteinfer_splits/random/test_GO.fasta...
2023-12-16 05:32:55 PST INFO Loaded label embeddings from ./data/embeddings/frozen_BioGPT_label_embeddings_mean.pkl


In [6]:
d = ProteinDataset(
    data_paths=config['dataset_paths_list'][0],
    config=config,
    logger=logger,
    vocabularies=vocabularies,
    require_label_idxs=True

)

In [29]:
[d.label2int[i] for i in (d.data[0][1][1:])]

[13652,
 11611,
 3049,
 8380,
 3172,
 27552,
 27610,
 15406,
 16259,
 10173,
 3029,
 15405,
 14624,
 11617,
 3041,
 19707,
 11613,
 3067,
 4662,
 6054,
 25242,
 3697,
 108,
 4666,
 13992,
 1872,
 16297,
 25390]

In [35]:
d[(0,[(d.label2int[i]+1 if d.label2int[i]==25390 else d.label2int[i]) for i in (d.data[0][1][1:])])]

MSKIIEYDETARRAIEAGVNTLADAVRVTLGPRGRHVVLAKAFGGPAVTNDGVTVAREIDLEDPFENLGAQLVKSVATKTNDVAGDGTTTATVLAQALVKGGLRLVAAGANPIELGAGISKAADAVSEALLASATPVSGKDAIAQVATVSSRDQVLGELVGEAMTKVGVDGVVSVEESSTLNTELEFTEGVGFDKGFLSAYFVTDFDAQQAVLDDPVILLHQEKISSLPDLLPMLEKVAESGKPLLIIAEDIEGEALATLVVNSIRKTLKAVAVKAPFFGDRRKAFLEDLAIVTGGQVINPDTGLLLREVGTEVLGSARRVVVSKDDTIIVDGGGAKDAVANRIKQLRAEIEKTDSDWDREKLQERLAKLAGGVAVIKVGAATETALKERKESVEDAVAAAKAAVEEGIVAGGGSALLQARKALDELRGSLSGDQALGVDVFAEALGAPLYWIASNAGLDGAVAVHKVAELPAGHGLNAEKLSYGDLIADGVIDPVKVTRSAVLNSASVARMVLTTETAVVDKPAEEADDHGHGHHHH ['P60545', 'GO:0035639', 'GO:0032553', 'GO:0005524', 'GO:0017076', 'GO:0005737', 'GO:1901265', 'GO:1901363', 'GO:0043168', 'GO:0044424', 'GO:0030554', 'GO:0005488', 'GO:0043167', 'GO:0042026', 'GO:0032559', 'GO:0005515', 'GO:0051082', 'GO:0032555', 'GO:0005575', 'GO:0008144', 'GO:0009987', 'GO:0097159', 'GO:0006457', 'GO:0000166', 'GO:0008150', 'GO:0036094', 'GO:0003674', 'GO:0044464', 'GO:0097367']


{'sequence_onehots': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'sequence_id': 'P60545',
 'sequence_embedding': None,
 'sequence_length': tensor(538),
 'label_multihots': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 0]),
 'tokenized_labels': None,
 'label_embeddings': None,
 'label_idxs': tensor([13652, 11611,  3049,  8380,  3172, 27552, 27610, 15406, 16259, 10173,
          3029, 15405, 14624, 11617,  3041, 19707, 11613,  3067,  4662,  6054,
         25242,  3697,   108,  4666, 13992,  1872, 16297, 25391])}

In [16]:
from torch.utils.data import BatchSampler,RandomSampler,SequentialSampler
import numpy as np
from itertools import product
from tqdm import tqdm
class MyDataset(Dataset):

    def __init__(self):
        self.data =[(j*1000,[i for i in range(10)]) for j in tqdm(range(1,8))]

    def __getitem__(self, idx):
        observation_idx,label_idxs = idx[0],idx[1]
        features = torch.tensor([self.data[observation_idx][0]])
        labels = torch.tensor([self.data[observation_idx][1][i] for i in label_idxs])

        return features,labels
    
    def __len__(self):
        return len(self.data)
    
class GridSampler(BatchSampler):

    def __init__(self,
                 observation_sampler,
                 observations_batch_size,
                 drop_last_observation_batch,
                 num_labels,
                 labels_batch_size):
        
        self.observation_sampler = observation_sampler
        self.observations_batch_size = observations_batch_size
        self.drop_last_observation_batch = drop_last_observation_batch

        self.num_labels = num_labels
        self.labels_batch_size = labels_batch_size
        self.labels_idxs = list(range(num_labels))
        self.calculate_num_batches()
        
    def __iter__(self):
        random.shuffle(self.labels_idxs)
        print('Getting label batches...')
        observation_batches = self.get_observation_batches()
        print('Done...')

        print('Getting observation batches...')
        label_batches = self.get_label_batches()
        print('Done...')

        print('Getting combinations...')
        obs_labels_batch_combinations = list(product(observation_batches,label_batches))

        print('Done...')
        print('Shuffling...')
        random.shuffle(obs_labels_batch_combinations)
        print('Done...')
        print(len(obs_labels_batch_combinations))
        for observation_batch,label_batch in obs_labels_batch_combinations:
            yield list(product(observation_batch, [label_batch]))#[observation_batch,label_batch]
    
    def calculate_num_batches(self):
        
        num_label_batches = np.ceil(self.num_labels/self.labels_batch_size)
        num_observation_batches = (np.ceil(len(self.observation_sampler)/self.observations_batch_size)
                                   if not self.drop_last_observation_batch
                                   else len(self.observation_sampler)//self.observations_batch_size)
        print('Done...')

        self.total_num_batches = int(num_label_batches*num_observation_batches)

    def __len__(self):
        return self.total_num_batches
    

    def get_label_batches(self):

        #n_chunks = int(np.ceil(self.num_labels/self.labels_batch_size))
        return [self.labels_idxs[i:i+self.labels_batch_size] for i in range(0,self.num_labels,self.labels_batch_size)]
        

    def get_observation_batches(self):

        batches = []

        if self.drop_last_observation_batch:
            observation_sampler_iter = iter(self.observation_sampler)
            while True:
                try:
                    batch = [next(observation_sampler_iter) for _ in range(self.observations_batch_size)]
                    batches.append(batch)
                except StopIteration:
                    break
        else:
            batch = [0] * self.observations_batch_size
            idx_in_batch = 0
            for idx in self.observation_sampler:
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.observations_batch_size:
                    batches.append(batch)
                    idx_in_batch = 0
                    batch = [0] * self.observations_batch_size
            if idx_in_batch > 0:
                batches.append(batch[:idx_in_batch])
        return batches



In [17]:
a = MyDataset()

100%|██████████| 7/7 [00:00<00:00, 131659.77it/s]


In [21]:
GS=GridSampler(observation_sampler=RandomSampler(data_source=a),
                 observations_batch_size=3,
                 drop_last_observation_batch=False,
                 num_labels=10,
                 labels_batch_size=3)
l = DataLoader(dataset=a,batch_sampler=GS)
l_iter = iter(l)
next(l_iter)

Done...
Getting label batches...
Done...
Getting observation batches...
Done...
Getting combinations...
Done...
Shuffling...
Done...
12


[tensor([[7000],
         [4000],
         [1000]]),
 tensor([[7, 9, 4],
         [7, 9, 4],
         [7, 9, 4]])]

In [34]:
DS = DistributedSampler(
    a,
    num_replicas=4,
    rank=rank,
    shuffle=True
)

DSGS=GridSampler(observation_sampler=DS,
                 observations_batch_size=3,
                 drop_last_observation_batch=False,
                 num_labels=10,
                 labels_batch_size=3)

Done...


In [35]:
l = DataLoader(dataset=a,batch_sampler=DSGS)
l_iter = iter(l)


In [42]:
len(d.label_vocabulary)

32102

In [43]:
next(l_iter)

Getting label batches...
Done...
Getting observation batches...
Done...
Getting combinations...
Done...
Shuffling...
Done...
30954


IndexError: list index out of range

In [14]:
d = defaultdict(list)
for i in l:
    labels,obs = i 

    labels = labels.flatten().tolist()
    print(labels,obs.tolist())
    for j_idx,j in enumerate(obs.tolist()):
        d[labels[j_idx]].extend(j)
        

    #print(obs.tolist())

Getting label batches...
Done...
Getting observation batches...
Done...
Getting combinations...
Done...
Shuffling...
Done...
61875


IndexError: list index out of range

In [154]:
{k:sorted(v) for k,v in d.items()}

{3000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 2000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 7000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 5000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 8000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 6000: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}

In [115]:
import numpy as np
my_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
N = 4
np.array_split(list(range(11)),4)

[array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([ 9, 10])]

In [96]:
split_list_into_chunks(my_list,N)

[[1, 2, 3], [4, 5], [6, 7], [8, 9]]

In [172]:
a = MyDataset()
s=RandomSampler(data_source=a)
B = BatchSampler(sampler=s,batch_size=2,drop_last=True)
l = DataLoader(dataset=a,batch_sampler=B)

In [173]:
list(B)

[[2, 0], [5, 3], [4, 6], [7, 1]]

In [68]:
list(s)

[0, 3, 7, 2, 4, 1, 5, 6]