In [1]:
# #データセット名
# task_kind = 5 # 1, 2, 3, 4, 5
# #タスクデータの数
# dataset_len = 200 #100,1100,300, 500, 200

# #データの種類
# data_kind =3 #0:image, 1:in_text, 2:out_text
# index_range = 10000

# with open(f"./dataset_{task_kind}.csv", "w") as f:
#     writer = csv.writer(f)
#     for i in range(dataset_len):
#         data = [task_kind*10+j+i/index_range for j in range(data_kind)]
#         writer.writerow(data)
    

In [2]:
from torch.utils.data import Dataset, DataLoader,ConcatDataset
import torch
import csv
from typing import Callable, Optional, Tuple
import numpy as np
import random
from torch.utils.data import DistributedSampler
import math
import itertools


In [3]:
###持っているタスクから一回ずつ取り出す
#multitask_collate_fn = lambda sample_list: sample_list
def default_each_task_collate_fn(batch):
    # list[image,in_text,out_text]が入力される
    sample_list = [[] for _ in range(len(batch[0]))]
    for data in batch:
        for i, sample in enumerate(data):
            sample_list[i].append(sample)
    return sample_list


def default_multi_task_collate_fn(sample_per_task_list):
    if type(sample_per_task_list[0]) == list:
        next_sample = [[] for _ in range(len(sample_per_task_list[0]))]
        for sample_list in sample_per_task_list:
            for i, sample in enumerate(sample_list):
                next_sample[i].extend(sample) #[imageA,imageB,imageC],[in_textA,in_textB,in_textC],[out_textA,out_textB,out_textC]]
    else:
        raise NotImplementedError
    next_sample = [torch.stack(sample) for sample in next_sample]
    return next_sample


class MultiTaskDataIterator:
    def __init__(self, dataloader_list, step_list, multi_data_collate_fn=None) -> None:
        self.iter_list = [iter(dataloader) for dataloader in dataloader_list]
        self.min_step = min(step_list)
        self.step = 0
        self.multi_task_collate_fn = multi_data_collate_fn

    def __next__(self):
        if self.step == self.min_step:
            raise StopIteration

        next_sample_list = [next(iter) for iter in self.iter_list]  # [taskA,taskB,taskC]

        if self.multi_task_collate_fn is None:
            next_sample = default_multi_task_collate_fn(
                next_sample_list
            )  # [imageA,imageB,imageC],[in_textA,in_textB,in_textC],[out_textA,out_textB,out_textC]]
        else:
            next_sample = self.multi_task_collate_fn(next_sample_list)  # [taskA,taskB,taskC]
        self.step += 1
        return next_sample

    def __len__(self):
        return self.min_step


