## Data Preprocessing Module Design

- Module name: `preprocess_data`
- Exposed functions:
  - Get the formatted raw dataset
  - Get filtered dataset
  - Get tokenized dataset
  - Get dataloader

### Raw Dataset

- For summaries in the same post, concatenate those with same emotion labels
- Argument: Data split
- Return: Formatted raw dataset

In [90]:
import json
from collections import defaultdict
from itertools import chain
from pathlib import Path


def get_raw_dataset(split, concat_same_emo=True):
    '''
    Raw dataset format:
    raw_dataset: [sample]
    sample     : { 'post': str, 'annos': [anno] }
    anno       : { 'emo': str, 'summ': str }
    '''
    data_dir = Path('data/train_val_test-WITH_POSTS')
    assert data_dir.exists(), 'Data not in the correct file path.'

    data_path = data_dir / f'{split}_anonymized-WITH_POSTS.json'
    assert data_path.exists(), f'Cannot find {split} data split at {data_path}.'

    with data_path.open() as f:
        json_data = json.load(f)

    raw_dataset = []

    if concat_same_emo:
        for raw_sample in json_data.values():
            emo2summ = defaultdict(str)
            for anno in chain(*raw_sample['Annotations'].values()):
                if anno['Emotion'] != 'NA':
                    emo2summ[anno['Emotion']] += ' ' + anno['Abstractive']

            sample = {'post': raw_sample['Reddit Post'], 'annos': []}
            for emo, summ in emo2summ.items():
                anno = {'emo': emo, 'summ': summ}
                sample['annos'].append(anno)
            raw_dataset.append(sample)
    else:
        for raw_sample in json_data.values():
            sample = {'post': raw_sample['Reddit Post'], 'annos': []}
            for anno in chain(*raw_sample['Annotations'].values()):
                if anno['Emotion'] != 'NA':
                    emo_summ = {'emo': anno['Emotion'], 'summ': anno['Abstractive']}
                    sample['annos'].append(emo_summ)
            raw_dataset.append(sample)

    return raw_dataset


def verify_raw_dataset(raw_ds):
    assert type(raw_ds) == list
    for sample in raw_ds:
        assert (ks := list(sample.keys())) == ['post', 'annos'], f'Invalid key set {ks}'
        assert len(sample['annos']) > 0, 'Empty emotion summary annotation'
        for anno in sample['annos']:
            assert (ks := list(anno.keys())) == ['emo', 'summ'], f'Invalid key set {ks}'

In [102]:
raw_train_ds = get_raw_dataset('train')
raw_train_ds2 = get_raw_dataset('train', False)
print(*raw_train_ds[3]['annos'], sep='\n')
print('')
print(*raw_train_ds2[3]['annos'], sep='\n')

{'emo': 'fear', 'summ': " I don't think am safe around my bosses, because one is vaccinated and isn't. I am really confuse on what to do. I'm apprehensive about whether I should be going into work still when one of my unvaccinated bosses was exposed to COVID last week through his uncle, who has tested positive."}

{'emo': 'fear', 'summ': "I don't think am safe around my bosses, because one is vaccinated and isn't. I am really confuse on what to do."}
{'emo': 'fear', 'summ': "I'm apprehensive about whether I should be going into work still when one of my unvaccinated bosses was exposed to COVID last week through his uncle, who has tested positive."}


In [82]:
raw_val_ds = get_raw_dataset('val')
raw_val_ds[0]

