In [None]:
from typing import List, Dict, Any
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

# What I want to do
- convert tokenized pair of (query, positive_document, negative_document) to batch

In precisely, if the batch_size=2,

from

[(tk_query1, tk_positive_document1, tk_negative_document1), 
(tk_query2, tk_positive_document2, tk_negative_document2)
]

to

(tk_query_batch1&2, tk_positive_document_batch1&2, tk_negative_document_batch1&2)

In [27]:
class QDCollator():
    def __call__(self, features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        num_docs = len(features[0])
        input_keys = list(features[0][0].keys())
        batch = []
        for i in range(num_docs):
            batch.append({k: [] for k in input_keys})

        for i, feature in enumerate(zip(*features)):
            for feat in feature:
                for k in input_keys:
                    batch[i][k].append(feat[k])
#         padding_func = super().__call__
#         batch = [padding_func(v) for v in batch]

        return batch

# Preparing sample data

In [98]:
texts = [{"query": "is a little caffeine ok during pregnancy", "positive_doc": "We don\u00e2\u0080\u0099t know a lot about the effects of caffeine during pregnancy on you and your baby. So it\u00e2\u0080\u0099s best to limit the amount you get each day. If you\u00e2\u0080\u0099re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1\u00c2\u00bd 8-ounce cups of coffee or one 12-ounce cup of coffee.", "negative_doc": "It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day."},
{"query": "what fruit is native to australia", "positive_doc": "Passiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sources list the fruit as edible, sweet and tasty, while others list the fruits as being bitter and inedible.assiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sources list the fruit as edible, sweet and tasty, while others list the fruits as being bitter and inedible.", "negative_doc": "The kola nut is the fruit of the kola tree, a genus (Cola) of trees that are native to the tropical rainforests of Africa."},
{"query": "how large is the canadian military", "positive_doc": "The Canadian Armed Forces. 1  The first large-scale Canadian peacekeeping mission started in Egypt on November 24, 1956. 2  There are approximately 65,000 Regular Force and 25,000 reservist members in the Canadian military. 3  In Canada, August 9 is designated as National Peacekeepers\u00e2\u0080\u0099 Day.", "negative_doc": "The Canadian Physician Health Institute (CPHI) is a national program created in 2012 as a collaboration between the Canadian Medical Association (CMA), the Canadian Medical Foundation (CMF) and the Provincial and Territorial Medical Associations (PTMAs)."},
{"query": "types of fruit trees", "positive_doc": "Cherry. Cherry trees are found throughout the world. There are 40 or more varieties, ranging from bing cherry to black cherry. Along with the fruit, cherry trees produce light and delicate pinkish-white blossoms that are highly fragrant.omments. Submit. Planting fruit trees on your property not only provides you with a steady supply of organic fruit, it also allows you to beautify your yard and give oxygen back to the environment.", "negative_doc": "The kola nut is the fruit of the kola tree, a genus (Cola) of trees that are native to the tropical rainforests of Africa."},
{"query": "how many calories a day are lost breastfeeding", "positive_doc": "Not only is breastfeeding better for the baby, however, research also says it\u00e2\u0080\u0099s better for the mother. Breastfeeding burns an average of 500 calories a day, with the typical range from 200 to 600 calories burned a day. It\u00e2\u0080\u0099s estimated that the production of 1 oz. of breast milk burns 20 calories. The amount of calories burned depending on how much the baby eats. Breastfeeding twins burn twice as much as feeding only one baby. With twins their mom burns 1000 calories a day. Burning an extra 500 calories a day will result in one pound of weekly weight loss.", "negative_doc": "However, you still need some niacin each day; men need about 16 mg per day and women need 14 mg per day unless they are pregnant or nursing (pregnant and breastfeeding women have higher niacin requirements)."},
]

In [101]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tk_texts = []
for text in texts:
    tk_query  = tokenizer(text["query"])
    tk_posi_doc = tokenizer(text["positive_doc"])
    tk_nega_doc = tokenizer(text["negative_doc"])
    tk_texts.append((tk_query, tk_posi_doc, tk_nega_doc))

In [102]:
tk_texts[0]

({'input_ids': [101, 2003, 1037, 2210, 24689, 7959, 3170, 7929, 2076, 10032, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]},
 {'input_ids': [101, 2057, 24260, 2102, 2113, 1037, 2843, 2055, 1996, 3896, 1997, 24689, 7959, 3170, 2076, 10032, 2006, 2017, 1998, 2115, 3336, 1012, 2061, 2009, 3022, 2190, 2000, 5787, 1996, 3815, 2017, 2131, 2169, 2154, 1012, 2065, 2017, 12069, 6875, 1010, 5787, 24689, 7959, 3170, 2000, 3263, 4971, 8004, 6444, 2015, 2169, 2154, 1012, 2023, 2003, 2055, 1996, 3815, 1999, 20720, 13714, 1022, 1011, 19471, 10268, 1997, 4157, 2030, 2028, 2260, 1011, 19471, 2452, 1997, 4157, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [103]:
batch_size=2

# If using normal Dataloader

In [104]:
dl = DataLoader(tk_texts, batch_size=batch_size)

In [105]:
for batch in dl:
    print(batch)

RuntimeError: each element in list of batch should be of equal size

# Using Data Collator

In [106]:
class QDCollator():
    def __call__(self, features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        num_docs = len(features[0])
        input_keys = list(features[0][0].keys())
        batch = []
        for i in range(num_docs):
            batch.append({k: [] for k in input_keys})

        for i, feature in enumerate(zip(*features)):
            for feat in feature:
                for k in input_keys:
                    batch[i][k].append(feat[k])

        return batch

In [107]:
new_dl = DataLoader(tk_texts, batch_size=batch_size, collate_fn=QDCollator())

In [108]:
for batch in new_dl:
    print(len(batch), len(batch[0]["input_ids"]), len(batch[1]["input_ids"]), len(batch[2]["input_ids"]))

3 2 2 2
3 2 2 2
3 1 1 1


# Align the length of batch

In [109]:
from transformers.data.data_collator import DataCollatorWithPadding

In [110]:
class QDCollator_with_Padding(DataCollatorWithPadding):
    def __call__(self, features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        num_docs = len(features[0])
        input_keys = list(features[0][0].keys())
        batch = []
        for i in range(num_docs):
            batch.append({k: [] for k in input_keys})

        for i, feature in enumerate(zip(*features)):
            for feat in feature:
                for k in input_keys:
                    batch[i][k].append(feat[k])
        padding_func = super().__call__
        batch = [padding_func(v) for v in batch]

        return batch

In [111]:
align_dl = DataLoader(tk_texts, batch_size=batch_size, collate_fn=QDCollator_with_Padding(tokenizer))

In [112]:
for batch in align_dl:
    print(len(batch), batch[0]["input_ids"].shape, batch[1]["input_ids"].shape, batch[2]["input_ids"].shape)

3 torch.Size([2, 11]) torch.Size([2, 116]) torch.Size([2, 44])
3 torch.Size([2, 8]) torch.Size([2, 89]) torch.Size([2, 51])
3 torch.Size([1, 13]) torch.Size([1, 127]) torch.Size([1, 49])


# Some Comments
- you can evade the error of normal dataloader if you tokenize the same length and return tensor pt.
- but maybe it is unexpected results

In [113]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tk_texts = []
for text in texts:
    tk_query  = tokenizer(text["query"], truncation=True, padding="max_length", max_length=64, return_tensors="pt")
    tk_posi_doc = tokenizer(text["positive_doc"], truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    tk_nega_doc = tokenizer(text["negative_doc"], truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    tk_texts.append((tk_query, tk_posi_doc, tk_nega_doc))

In [114]:
same_length_dl = DataLoader(tk_texts, batch_size=batch_size)

In [115]:
for batch in same_length_dl:
    print(batch[0]["input_ids"].shape, batch[1]["input_ids"].shape, batch[2]["input_ids"].shape)

torch.Size([2, 1, 64]) torch.Size([2, 1, 512]) torch.Size([2, 1, 512])
torch.Size([2, 1, 64]) torch.Size([2, 1, 512]) torch.Size([2, 1, 512])
torch.Size([1, 1, 64]) torch.Size([1, 1, 512]) torch.Size([1, 1, 512])


- deploy list by zip(*list)

In [116]:
demo_texts = [
    ["query1", "posi_doc1", "nega_doc1"],
    ["query2", "posi_doc2", "nega_doc2"],
    ["query3", "posi_doc3", "nega_doc3"],
    ["query4", "posi_doc4", "nega_doc4"],
    ["query5", "posi_doc5", "nega_doc5"],
]

In [117]:
for dt in enumerate(zip(*demo_texts)):
    print(dt)

(0, ('query1', 'query2', 'query3', 'query4', 'query5'))
(1, ('posi_doc1', 'posi_doc2', 'posi_doc3', 'posi_doc4', 'posi_doc5'))
(2, ('nega_doc1', 'nega_doc2', 'nega_doc3', 'nega_doc4', 'nega_doc5'))
