In [11]:
import torch
from torch.utils.data import Dataset
import json
import pandas as pd

In [21]:
# class defenition for dataset
class CheeseDescriptionsDataset(Dataset):
    def __init__(self, annotation_file, loader_pipeline):
        self.annot_file = annotation_file
        self.df = self.load_data(self.annot_file)
        self.pipeline = loader_pipeline

    
    def load_data(self, annot_file):
        with open(annot_file, 'r', encoding='utf-8') as f:
            lines = json.load(f)
            task_input, task_output = [], []
            for file in lines:
                line = lines[file]
                for rhet_tag in line:
                    text = line[rhet_tag]['text']
                    slots = line[rhet_tag]['slots']
                    formatted_slot = ''
                    for slot_key, slot_value in slots.items():
                        formatted_slot+= '<'+slot_key+':'+slot_value+'>'
                    task_input.append(text)
                    task_output.append(formatted_slot)
        
        data = {'input':task_input, 'output':task_output}
        df = pd.DataFrame(data=data, columns=['input','output'])
        return df

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        # return specific items
        row = self.df.iloc[index]
        return self.pipeline(row)
    


In [46]:
class BertPipeline:
    def __init__(self,
                 tokenizer,
                 max_len_encoder,
                 max_len_decoder,
                 **kwargs):
        self.tokenizer = tokenizer
        self.max_len_encoder = max_len_encoder
        self.max_len_decoder = max_len_decoder

    def bert_pipeline(self, row):
        decoder_text = row['output']
        encoder_text = row['input']

        # prepare encoder inputs
        enc_tokens = self.tokenizer(encoder_text,
                                    max_length = self.max_len_encoder,
                                    padding = 'max_length',
                                    truncation = True)
        encoder_input_ids = enc_tokens['input_ids']
        encoder_attention_mask = [1 if x!=0 else 0 for x in encoder_input_ids]
        encoder_cross_attention_mask = [1 if x!=0 else 0 for x in encoder_attention_mask]

        # prepare decoder inputs
        dec_tokens = self.tokenizer(decoder_text,
                                    max_length = self.max_len_decoder,
                                    padding = 'max_length',
                                    truncation = True)
        
        decoder_input_ids = dec_tokens['input_ids']
        decoder_attention_mask = [1 if x!=0 else 0 for x in decoder_input_ids]

        # prepare the labels and target ids are shifted inside the decoder model forward pass
        decoder_target_ids = [x for x in decoder_input_ids]
        
        ds = {
            'input_ids': encoder_input_ids,
            'attention_mask': encoder_attention_mask,
            'cross_attention_mask': encoder_cross_attention_mask,
            'decoder_input_ids': decoder_input_ids,
            'decoder_attention_mask': decoder_attention_mask,
            'labels': decoder_target_ids
        }

        return ds


        
    



In [47]:
annot_file = '../Data/slots_data/rhet_data_slots_cleaned.json'
ds = CheeseDescriptionsDataset(annotation_file=annot_file)

In [48]:
from transformers import BertTokenizer

txt= 'Hello world'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',lowercase=True)
pipeline = BertPipeline(tokenizer, 512, 512)
row = ds.__getitem__(0)
pipeline.bert_pipeline(row)



{'input_ids': [101, 13055, 28159, 2015, 14400, 3062, 22851, 2080, 3151, 24519, 9638, 2003, 1037, 2995, 17070, 1997, 9638, 1010, 2550, 2011, 1996, 8228, 8808, 7751, 13055, 28159, 2015, 1012, 2023, 8808, 2038, 1037, 4138, 2381, 1998, 10056, 10003, 6651, 1010, 2004, 2009, 2003, 2081, 2478, 7246, 23184, 6501, 2013, 8623, 2306, 1037, 2260, 1011, 3542, 12177, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 