{'post': "In my area we have a super high vaccination rate. In the sf Metro area we have 65 fully vaccinated, and some parts of it like sf have 68.9 fully vaccinated of whole population. And yet the delta is still surging here. The cdc just said in areas with high transmission masks should be mandated again and I feel completely hopeless. It's so far unknown if the bay area will reimplement masks but I'm sure they will. It's been close too 2 weeks since LA reinstated masks and the cases are still exploding there which is pretty hopeless. I can just see another lockdown coming maybe in the winter.",
 'annos': [{'emo': 'anticipation',
   'summ': 'The person is anticipating a new lockdown, realizing that changes will occur and is on alert with COVID-19.'},
  {'emo': 'fear',
   'summ': 'The person cannot relax and becomes worried and apprehensive about the increase in the number of COVID-19 cases.'},
  {'emo': 'sadness',
   'summ': 'The person feels defeated and without expectations of hav

In [83]:
for sample in raw_val_ds:
    emos = [anno['emo'] for anno in sample['annos']]
    if len(emos) != len(set(emos)):
        print(*sample['annos'], sep='\n')
        break

{'emo': 'anticipation', 'summ': 'The person is anticipating a new lockdown, realizing that changes will occur and is on alert with COVID-19.'}
{'emo': 'fear', 'summ': 'The person cannot relax and becomes worried and apprehensive about the increase in the number of COVID-19 cases.'}
{'emo': 'sadness', 'summ': 'The person feels defeated and without expectations of having to bear wearing masks again.'}
{'emo': 'anticipation', 'summ': 'There have been limited vaccination in the delta but am sure it will improve.'}


In [74]:
raw_test_ds = get_raw_dataset('test')
raw_test_ds[85]

{'post': "Hey guys, Apologies- pessimistic post incoming- feel free to delete if inappropriate. I'm struggling so much at the moment with having any sort of hope for the future. It just feels like we will never get normality back. COVID has completely ruined me. I went from being a (mostly) happy and functional human to depressed and suicidal since 2020. I take anti depressants now and I started therapy but even on my good days, I still feel like there's this dark cloud over my head and I'm waiting for the next shitstorm. I live in the UK and we've had three long lockdowns. Everything was looking pretty hopeful a few months ago, but since the last lockdown rules started easing, cases and hospitalisations have started rising again, even though we're doing pretty well with vaccinations. Our Government have handled the pandemic pretty poorly throughout and managed to let the Delta variant in with their loose border control measures. It just feels like this never ending cycle of rising cas

### Data Dictionary

- Sampled from raw dataset
- Formatted as a dictionary

In [55]:
def verify_data_dict(dd):
    key_set = ['post', 'emo', 'summ']
    assert list(dd.keys()) == key_set, f'Invalid key set: {dd.keys()}'

    len_dict = {k: len(dd[k]) for k in key_set}
    assert len_dict['post'] == len_dict['emo'] == len_dict['summ'], f'{len_dict=}'

#### All summaries

- Duplicate posts to match each of its summary

In [144]:
def sample_all_summaries(raw_dataset):
    verify_raw_dataset(raw_dataset)
    data_dict = {'post': [], 'emo': [], 'summ': []}

    for sample in raw_dataset:
        for anno in sample['annos']:
            data_dict['post'].append(sample['post'])
            data_dict['emo'].append(anno['emo'])
            data_dict['summ'].append(anno['summ'])

    return data_dict


def data_dict_allsumm(split, **kwargs):
    raw_ds = get_raw_dataset(split, **kwargs)
    sampled_raw_ds = sample_all_summaries(raw_ds)
    return sampled_raw_ds

In [145]:
from collections import Counter

train_ds = data_dict_allsumm('train')
Counter(train_ds['emo'])

Counter({'sadness': 360,
         'anger': 470,
         'fear': 765,
         'trust': 99,
         'anticipation': 873,
         'disgust': 192,
         'joy': 134})

In [146]:
from collections import Counter

train_ds = data_dict_allsumm('val')
Counter(train_ds['emo'])

Counter({'anticipation': 206,
         'fear': 192,
         'sadness': 97,
         'joy': 32,
         'anger': 114,
         'trust': 23,
         'disgust': 41})

In [147]:
train_ds = data_dict_allsumm('train')
train_ds['post'][0] == train_ds['post'][1]

True

#### Balanced

- Each emotion has the same number of summaries
- Specified number of samples per emotion
- No duplicated posts

In [148]:
from collections import Counter, defaultdict

EMO_LIST = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'trust', 'anticipation']


def data_dict_balanced(split, sample_size=float('inf')):
    raw_dataset = get_raw_dataset(split, concat_same_emo=True)
    data_dict = {'post': [], 'emo': [], 'summ': []}

    n_samples = dict.fromkeys(EMO_LIST, 0)
    sampling_emos = set(EMO_LIST)
    emo_freq = Counter(sample_all_summaries(raw_dataset)['emo'])
    sample_size = min(min(emo_freq.values()), sample_size)

    for sample in raw_dataset:
        annos = list(filter(lambda es: es['emo'] in sampling_emos, sample['annos']))
        if annos:
            anno = min(annos, key=lambda anno: emo_freq[anno['emo']])
            data_dict['post'].append(sample['post'])
            data_dict['emo'].append(anno['emo'])
            data_dict['summ'].append(anno['summ'])

            emo = anno['emo']
            n_samples[emo] += 1
            if n_samples[emo] == sample_size:
                sampling_emos.remove(emo)

    return data_dict

In [149]:
from collections import Counter
train_ds = data_dict_balanced('train')
Counter(train_ds['emo']), len(set(train_ds['post'])) == len(train_ds['post'])

(Counter({'trust': 99,
          'sadness': 99,
          'fear': 99,
          'anger': 99,
          'disgust': 99,
          'joy': 99,
          'anticipation': 99}),
 True)

In [150]:
val_ds = data_dict_balanced('val', sample_size=10)
Counter(val_ds['emo']), len(set(val_ds['post'])) == len(val_ds['post'])

(Counter({'sadness': 10,
          'anticipation': 10,
          'joy': 10,
          'anger': 10,
          'trust': 10,
          'fear': 10,
          'disgust': 10}),
 True)

### Dataset

- Configure dataset:
  - Argument: tokenizer
  - Return: Build dataset function
- Build dataset function:
  - Argument: data dictionary
  - Return: dataset

In [191]:
from datasets import Dataset


def config_dataset(tokenizer):
    instr = 'Generate a summary of what triggered {} in this post: {}'

    def make_prompt(sample):
        return {'prompt': instr.format(sample['emo'], sample['post'])}

    def tokenize(sample):
        inputs = tokenizer(
            sample['prompt'],
            max_length=512, truncation=True, padding='max_length'
        )
        labels = tokenizer(
            sample['summ'], return_attention_mask=False,
            max_length=128, truncation=True, padding='max_length'
        )
        return {**inputs, 'labels': labels['input_ids']}

    def make_dataset(data_dict):
        verify_data_dict(data_dict)
        dataset = Dataset.from_dict(data_dict)
        dataset = dataset.map(make_prompt)
        dataset = dataset.remove_columns(['post', 'emo'])
        dataset = dataset.map(tokenize, batched=True)
        dataset = dataset.remove_columns(['prompt', 'summ'])
        dataset.set_format('torch')
        return dataset

    return make_dataset

In [192]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
data_dict = data_dict_balanced('train')
make_dataset = config_dataset(tokenizer)
dataset = make_dataset(data_dict)
dataset

Map:   0%|          | 0/693 [00:00<?, ? examples/s]

Map:   0%|          | 0/693 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 693
})

### Dataloader

- Configure dataloader:
  - Arguments: model, tokenizer, dataloader kwargs
  - Return: build dataloader function
- Build dataloader function:
  - Argument: dataset
  - Return: dataloader

In [193]:
from os import sched_getaffinity
import torch
from absl import flags
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

FLAGS = flags.FLAGS

def config_dataloader(model, tokenizer, **kwargs):
    collator = DataCollatorForSeq2Seq(tokenizer, model, padding='longest')

    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    dl_kwargs = dict(
        collate_fn=collator, batch_size=FLAGS.batch_size,
        num_workers=len(sched_getaffinity(0)), worker_init_fn=seed_worker
    )
    dl_kwargs.update(kwargs)

    make_dataloader = lambda dataset: DataLoader(dataset, **dl_kwargs)

    return make_dataloader

In [190]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
make_dataloader = config_dataloader(model, tokenizer, batch_size=16)

## Training Script

### Create Argument Flags

### Create Components

- Model
  - Argument: `checkpoint`
- Tokenizer
  - Argument: `checkpoint`
  - if checkpoint is T5: `model_max_length=512`
- Optimizer
  - Arguments:
    - model parameters
    - `learning rate`
- Scheduler
  - Arguments:
    - optimizer
    - warmup schedule if `warmup` is defined else constant schedule
- Writer
  - Argument: `experiment name`

### Training Pipeline

- Setup model and optimizer
- Setup progress bar and logging config
- Initialize `best_rouge` and `accum_loss`

For each batch in train dataloader at forward step `step`:

- Forward and backward pass:
    - Forward pass to get loss
    - Average loss by gradient accumulation steps `grad_acc`
    - Backward pass
    - Accumulate loss
- Optimize model every `grad_acc` (gradient accumulation) forward steps:
  - Make a step optimizer and scheduler
  - Update progress bar
- Log metrics every `eval_freq` optimization steps:
  - Log learning rate
  - Log average loss with accumulates loss, reset `accum_loss`
  - Evaluate on train data subset, log ROUGE-L score, and print it with logging
  - Evaluate on validation data subset, log ROUGE-L score, and print it with logging
  - Reset model status to `train`
  - if model improves upon `best_rouge`, update it and save the model

- Close progress bar and writer objects
- Save tokenizer

#### Evaluate Pipeline on Training Dataset

For each bath in validation dataloader:
- Generate summary on `torch.inference_mode()`
- Log ROUGE score