class MultiTaskDataLoader:
    def __init__(
        self,
        dataset_dict: dict[str, DataLoader],
        batch_size_dict: dict[str, int],
        each_task_collate_fn_dict: dict[str, Callable] = None,
        multi_task_collate_fn=None,
        is_ddp=False,
        seed=0,
        loader_drop_last=False,
        sampler_drop_last=False,
        **dataloader_args,
    ) -> None:
        """_summary_

        Args:
            dataset_dict (dict[str,DataLoader]): {taskA:Dataset,taskB:Dataset,taskC:Dataset}のようなdict
            batch_size_dict (_type_): ｛taskA:10,taskB:20,taskC:10｝のようなdict
            each_task_collate_fn_dict (dict[str,function], optional): {taskA:collate_fnA,taskB:collate_fnB,task:C,collate_fnC}のようなdict.タスクごとのデーターローダー Defaults to None.
            multi_task_collate_fn (_type_, optional): すべてのタスクからのバッチを統合する関数 Defaults to None.
            is_ddp (bool, optional): DDPか否か Defaults to False.
            seed (int, optional): 乱数シード Defaults to 0.
        """

        dataset_dict_keys = dataset_dict.keys()
        if is_ddp:
            distributed_keys = ["num_replicas", "rank", "shuffle"]
            distributed_args_dict = {}
            for key in distributed_keys:
                if key in dataloader_args:
                    distributed_args_dict[key] = dataloader_args[key]
                    del dataloader_args[key]

            self.distributed_sampler_dict = {
                key: DistributedSampler(dataset_dict[key], seed=seed, drop_last=sampler_drop_last, **distributed_args_dict) for key in dataset_dict_keys
            }
        else:
            self.distributed_sampler_dict = {key: None for key in dataset_dict_keys}

        if each_task_collate_fn_dict is None:
            each_task_collate_fn_dict = {key: default_each_task_collate_fn for key in dataset_dict_keys}

        def seed_worker(worker_id):
            worker_seed = torch.initial_seed() % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        g = torch.Generator()
        g.manual_seed(seed)

        self.dataloader_list = [
            DataLoader(
                dataset_dict[key],
                batch_size_dict[key],
                collate_fn=each_task_collate_fn_dict[key],
                sampler=self.distributed_sampler_dict[key],
                worker_init_fn=seed_worker,
                generator=g,
                drop_last=loader_drop_last,
                **dataloader_args,
            )
            for key in dataset_dict_keys
        ]
        self.step_list = [len(dataloader) for dataloader in self.dataloader_list]
        self.min_step = min(self.step_list)
        self.multi_task_collate_fn = multi_task_collate_fn

    def __iter__(self):
        return MultiTaskDataIterator(self.dataloader_list, self.step_list, self.multi_task_collate_fn)

    def __len__(self):
        return self.min_step

    def set_epoch(self, epoch: int):
        for sampler in self.distributed_sampler_dict.values():
            sampler.set_epoch(epoch)


In [4]:
class CountingIterator(object):
    """Wrapper around an iterable that maintains the iteration count.

    Args:
        iterable (iterable): iterable to wrap
        start (int): starting iteration count. Note that this doesn't
            actually advance the iterator.
        total (int): override the iterator length returned by ``__len``.
            This can be used to truncate *iterator*.

    Attributes:
        n (int): number of elements consumed from this iterator
    """

    def __init__(self, iterable, start=None, total=None):
        self._itr = iter(iterable)
        self.n = start or getattr(iterable, "n", 0)
        self.total = total if total is not None else self.n + len(iterable)

    def __len__(self):
        return self.total

    def __iter__(self):
        return self

    def __next__(self):
        if not self.has_next():
            raise StopIteration
        try:
            x = next(self._itr)
        except StopIteration:
            raise IndexError(
                f"Iterator expected to have length {self.total}, "
                f"but exhausted at position {self.n}."
            )
        self.n += 1
        return x

    def has_next(self):
        """Whether the iterator has been exhausted."""
        return self.n < self.total

    def skip(self, n):
        """Fast-forward the iterator by skipping n elements."""
        for _ in range(n):
            next(self)
        return self

    def take(self, n):
        """Truncate the iterator to n elements at most."""
        self.total = min(self.total, n)
        # Propagate this change to the underlying iterator
        if hasattr(self._itr, "take"):
            self._itr.take(max(n - self.n, 0))
        return self

class GroupedIterator(CountingIterator):
    """Wrapper around an iterable that returns groups (chunks) of items.

    Args:
        iterable (iterable): iterable to wrap
        chunk_size (int): size of each chunk
        skip_remainder_batch (bool, optional): if set, discard the last grouped batch in
          each training epoch, as the last grouped batch is usually smaller than
                local_batch_size * distributed_word_size * chunk_size (default: ``False``).
    Attributes:
        n (int): number of elements consumed from this iterator
    """

    def __init__(self, iterable, chunk_size, skip_remainder_batch=False):
        if skip_remainder_batch:
            total_num_itrs = int(math.floor(len(iterable) / float(chunk_size)))
            #logger.info(
            #    f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}"
            #)
        else:
            #raise NotImplementedError
            total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size)))
            #logger.info(f"grouped total_num_itrs = {total_num_itrs}")

        itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch)
        super().__init__(
            itr,
            start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
            total=total_num_itrs,
        )
        self.chunk_size = chunk_size

        if skip_remainder_batch:
            self.take(total_num_itrs)
            # TODO: [Hack] Here the grouped iterator modifies the base iterator size so that
            # training can move into the next epoch once the grouped iterator is exhausted.
            # Double-check this implementation in case unexpected behavior occurs.
            iterable.take(total_num_itrs * chunk_size)


