In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader

import os
from tqdm import tqdm

from dataset.dataset import (
    # MultimodalPretrainedEmbeddingsDatasetLoader, 
    MultimodalPretrainedEmbeddingsDataset, 
)

from models.adaptor import Adaptor, AdaptorTrainer, AdaptorTrainingArguments, ExternalLoggingCallback
from models.configurations import (
    TEXT_PRETRAINED_AVAILABLE,
    VISION_PRETRAINED_AVAILABLE,
    VISION_MODEL_TYPE_2_DATA_TRANSFORM,
    VISION_MODEL_TYPE_2_VISION_OUTPUT_DIM, 
)
from utils.utils import load_timm_model, freeze_encoder
from utils.model_utils import load_vision_model
from transformers import AutoTokenizer
from transformers import BertModel
from transformers import TrainingArguments

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

from datasets import Dataset

import argparse

import logging

In [15]:

class MultimodalPretrainedEmbeddingsIteratbleDataset(torch.utils.data.IterableDataset):
    def __init__(
        self, 
        text_embeds_raw_dir: str,
        image_embeds_raw_dir: str,
        split: str='train',
        device='cpu',
        num_of_batches=-1, 
        shuffle=True,
    ):
        super().__init__()
        self.text_embeds_raw_dir = os.path.join(text_embeds_raw_dir, split)
        self.image_embeds_raw_dir = os.path.join(image_embeds_raw_dir, split)
        self.device = torch.device(device)
        self.num_of_batches = num_of_batches
        
        self.text_tensor_names = sorted([f for f in os.listdir(self.text_embeds_raw_dir)],
                                        key=lambda x: int(x.split('_')[1].split('.')[0]))
        self.image_tensor_names = sorted([f for f in os.listdir(self.image_embeds_raw_dir)], 
                                         key=lambda x: int(x.split('_')[1].split('.')[0]))
        if self.num_of_batches > 0:
            self.text_tensor_names = self.text_tensor_names[:self.num_of_batches]
            self.image_tensor_names = self.image_tensor_names[:self.num_of_batches]
        else:
            self.num_of_batches = len(self.text_tensor_names)
        
        if shuffle:
            self.shuffle_batches()
        
        assert self.text_tensor_names == self.image_tensor_names, "text and image tensor names do not match"
        self.batch_size = torch.load(os.path.join(self.text_embeds_raw_dir, self.text_tensor_names[0]), 
                                     map_location='cpu')[0].shape[0]
    
    def process_single_tensor_file(self, text_tensor, image_tensor):
        for tt, it in zip(text_tensor, image_tensor):
            yield {'text_embeds_raw':tt, 'image_embeds_raw':it}
    
    def shuffle_batches(self):
        shuffled_idx = torch.randperm(self.num_of_batches)
        self.text_tensor_names = [self.text_tensor_names[i] for i in shuffled_idx]
        self.image_tensor_names = [self.image_tensor_names[i] for i in shuffled_idx]
    
    def __iter__(self):
        for text_tensor_name, image_tensor_name in zip(self.text_tensor_names, self.image_tensor_names):
            text_tensor = torch.load(os.path.join(self.text_embeds_raw_dir, text_tensor_name), 
                                     map_location=self.device)
            image_tensor = torch.load(os.path.join(self.image_embeds_raw_dir, image_tensor_name), 
                                      map_location=self.device)
            if isinstance(image_tensor, dict):  ### For ResNetAE
                image_tensor = image_tensor['z']
            
            yield from self.process_single_tensor_file(text_tensor, image_tensor)
            # for tt, it in zip(text_tensor, image_tensor):
            #     yield {'text_embeds_raw':tt, 'image_embeds_raw':it}
    
    # def __getitem__(self, idx):
    #     batch_idx = idx // self.batch_size
    #     batch_item_idx = idx % self.batch_size
        
    #     text_tensor = torch.load(os.path.join(self.text_embeds_raw_dir, self.text_tensor_names[batch_idx]), 
    #                              map_location=self.device)[batch_item_idx]
    #     image_tensor = torch.load(os.path.join(self.image_embeds_raw_dir, self.image_tensor_names[batch_idx]), 
    #                               map_location=self.device)
    #     if isinstance(image_tensor, dict):  ### For ResNetAE
    #         image_tensor = image_tensor['z']
    #     image_tensor = image_tensor[batch_item_idx]
        
    #     # print('Single sample loaded. ')
    #     return {'text_embeds_raw':text_tensor, 'image_embeds_raw':image_tensor}

    # def __len__(self):
    #     return self.num_of_batches * self.batch_size

In [22]:
text_embeds_raw_dir = '/vol/bitbucket/jq619/individual-project/saved_embeddings/text_embeds/BioBERT'
image_embeds_raw_dir = '/vol/bitbucket/jq619/individual-project/saved_embeddings/image_embeds/ResNetAE'

num_of_batches = -1
dataset_device = 'cpu'

In [23]:
train_dataset = MultimodalPretrainedEmbeddingsDataset(text_embeds_raw_dir, image_embeds_raw_dir, 
                                                      split='train', num_of_batches=num_of_batches, 
                                                      device=dataset_device)

In [31]:
from datasets import Dataset

def gen():
    # or if it's an IterableDataset
    for ex in train_dataset:
        yield ex

dset = Dataset.from_generator(gen, streaming=True)

In [32]:
indices = []

for i, data in enumerate(dset):
    indices.append(i)
    # print(data)

KeyboardInterrupt: 

In [33]:
indices[-1]

2431