In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import TypeAlias
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
import emotion_analysis.data.dataset as data
import emotion_analysis.data.transform as transform
from emotion_analysis.model.emotion_cause_text import load_text_model

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  4 2023, 22:00:02) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.6.8-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Thu, 21 Dec 2023 19:01:01 +0000', machine='x86_64')


External Imports

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import mlflow as mlf
from transformers import RobertaConfig, FlaxRobertaModel, RobertaTokenizerFast
from transformers import PretrainedConfig
from torch.utils.data import Dataset, DataLoader, Subset, random_split

In [4]:
# Load pretrained model along with its tokenizer
model, params, tokenizer = load_text_model()

# Select tokenization method
tokenize = transform.Tokenize(tokenizer)
collate = transform.DataTokenizerCollator(tokenize)

# Load and prepare the data
task = 'task_1'
DataSplit: TypeAlias = t.Literal['train', 'valid', 'test']
ds_test = data.ECACDataset(ea.DATA_DIR, task, split='test')
ds_train = data.ECACDataset(ea.DATA_DIR, task, split='train')
ds_train, ds_valid = random_split(ds_train, [0.75, 0.25])
num_classes: int = ds_test.num_emotions

# Split into train-valid-test
dataloader: t.Dict[DataSplit, DataLoader[t.Dict[str, Array]]] = {
    'valid': DataLoader(ds_valid, 32, False, collate_fn=collate),
    'train': DataLoader(ds_train, 32, True, collate_fn=collate),
    'test': DataLoader(ds_test, 32, False, collate_fn=collate),
}

In [5]:
for x in dataloader['train']:
    sample = x
    print(x['input_ids'].shape)
    break

(32, 33, 93)


In [6]:
print(sample.keys())
input_mask = sample['input_mask'].astype(bool)

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'input_mask', 'emotion_labels', 'cause_labels', 'cause_span', 'cause_mask'])


In [20]:
sample['attention_mask'].shape

(32, 33, 93)

In [7]:
sample['input_ids'][input_mask]

Array([[   0, 2847, 2156, ...,    1,    1,    1],
       [   0, 2409,   47, ...,    1,    1,    1],
       [   0, 1185,  216, ...,    1,    1,    1],
       ...,
       [   0, 2264,   32, ...,    1,    1,    1],
       [   0,  170,   32, ...,    1,    1,    1],
       [   0, 7516, 2156, ...,    1,    1,    1]], dtype=int32)

In [8]:
decoded = tokenizer.batch_decode(sample['input_ids'][input_mask], clean_up_tokenization_spaces=False)

In [9]:
sample['input_ids'].shape

(32, 33, 93)

In [10]:
decoded

['<s>So , you are like a zillionaire ?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>And you are our age . You are our age .</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>You know what , you should like , you should buy a state and then just name it after yourself 

In [11]:
sample['emotion_labels']

Array([[2, 2, 0, ..., 0, 0, 0],
       [4, 0, 4, ..., 0, 0, 0],
       [0, 4, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 4, 0, ..., 0, 0, 0],
       [2, 4, 4, ..., 0, 0, 0]], dtype=int32)

In [12]:
sample['cause_labels']

Array([[1, 1, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)

In [13]:
cause_mask = jnp.tile(input_mask[:, None, :, None], (1, 2, 1, 33))

In [14]:
sample['cause_span'][:1, ...].squeeze(0)

Array([[[ 3,  0,  0, ...,  0,  0,  0],
        [ 3,  1,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]],

       [[10,  0,  0, ...,  0,  0,  0],
        [10, 11,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]]], dtype=int32)

In [15]:
sample['cause_span'][0]

Array([[[ 3,  0,  0, ...,  0,  0,  0],
        [ 3,  1,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]],

       [[10,  0,  0, ...,  0,  0,  0],
        [10, 11,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]]], dtype=int32)

In [16]:
sample['cause_span'][0, 0, :8]

Array([[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)

In [17]:
for k in range(sample['cause_span'].shape[0]):
    for i in range(sample['cause_span'].shape[2]):
        for j in range(sample['cause_span'].shape[3]):
            if sample['cause_span'][k, 0, i, j] == 0:
                continue
            utterance = sample['input_ids'][k, j, ...]
            start = sample['cause_span'][k, 0, i, j]
            stop = sample['cause_span'][k, 1, i, j]
            print(tokenizer.decode(utterance[start: stop+1], clean_up_tokenization_spaces=False))
    print('--------')

 you are like a zillionaire ?
 you are like a zillionaire ?
And you are our age . You are our age .
You are our age !
 we on for tomorrow ?
 I am running out of places I can touch him !
 is there something wrong with me ?
 why am I only attracted to guys where there is no future ?
Either they are too old , or they are too young , and then there is Pete who ... who crazy about me , and who absolutely perfect for me , and there is like zip going on !
 does it sound like something wrong with me ?
 is there something wrong with me ?
 why am I only attracted to guys where there is no future ?
Either they are too old , or they are too young , and then there is Pete who ... who crazy about me , and who absolutely perfect for me , and there is like zip going on !
 does it sound like something wrong with me ?
 is there something wrong with me ?
 why am I only attracted to guys where there is no future ?
Either they are too old , or they are too young , and then there is Pete who ... who crazy a

In [18]:
sample['cause_span'][0, 1, :8]

Array([[10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [10, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0