In [1]:
import os
import time
import torch
import random
import logging
import argparse

import numpy as np 
import pandas as pd
import pytorch_lightning as pl

from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

In [2]:
logging.basicConfig(level = logging.INFO)

In [3]:
logger = logging.getLogger(__name__)

### Load Dataset

In [4]:
train = pd.read_csv('train.csv').dropna()
test = pd.read_csv('test.csv')

train, val = train_test_split(train, test_size=0.13, random_state=42)

INFO:numexpr.utils:Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [5]:
# Input
for a,b,_ in zip(train.sentiment.values[:10], train.text.values[:10], train.selected_text.values[:10]):
    print("sentiment:", a, "tweet:", b)

sentiment: negative tweet:  How did we just get paid and still be broke as hell?! No shopping spree for me today
sentiment: positive tweet: i no i no bt i had only been a gamer for like 2 years when i made that attempt  lol yea i luvd F1 to an extent 
sentiment: positive tweet: I love when my ipod shuffles so all the good songs are all together
sentiment: neutral tweet:  no i mean 2moz. I`m workin` 7-1 in a bakers then 6-4 later in a pub
sentiment: positive tweet: Lovely walk this morning with the missus; drizzle didn`t matter
sentiment: neutral tweet:  , just dont understand what`s it got to do with me. I`m just a nice girl
sentiment: negative tweet: getting bored of walking up and down the stairs
sentiment: positive tweet:  have your own style. it just might work.
sentiment: negative tweet: fighting with mum on mothers day
sentiment: neutral tweet:  & I got too much work to do


In [6]:
# Target
for _,_,c in zip(train.sentiment.values[:10], train.text.values[:10], train.selected_text.values[:10]):
    print(c)

broke as hell?!
luvd
love
no i mean 2moz. I`m workin` 7-1 in a bakers then 6-4 later in a pub
Lovely walk this morning with the missus; drizzle didn`t matter
, just dont understand what`s it got to do with me. I`m just a nice girl
getting bored of walking up and down the stairs
it just might work.
fighting
I got too much work to do


### Preprocess Dataset

In [7]:
# Append EOS token to target text, This is the standard format for T5 targets
train['selected_text'] = train['selected_text'] + ' </s>'
val['selected_text'] = val['selected_text'] + ' </s>'

# Apply Q&A structure
# From Appendix D in the T5 paper
processed_input_train = ("question: " + train.sentiment + " context: " + train.text)
processed_input_test = ("question: " + test.sentiment + " context: " + test.text)
processed_input_val = ("question: " + val.sentiment + " context: " + val.text)

# Save data as string separated by \n (new line)
processed_input_str_train = '\n'.join(processed_input_train.values.tolist())
processed_input_str_test = '\n'.join(processed_input_test.values.tolist())
selected_text_str_train = '\n'.join(train['selected_text'].values.tolist())
processed_input_str_val = '\n'.join(processed_input_val.values.tolist())
selected_text_str_val = '\n'.join(val['selected_text'].values.tolist())

In [8]:
processed_input_train[0], train['selected_text'][0]

('question: neutral context:  I`d have responded, if I were going',
 'I`d have responded, if I were going </s>')

In [9]:
processed_input_test[0]

'question: neutral context: Last session of the day  http://twitpic.com/67ezh'

In [10]:
with open('train.source', 'w') as f:
    f.write(processed_input_str_train)

with open('test.source', 'w') as f:
    f.write(processed_input_str_test)
    
with open('val.source', 'w') as f:
    f.write(processed_input_str_val)

In [11]:
with open('train.target', 'w') as f:
    f.write(selected_text_str_train)
    
with open('val.target', 'w') as f:
    f.write(selected_text_str_val)

### Prep the T5 Dataset

In [12]:
def encode_file(tokenizer, data_path, max_length, padding='max_length', return_tensors="pt"):
    """
    This function reads the text files that we prepared and returns them in tokenized form.

    Actually tokenizer.batch_encode_plus returns these as a list of dictionaries where 
    each dictionary contains the word piece indices among other relevant inputs for training & inference
    """
    examples = []
    with open(data_path, "r") as f:
        for text in f.readlines():
            tokenized = tokenizer.batch_encode_plus(
                [text], max_length=max_length, padding=padding, return_tensors=return_tensors,
            )
            examples.append(tokenized)
    return examples

In [13]:
class T5Dataset(Dataset):
    """
    This is the T5 dataset that can read our train, test, and dev files separately

    This was patterned after the SummarizationDataset from the `transformer` library's 
    summarization example (compatible with T5)
    """
    def __init__(
        self,
        tokenizer,
        data_dir="./",
        type_path="train",
        max_source_length=1024,
        max_target_length=56,
    ):
        super().__init__()
        # Store the tokenizer
        self.tokenizer = tokenizer
        self.type_path = type_path
        # Read the source and target files for the type of file (train, test, or val)
        self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
        self.target = None
        if self.type_path != "test":
            self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        # Return example as a dictionary containing source_ids, src_mask, and target_ids
        source_ids = self.source[index]["input_ids"].squeeze() # (1024,)
        # We need masks for transformers to:
        # 1) ignore padding for both the encoder and decoder stages (src_mask)
        # 2) ignore future tokens at the decoder stage
        src_mask = self.source[index]["attention_mask"].squeeze()

        if self.type_path == "test":
            return {"source_ids": source_ids, "source_mask": src_mask}

        target_ids = self.target[index]["input_ids"].squeeze() # (56, )
        return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}

    def collate_fn(self, batch):
        """
        The tensors are stacked together as they are yielded.

        Collate function is applied to the output of a DataLoader as it is yielded.
        """
        input_ids = torch.stack([x["source_ids"] for x in batch]) # BS x SL
        masks = torch.stack([x["source_mask"] for x in batch]) # BS x SL
        pad_token_id = self.tokenizer.pad_token_id
        if self.type_path == "test":
            return {"source_ids": source_ids, "source_mask": source_mask}

        target_ids = torch.stack([x["target_ids"] for x in batch]) # BS x SL
        return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}