### Imports

In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict, TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
from emotion_analysis import config
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.model.trainer import TrainerModule
from emotion_analysis.data.loader import DefaultDataLoader
from emotion_analysis.model.pretrained import load_text_model
from emotion_analysis.model.model import EmotionCauseTextModel
from emotion_analysis.data.types import TrainSplit, DataSplit, SubTask
from emotion_analysis.data.transform import DataTokenize, DataTransform, DataCollator

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  5 2023, 23:57:12) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.7.0-arch3-1', version='#1 SMP PREEMPT_DYNAMIC Sat, 13 Jan 2024 14:37:14 +0000', machine='x86_64')


External Imports

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
import jax.tree_util as tree_util
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import evaluate as eval
import mlflow as mlf
from torch.utils.data import Dataset, DataLoader, Subset, random_split

### Dataset

In [4]:
# Data Preprocessors
text_encoder_pretrained = load_text_model(config.model_repo, config.cache_dir)
tokenizer = text_encoder_pretrained.tokenizer
tokenize = DataTokenize( tokenizer,  max_seq_len=config.max_uttr_len)
transform = DataTransform(tokenize, max_conv_len=config.max_conv_len)
collator = DataCollator(transform)

# Load data subsets
dataset: Dict[DataSplit, ECACDataset] = ECACDataset.read_data(config.data_dir, config.subtask)
ds_train, ds_valid, ds_test = *random_split(dataset['train'], [0.75, 0.25]), dataset['test']
num_classes: int = dataset['train'].num_emotions

# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DefaultDataLoader(ds_train,  shuffle=True, collate_fn=collator, drop_last=True),
    'valid': DefaultDataLoader(ds_valid, shuffle=False, collate_fn=collator),
    'test' : DefaultDataLoader( ds_test, shuffle=False, collate_fn=collator),
}

### Model

In [5]:
key = jrm.PRNGKey(config.seed)
key, trainer_key = jrm.split(key, 2)
f1_score = eval.load('f1')

trainer = TrainerModule(
    key=trainer_key,
    finetune=config.finetune,
    batch_size=config.batch_size,
    max_conv_len=config.max_conv_len,
    max_uttr_len=config.max_uttr_len,
    text_model_repo=config.model_repo,
    learning_rate=config.learning_rate,
)


In [6]:
def take_until_index(array, index):
    output = []
    for batch, pad_idx in enumerate(index):
        output.append(array[batch, :pad_idx, ...])
    output = jnp.concatenate(output)
    return output

In [7]:
for epoch in range(2):
    for i, X in enumerate(dataloader['train']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].sum(axis=1).astype(jnp.int32)

        # Forward and backward pass
        loss_train, logits, key, trainer.state = trainer.train_step(key, trainer.state, X)

        # Track training loss
        print(i + 1, len(dataloader['train']), ': ', loss_train)

        # Track F1
        pred = take_until_index(logits, input_mask).argmax(axis=1)
        real = take_until_index(X['emotion_labels'], input_mask)
        f1_score.add_batch(predictions=pred, references=real)
    print('train {}: {} f1'.format(epoch, f1_score.compute(average='weighted')))

    losses = []
    for i, X in enumerate(dataloader['valid']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].sum(axis=1).astype(jnp.int32)

        # Forward pass
        loss_valid, logits, key, _ = trainer.eval_step(key, trainer.state, X)
        losses.append(loss_valid)

        # Track F1
        pred = take_until_index(logits, input_mask).argmax(axis=1)
        real = take_until_index(X['emotion_labels'], input_mask)
        f1_score.add_batch(predictions=pred, references=real)
    print('valid {}: {} f1'.format(epoch, f1_score.compute(average='weighed')))
    print('valid {}: {} loss'.format(epoch, np.array(losses).mean()))

1 64 :  0.6829146
2 64 :  1.8118005
3 64 :  1.6454159
4 64 :  0.7176057
5 64 :  1.1118233
6 64 :  1.1784569
7 64 :  0.99458706
8 64 :  0.9376453
9 64 :  0.6109783
10 64 :  0.5112756
11 64 :  0.77690494
12 64 :  0.71182024
13 64 :  0.49245372
14 64 :  0.8214066
15 64 :  0.60007596
16 64 :  0.5766284
17 64 :  0.68032676
18 64 :  0.53054476
19 64 :  0.6370242
20 64 :  0.5195521
21 64 :  0.51246345
22 64 :  0.6043467
23 64 :  0.4873317
24 64 :  0.5514187
25 64 :  0.567977
26 64 :  0.53544307
27 64 :  0.53609014
28 64 :  0.53278816
29 64 :  0.5270432
30 64 :  0.50832355
31 64 :  0.7034503
32 64 :  0.50238997
33 64 :  0.53093904
34 64 :  0.5158324
35 64 :  0.53135645
36 64 :  0.5307175
37 64 :  0.4110269
38 64 :  0.5268087
39 64 :  0.46615025
40 64 :  0.5552973
41 64 :  0.39142346
42 64 :  0.5050903
43 64 :  0.531948
44 64 :  0.4484294
45 64 :  0.6159646
46 64 :  0.40936208
47 64 :  0.64027905
48 64 :  0.4821464
49 64 :  0.5448167
50 64 :  0.42928627
51 64 :  0.5368924
52 64 :  0.42956966
53