def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False):
    chunk = []
    for x in itr:
        chunk.append(x)
        if len(chunk) == chunk_size:
            yield chunk
            chunk = []
    if not skip_remainder_batch and len(chunk) > 0:
        yield chunk
        
def default_each_task_collate_fn(batch):
    # list[image,in_text,out_text]が入力される
    sample_list = [[] for _ in range(len(batch[0]))]
    for data in batch:
        for i, sample in enumerate(data):
            sample_list[i].append(sample)
    return sample_list


def default_multi_task_collate_fn(sample_per_task_list):
    if type(sample_per_task_list[0]) == list:
        next_sample = [[] for _ in range(len(sample_per_task_list[0]))]
        for sample_list in sample_per_task_list:
            for i, sample in enumerate(sample_list):
                next_sample[i].extend(sample) #[imageA,imageB,imageC],[in_textA,in_textB,in_textC],[out_textA,out_textB,out_textC]]
    else:
        raise NotImplementedError
    next_sample = [torch.stack(sample) for sample in next_sample]
    return next_sample


class MultiTaskDataIterator:
    def __init__(self, dataloader_list, step_list, sample_num_list) -> None:
        self.dataloader_list = dataloader_list
        self.min_step = min(step_list)
        self.step = 0
        self.sample_num_list = sample_num_list
        self.start = 0
        self.grouped_iterator_list = [iter(GroupedIterator(CountingIterator(dataloader,0), sample_num,skip_remainder_batch=True)) for dataloader,sample_num in zip(self.dataloader_list,self.sample_num_list)]

    def __next__(self):
        if self.step == self.min_step:
            raise StopIteration

        next_sample_list = itertools.chain.from_iterable([next(iter) for iter in self.grouped_iterator_list])  # [taskA,taskB,taskC]

        self.step += 1
        return next_sample_list

    def __len__(self):
        return self.min_step


class MultiTaskDataLoader:
    def __init__(
        self,
        dataset_dict: dict[str, DataLoader],
        batch_size_dict: dict[str, int],
        each_task_collate_fn_dict: dict[str, Callable] = None,
        each_task_sample_num_dict: dict[str, int] = None,
        is_ddp=False,
        seed=0,
        loader_drop_last=False,
        sampler_drop_last=False,
        **dataloader_args,
    ) -> None:
        """_summary_

        Args:
            dataset_dict (dict[str,DataLoader]): {taskA:Dataset,taskB:Dataset,taskC:Dataset}のようなdict
            batch_size_dict (_type_): ｛taskA:10,taskB:20,taskC:10｝のようなdict
            each_task_collate_fn_dict (dict[str,function], optional): {taskA:collate_fnA,taskB:collate_fnB,task:C,collate_fnC}のようなdict.タスクごとのデーターローダー Defaults to None.
            multi_task_collate_fn (_type_, optional): すべてのタスクからのバッチを統合する関数 Defaults to None.
            is_ddp (bool, optional): DDPか否か Defaults to False.
            seed (int, optional): 乱数シード Defaults to 0.
        """

        dataset_dict_keys = dataset_dict.keys()
        if is_ddp:
            distributed_keys = ["num_replicas", "rank", "shuffle"]
            distributed_args_dict = {}
            for key in distributed_keys:
                if key in dataloader_args:
                    distributed_args_dict[key] = dataloader_args[key]
                    del dataloader_args[key]

            self.distributed_sampler_dict = {
                key: DistributedSampler(dataset_dict[key], seed=seed, drop_last=sampler_drop_last, **distributed_args_dict) for key in dataset_dict_keys
            }
        else:
            self.distributed_sampler_dict = {key: None for key in dataset_dict_keys}

        if each_task_collate_fn_dict is None:
            each_task_collate_fn_dict = {key: default_each_task_collate_fn for key in dataset_dict_keys}

        def seed_worker(worker_id):
            worker_seed = torch.initial_seed() % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        g = torch.Generator()
        g.manual_seed(seed)

        self.dataloader_list = [
            DataLoader(
                dataset_dict[key],
                batch_size_dict[key],
                collate_fn=each_task_collate_fn_dict[key],
                sampler=self.distributed_sampler_dict[key],
                worker_init_fn=seed_worker,
                generator=g,
                drop_last=loader_drop_last,
                **dataloader_args,
            )
            for key in dataset_dict_keys
        ]
        
        self.sample_num_list = [each_task_sample_num_dict[key] for key in dataset_dict_keys]
        self.step_list = [int(math.floor(len(dataloader) / float(sample_num))) for sample_num,dataloader in zip(self.sample_num_list,self.dataloader_list)]
        self.min_step = min(self.step_list)
        #self.multi_task_collate_fn = multi_task_collate_fn

    def __iter__(self):
        return MultiTaskDataIterator(self.dataloader_list, self.step_list, self.sample_num_list)

    def __len__(self):
        return self.min_step

    def set_epoch(self, epoch: int):
        for sampler in self.distributed_sampler_dict.values():
            sampler.set_epoch(epoch)


