In [88]:
import os

import numpy as np
import pandas as pd
from transformers import T5Tokenizer

class Dataset(object):
    
    def __init__(self, data_dir):
        self.df = pd.read_csv(os.path.join(data_dir, "train_df.csv"))
        self.training_column = "cat_conc_sec"
        self.tokenizer = T5Tokenizer.from_pretrained('t5-small')
        self.max_length = 400

    def __getitem__(self, idx):

        X = self.df[self.training_column][idx]
        y = self.df['target_text'][idx]
        
        inputbatch = self.tokenizer.encode_plus(
            text=X,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt')["input_ids"][0]
        
        labelbatch = self.tokenizer.encode_plus(
            text=y,
            padding='max_length',
            max_length=self.max_length,
            return_tensors="pt")["input_ids"][0]
        
        return inputbatch, labelbatch

    def __len__(self):
        return self.df.shape[0]

In [89]:
D = Dataset('../data')

In [90]:
from torch.utils.data import DataLoader
dataloader = DataLoader(D, batch_size=4, shuffle=True)

In [91]:
X, y = next(iter(dataloader))

In [92]:
X.shape, y.shape

(torch.Size([4, 400]), torch.Size([4, 400]))

In [93]:
model.train()
model(input_ids=X, labels=y)

Seq2SeqLMOutput(loss=tensor(2.4109, grad_fn=<NllLossBackward>), logits=tensor([[[ 10.9310, -12.5166, -14.9949,  ..., -41.2985, -41.2933, -41.3102],
         [ -4.2262, -17.9021, -10.2211,  ..., -49.7641, -49.8876, -49.8886],
         [ -3.0235,  -9.9621,  -8.8996,  ..., -40.2220, -40.2278, -40.3380],
         ...,
         [ 14.5240, -13.9865, -16.0803,  ..., -48.2629, -48.3096, -48.3490],
         [ 16.9134, -12.9729, -15.2565,  ..., -45.0781, -45.1830, -45.1838],
         [ 10.5574, -12.8382, -17.0926,  ..., -45.4711, -45.5147, -45.4705]],

        [[-18.1025,  -7.4446,  -9.8509,  ..., -37.0133, -37.0284, -37.0193],
         [-27.9342, -12.0741, -15.3477,  ..., -47.7742, -47.7753, -47.7409],
         [  0.3036, -14.2789, -16.2239,  ..., -49.4646, -49.5569, -49.6069],
         ...,
         [  4.7874, -13.7262, -15.0421,  ..., -45.9564, -45.9093, -45.9857],
         [  5.2998, -12.6286, -15.7437,  ..., -48.4817, -48.4517, -48.5013],
         [ 30.4321, -12.5227, -15.3501,  ..., -43.01