In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict, TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.data.types import TrainSplit, DataSplit, SubTask
from emotion_analysis.data.transform import DataTokenize, DataTransform, DataCollator
from emotion_analysis.model.emotion_cause_text import load_text_model
from emotion_analysis.model.emotion_cause_text import EmotionCauseTextModel

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  4 2023, 22:00:02) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.6.10-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Fri, 05 Jan 2024 16:20:41 +0000', machine='x86_64')


External Imports

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
import jax.tree_util as tree_util
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import mlflow as mlf
from transformers import RobertaConfig, FlaxRobertaModel, RobertaTokenizerFast
from transformers import PretrainedConfig
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import pandas as pd
from collections import defaultdict
from itertools import chain

In [29]:
# Configuration
batch_size = 32
max_conv_len = 33
max_uttr_len = 93
subtask: SubTask = '1'

# Load pretrained model along with its tokenizer
text_encoder_pretrained = load_text_model()
tokenizer = text_encoder_pretrained.tokenizer
tokenize = DataTokenize(tokenizer, max_seq_len=max_uttr_len, padding='max_length')

# Load data subsets
ds_train = ECACDataset(
    data_dir=ea.DATA_DIR,
    subtask=subtask,
    split='train',
)
ds_test = ECACDataset(
    data_dir=ea.DATA_DIR,
    subtask=subtask,
    split='test',
)

# Data preprocessors
transform: Dict[TrainSplit, DataTransform] = {
    'train': DataTransform.from_data(ds_train, tokenize),
    'valid': DataTransform.from_data(ds_train, tokenize),
    'test': DataTransform.from_data(ds_test, tokenize),
}
collator: Dict[TrainSplit, DataCollator] = tree_util.tree_map(DataCollator, transform)

# Split training into train-valid subsets
num_classes: int = len(ds_test.emotions)

# Split into train-valid-test
# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DataLoader(ds_train, batch_size, False, collate_fn=collator['train'], drop_last=True),
    'test': DataLoader(ds_test, batch_size, False, collate_fn=collator['test']),
}

In [30]:
for x in dataloader['train']:
    sample = x
    print(x['input_ids'].shape)
    break

(32, 33, 93)


In [31]:
print(sample.keys())
input_mask = sample['input_mask'].astype(bool)

frozen_dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'input_mask', 'emotion_labels', 'cause_labels', 'cause_span', 'span_mask'])


In [32]:
sample['attention_mask'].shape

(32, 33, 93)

In [33]:
sample['input_ids'][input_mask]

Array([[    0, 45636,  2156, ...,     1,     1,     1],
       [    0,  7516,  2156, ...,     1,     1,     1],
       [    0, 12948,    38, ...,     1,     1,     1],
       ...,
       [    0,  1185,   216, ...,     1,     1,     1],
       [    0, 23692,   479, ...,     1,     1,     1],
       [    0, 47611,   479, ...,     1,     1,     1]], dtype=int32)

In [34]:
decoded = tokenizer.batch_decode(sample['input_ids'][input_mask], clean_up_tokenization_spaces=False)

In [35]:
sample['input_ids'].shape

(32, 33, 93)

In [36]:
decoded

['<s>Alright , so I am back in high school , I am standing in the middle of the cafeteria , and I realize I am totally naked .</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>Oh , yeah . Had that dream .</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>Then I look down , and I realize there is a phone ... there .</s><pad><pad><pad><pad><pad><p

In [37]:
sample['emotion_labels']

Array([[0, 0, 2, ..., 0, 0, 0],
       [3, 0, 2, ..., 0, 0, 0],
       [2, 0, 3, ..., 0, 0, 0],
       ...,
       [4, 0, 0, ..., 0, 0, 0],
       [0, 4, 0, ..., 0, 0, 0],
       [4, 4, 4, ..., 0, 0, 0]], dtype=int32)

In [38]:
sample['cause_labels']

Array([[1, 0, 1, ..., 0, 0, 0],
       [1, 0, 1, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 0, ..., 0, 0, 0],
       [1, 0, 1, ..., 0, 0, 0]], dtype=int32)

In [39]:
cause_mask = jnp.tile(input_mask[:, None, :, None], (1, 2, 1, 33))

In [40]:
sample['cause_span'][:1, ...].squeeze(0)

Array([[[ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [22,  0,  1, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]],

       [[ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [28,  0, 15, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]]], dtype=int32)

In [41]:
sample['cause_span'][0]

Array([[[ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [22,  0,  1, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]],

       [[ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [28,  0, 15, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]]], dtype=int32)

In [42]:
sample['cause_span'][0, 0, :8]

Array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [22,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [22,  0,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [22,  0,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0

In [44]:
for k in range(sample['cause_span'].shape[0]):
    for i in range(sample['cause_span'].shape[2]):
        for j in range(sample['cause_span'].shape[3]):
            if sample['cause_span'][k, 0, i, j] == 0:
                continue
            utterance = sample['input_ids'][k, j, ...]
            start = sample['cause_span'][k, 0, i, j]
            stop = sample['cause_span'][k, 1, i, j]
            print(tokenizer.decode(utterance[start: stop+1], clean_up_tokenization_spaces=False))
    print('--------')

 I realize I am totally naked .
Then I look down , and I realize there is a phone ... there .
 I realize I am totally naked .
Then I look down , and I realize there is a phone ... there .
Instead of ...
 I realize I am totally naked .
Then I look down , and I realize there is a phone ... there .
Instead of ...
--------
I do not want to be single
Rachel ? !
--------
 I should have caught on when she started going to the dentist four and five times a week .
 How did you get through it ?
 you might try accidentally breaking something valuable of hers , say her ...
leg ?
 you might try accidentally breaking something valuable of hers , say her ...
leg ?
That is one way !
 you might try accidentally breaking something valuable of hers
 I went for the watch .
You actually broke her watch ?
--------
--------
I am gonna go get one of those job things .
I am gonna go get one of those job things .
--------
 you probably did not know this , but back in high school , I had a , um , major crush on 

In [45]:
sample['cause_span'][0, 1, :8]

Array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [28,  0, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [28,  0, 15,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [28,  0, 15,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0