In [5]:
range_10 = range(10)
iter_10 = iter(range_10)
chunk_size = 3
skip_remainder_batch = True
counting_range_10 = CountingIterator(range_10,0)
counting_range_10.take(len(counting_range_10)-1)
grouped_range_10 = GroupedIterator(counting_range_10, chunk_size, skip_remainder_batch)
#最後まで読んだら初期化しなしないといけない


In [6]:
for sample in grouped_range_10:
    print(sample)
    print("next")


[0, 1, 2]
next
[3, 4, 5]
next
[6, 7, 8]
next


In [7]:
grouped_iter_10 = iter(grouped_range_10)

print("iter")
for i in range(5):
    print(next(grouped_iter_10))
    print("next")


iter


StopIteration: 

In [8]:
class MyDataset(Dataset):
    def __init__(self, csv_file):
        self.data = []
        with open(csv_file, "r") as f:
            reader = csv.reader(f)
            for row in reader:
                self.data.append(row)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return [torch.tensor([float(data)]) for data in self.data[idx]]


In [9]:
def get_dataset(path):
    return MyDataset(path)

def get_data(dataset_name_dict):
    dataset_dict = {
        key: ConcatDataset([get_dataset(f"./dataset_{task_kind}.csv") for task_kind in dataset_name_dict[key]]) for key in dataset_name_dict.keys()
    }
    return dataset_dict


In [10]:
#python multi_task_train/jsonargs.py '{"taskA":["1","4"],"taskB":["2"],"taskC":["3","5"]}'
dataset_name_dict = {"taskA":["1","4"], #1:100, 4:500 = 600
                "taskB":["2"], #2: 1100
                "taskC":["3","5"]} # 3:300, 5:200 = 500

dataset_dict = get_data(dataset_name_dict)
print(dataset_dict["taskC"][499])
# batch_size_dict = {"taskA": 10, "taskB": 109, "taskC": 5} 
# sample_num_dict = {"taskA": 2, "taskB": 1, "taskC": 4}
# #Bのバッチは1～9stepは109個、10stepは10個

batch_size_dict = {"taskA": 10, "taskB": 20, "taskC": 5}
sample_num_dict = {"taskA": 2, "taskB": 2, "taskC": 4}
# inex0 20,40,20
# index24 500,1000,500


[tensor([50.0199]), tensor([51.0199]), tensor([52.0199])]


(500, 1000, 500)

In [11]:
dataloader = MultiTaskDataLoader(dataset_dict,batch_size_dict,None,each_task_sample_num_dict=sample_num_dict)


In [None]:
#python multi_task_train/jsonargs.py '{"taskA":["1","4"],"taskB":["2"],"taskC":["3","5"]}'
dataset_name_dict = {"taskA":["1","4"], #1:100, 4:500 = 600
                "taskB":["2"], #2: 1100
                "taskC":["3","5"]} # 3:300, 5:200 = 500

batch_size_dict = {"taskA": 10, "taskB": 20, "taskC": 5}
sample_num_dict = {"taskA": 2, "taskB": 2, "taskC": 4}
# inex0 20,40,20
# index24 500,1000,500

In [12]:
print(len(dataloader))
for index,samples in enumerate(dataloader):
    if index == 0 or index == len(dataloader)-1:
        print(f"index:{index}")
        for sample in samples:
            print(sample)


25
index:0
[[tensor([10.]), tensor([10.0001]), tensor([10.0002]), tensor([10.0003]), tensor([10.0004]), tensor([10.0005]), tensor([10.0006]), tensor([10.0007]), tensor([10.0008]), tensor([10.0009])], [tensor([11.]), tensor([11.0001]), tensor([11.0002]), tensor([11.0003]), tensor([11.0004]), tensor([11.0005]), tensor([11.0006]), tensor([11.0007]), tensor([11.0008]), tensor([11.0009])], [tensor([12.]), tensor([12.0001]), tensor([12.0002]), tensor([12.0003]), tensor([12.0004]), tensor([12.0005]), tensor([12.0006]), tensor([12.0007]), tensor([12.0008]), tensor([12.0009])]]
[[tensor([10.0010]), tensor([10.0011]), tensor([10.0012]), tensor([10.0013]), tensor([10.0014]), tensor([10.0015]), tensor([10.0016]), tensor([10.0017]), tensor([10.0018]), tensor([10.0019])], [tensor([11.0010]), tensor([11.0011]), tensor([11.0012]), tensor([11.0013]), tensor([11.0014]), tensor([11.0015]), tensor([11.0016]), tensor([11.0017]), tensor([11.0018]), tensor([11.0019])], [tensor([12.0010]), tensor([12.0011]), 

In [13]:
batch_size_dict = {"taskA":10,"taskB":20,"taskC":10}
max_batch_size = 40
assert max_batch_size == sum([batch_size_dict[key] for key in batch_size_dict.keys()]), "batch_size_dictの合計がmax_batch_sizeと一致しません"


In [14]:
from typing import Dict, List, Callable, Any
def each_task_collate_fn(batch):
    #list[image,in_text,out_text]が入力される
    sample_list = [[] for _ in range(len(batch[0]))]
    for data in batch:
        for i,sample in enumerate(data):
            sample_list[i].append(sample)
    return sample_list

class MultiDataIterator():
    def __init__(self,dataloader_list,min_step,multi_data_collate_fn=None) -> None:
        self.iter_list = [iter(dataloader) for dataloader in dataloader_list]
        self.min_step = min(min_step)
        self.step = 0
        self.multi_task_collate_fn = multi_data_collate_fn
    def __next__(self):
        if self.step == self.min_step:
            raise StopIteration
        
        next_sample_list =  [next(iter) for iter in self.iter_list]#[taskA,taskB,taskC]
        if type(next_sample_list[0]) == list:
            next_sample = [[] for _ in range(len(next_sample_list[0]))]
            for sample_list in next_sample_list:
                for i,sample in enumerate(sample_list):
                    next_sample[i].extend(sample) #[imageA,imageB,imageC],[in_textA,in_textB,in_textC],[out_textA,out_textB,out_textC]]
        else:
            raise NotImplementedError
        if self.multi_task_collate_fn is None:
            next_sample = [torch.stack(sample) for sample in next_sample]
        else:
            next_sample = self.multi_task_collate_fn(next_sample)
        self.step += 1
        return next_sample
    def __len__(self):
        return self.min_step
    
class MultiDataLoader():
    def __init__(self,dataset_dict:dict[str,DataLoader],batch_size_dict:dict[str,int],each_task_collate_fn_dict:dict[str,Callable]=None,multi_task_collate_fn=None,is_ddp=False,**dataloader_args) -> None:
        """_summary_

        Args:
            dataset_dict (dict[str,DataLoader]): {taskA:Dataset,taskB:Dataset,taskC:Dataset}のようなdict
            batch_size_dict (_type_): ｛taskA:10,taskB:20,taskC:10｝のようなdict
            each_task_collate_fn_dict (dict[str,function], optional): {taskA:collate_fnA,taskB:collate_fnB,task:C,collate_fnC}のようなdict. Defaults to None.
            multi_task_collate_fn (_type_, optional): すべてのデータセットからのバッチを統合する関数 Defaults to None.
            is_ddp (bool, optional): DDPか否か Defaults to False.
        """
        if each_task_collate_fn_dict is None:
            each_task_collate_fn_dict = {key:each_task_collate_fn for key in dataset_dict.keys()}
        self.dataloader_list = [DataLoader(dataset_dict[key],batch_size_dict[key],collate_fn=each_task_collate_fn_dict[key],**dataloader_args) for key in dataset_dict.keys()]
        self.step_list = [len(dataloader) for dataloader in self.dataloader_list]
        self.min_step = min(self.step_list)
        self.multi_task_collate_fn = multi_task_collate_fn
    def __iter__(self):
        return MultiDataIterator(self.dataloader_list,self.step_list,self.multi_task_collate_fn)
    def __len__(self):
        return self.min_step

    

In [16]:


def multi_task_collate_fn(sample):
    #[[image_list],[in_text_list],[out_text_list]]が入力される
    
    image_list = sample[0]
    in_text_list = sample[1]
    out_text_list = sample[2]
    return torch.stack(image_list),in_text_list,out_text_list
multi_dataloader = MultiDataLoader(dataset_dict,batch_size_dict,shuffle=True,drop_last=True,multi_task_collate_fn=multi_task_collate_fn)
print(len(multi_dataloader))
print_slices = [1,10,30,50]
epoch = 2
for _ in range(epoch):
    for index,data in enumerate(multi_dataloader):
        if index+1 in print_slices:
            print(f"index:{index+1}")
            print([len(data[i]) for i in range(len(data))])
            print(data)
            print("\n")
            

50
index:1
[40, 40, 40]
(tensor([[40.0229],
        [40.0047],
        [40.0062],
        [40.0113],
        [40.0358],
        [40.0131],
        [10.0057],
        [40.0110],
        [40.0420],
        [40.0283],
        [20.0115],
        [20.0657],
        [20.0967],
        [20.1015],
        [20.0224],
        [20.0114],
        [20.0372],
        [20.0597],
        [20.0783],
        [20.0600],
        [20.0447],
        [20.0310],
        [20.0623],
        [20.0167],
        [20.0365],
        [20.0594],
        [20.0280],
        [20.0061],
        [20.0575],
        [20.0974],
        [30.0001],
        [30.0118],
        [30.0073],
        [30.0075],
        [30.0262],
        [30.0250],
        [30.0217],
        [50.0104],
        [50.0158],
        [50.0038]]), [tensor([41.0229]), tensor([41.0047]), tensor([41.0062]), tensor([41.0113]), tensor([41.0358]), tensor([41.0131]), tensor([11.0057]), tensor([41.0110]), tensor([41.0420]), tensor([41.0283]), tensor([21.0115]), ten

In [21]:
import torch.nn.functional as F
source = torch.randint(0,3,(5,10))
print(source.shape)
print(source)
# now we expand to size (7, 11) by appending a row of 0s at pos 0 and pos 6, 
# and a column of 0s at pos 10
result = F.pad(input=source, pad=(0, 1, 0, 0), mode='constant', value=0)
print(result.shape)
print(result)


torch.Size([5, 10])
tensor([[2, 1, 2, 2, 2, 2, 0, 0, 1, 0],
        [1, 1, 0, 0, 1, 1, 2, 1, 2, 1],
        [1, 1, 2, 0, 2, 2, 0, 0, 1, 1],
        [0, 1, 2, 0, 1, 1, 0, 0, 1, 2],
        [2, 2, 2, 0, 2, 1, 1, 2, 1, 0]])
torch.Size([5, 11])
tensor([[2, 1, 2, 2, 2, 2, 0, 0, 1, 0, 0],
        [1, 1, 0, 0, 1, 1, 2, 1, 2, 1, 0],
        [1, 1, 2, 0, 2, 2, 0, 0, 1, 1, 0],
        [0, 1, 2, 0, 1, 1, 0, 0, 1, 2, 0],
        [2, 2, 2, 0, 2, 1, 1, 2, 1, 0, 0]])


In [25]:
source1 = torch.randint(1,9,(5,10))
source2 = torch.randint(1,9,(10,20))
print(source1)
print(source2)
max_length = max(source1.shape[-1],source2.shape[-1])
pad_value = 0
# from torch.nn.utils.rnn import pad_sequence
# source1 = pad_sequence([source1,source2],batch_first=True,padding_value=pad_value)
# print(source1.shape)
source1 = F.pad(input=source1, pad=(0, max_length-source1.shape[-1]), mode='constant', value=pad_value)
source2 = F.pad(input=source2, pad=(0, max_length-source2.shape[-1]), mode='constant', value=pad_value)
source = torch.vstack([source1,source2])
print("source")
print(source.shape)
print(source)


tensor([[7, 3, 3, 2, 3, 6, 7, 6, 6, 7],
        [8, 7, 1, 3, 8, 2, 1, 3, 2, 2],
        [6, 2, 6, 3, 1, 3, 3, 6, 6, 7],
        [8, 1, 6, 6, 6, 8, 5, 8, 5, 2],
        [7, 8, 4, 7, 1, 3, 3, 8, 4, 5]])
tensor([[3, 3, 6, 8, 4, 6, 2, 4, 2, 5, 2, 7, 1, 6, 2, 5, 3, 2, 3, 1],
        [7, 8, 7, 1, 3, 7, 3, 5, 3, 5, 8, 7, 5, 5, 5, 5, 5, 7, 3, 6],
        [6, 5, 5, 7, 7, 5, 7, 1, 3, 5, 8, 6, 1, 8, 3, 7, 5, 4, 2, 4],
        [2, 4, 5, 7, 5, 7, 5, 4, 8, 8, 4, 3, 4, 8, 3, 1, 5, 8, 3, 1],
        [8, 4, 2, 1, 3, 8, 8, 4, 3, 8, 6, 7, 1, 3, 6, 5, 8, 5, 7, 8],
        [7, 8, 5, 5, 8, 8, 4, 2, 8, 5, 5, 5, 5, 3, 1, 3, 6, 8, 4, 7],
        [4, 3, 2, 4, 4, 7, 2, 1, 5, 3, 4, 4, 4, 2, 3, 6, 3, 5, 6, 2],
        [8, 6, 5, 3, 6, 5, 3, 5, 5, 3, 4, 3, 1, 4, 7, 7, 1, 3, 2, 6],
        [5, 1, 3, 2, 6, 1, 4, 8, 7, 6, 7, 2, 1, 6, 4, 3, 7, 2, 7, 2],
        [6, 7, 5, 1, 1, 4, 7, 3, 4, 8, 3, 8, 6, 6, 4, 7, 4, 1, 3, 1]])


RuntimeError: The size of tensor a (10) must match the size of tensor b (20) at non-singleton dimension 1

In [None]:
def collate_fn(self, batch):
    # バッチ内のテンソルをパッドする
    src_images, tgt_images, src_texts, tgt_texts = [], [], [], []
    for src_image, tgt_image, src_text, tgt_text in batch:
        src_images.append(src_image)
        tgt_images.append(tgt_image)
        src_texts.append(src_text)
        tgt_texts.append(tgt_text)

    src_images = torch.stack(src_images)
    tgt_images = torch.stack(tgt_images)
    src_texts = pad_sequence(src_texts, batch_first=True, padding_value=self.src_tokenizer.pad_token_id)
    tgt_texts = pad_sequence(tgt_texts, batch_first=True, padding_value=self.tgt_tokenizer.pad_token_id)


In [26]:
dataset_1 = MyDataset("./dataset_1.csv")


In [27]:
for i in range(10):
    print(dataset_1[i])


[tensor([10.]), tensor([11.]), tensor([12.])]
[tensor([10.0001]), tensor([11.0001]), tensor([12.0001])]
[tensor([10.0002]), tensor([11.0002]), tensor([12.0002])]
[tensor([10.0003]), tensor([11.0003]), tensor([12.0003])]
[tensor([10.0004]), tensor([11.0004]), tensor([12.0004])]
[tensor([10.0005]), tensor([11.0005]), tensor([12.0005])]
[tensor([10.0006]), tensor([11.0006]), tensor([12.0006])]
[tensor([10.0007]), tensor([11.0007]), tensor([12.0007])]
[tensor([10.0008]), tensor([11.0008]), tensor([12.0008])]
[tensor([10.0009]), tensor([11.0009]), tensor([12.0009])]


In [30]:
dataloader = DataLoader(dataset_1, batch_size=10, shuffle=False)
dataloader_iter = iter(dataloader)
print(next(dataloader_iter))


10
[tensor([[10.0000],
        [10.0001],
        [10.0002],
        [10.0003],
        [10.0004],
        [10.0005],
        [10.0006],
        [10.0007],
        [10.0008],
        [10.0009]]), tensor([[11.0000],
        [11.0001],
        [11.0002],
        [11.0003],
        [11.0004],
        [11.0005],
        [11.0006],
        [11.0007],
        [11.0008],
        [11.0009]]), tensor([[12.0000],
        [12.0001],
        [12.0002],
        [12.0003],
        [12.0004],
        [12.0005],
        [12.0006],
        [12.0007],
        [12.0008],
        [12.0009]])]


In [31]:
print(len(dataloader))
print(len(iter(dataloader)))


10
10


In [8]:
from ex_module import ExModel
model = ExModel(None)
data = next(dataloader_iter)
out = model(data[0], data[1])


In [9]:
import torch
data1 = [[1],[2],[3]]
data2 = [[4],[5],[6],[7]]
print(torch.cat([torch.tensor(data1), torch.tensor(data2)],dim=0))


tensor([[1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7]])


In [10]:
import torch
from torch.utils.data import Dataset, DataLoader, ChainDataset
from ex_module import MyDataset,MyChainDataset


In [11]:
dataset_1 = MyDataset("./dataset_1.csv")
dataset_2 = MyDataset("./dataset_2.csv")
dataset_3 = MyDataset("./dataset_3.csv")

dataset = MyChainDataset([dataset_1, dataset_2, dataset_3],[[0,1],[1,2],[2,0]])


In [12]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True)


In [13]:
data_iter = iter(dataloader)
for i in range(10):
    print(next(data_iter))


[tensor([[21.0079],
        [21.0603],
        [21.0834],
        [21.0860]]), tensor([[22.0079],
        [22.0603],
        [22.0834],
        [22.0860]])]
[tensor([[10.0053],
        [21.0394],
        [21.0858],
        [21.0271]]), tensor([[11.0053],
        [22.0394],
        [22.0858],
        [22.0271]])]
[tensor([[32.0271],
        [21.0072],
        [21.0586],
        [32.0232]]), tensor([[30.0271],
        [22.0072],
        [22.0586],
        [30.0232]])]
[tensor([[21.0893],
        [21.0944],
        [21.0180],
        [21.0756]]), tensor([[22.0893],
        [22.0944],
        [22.0180],
        [22.0756]])]
[tensor([[32.0252],
        [21.0284],
        [21.0052],
        [32.0263]]), tensor([[30.0252],
        [22.0284],
        [22.0052],
        [30.0263]])]
[tensor([[21.0931],
        [21.0924],
        [21.0279],
        [21.0547]]), tensor([[22.0931],
        [22.0924],
        [22.0279],
        [22.0547]])]
[tensor([[21.0353],
        [21.0164],
        [21.1006],


In [14]:
len(dataset)


1500

In [15]:
dataset[1500]


IndexError: 