### Imports

In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict, TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
from emotion_analysis import config
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.model.trainer import TrainerModule
from emotion_analysis.data.loader import DefaultDataLoader
from emotion_analysis.model.pretrained import load_text_model
from emotion_analysis.model.model import EmotionCauseTextModel
from emotion_analysis.data.types import TrainSplit, DataSplit, SubTask
from emotion_analysis.data.transform import DataTokenize, DataTransform, DataCollator

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  4 2023, 22:00:02) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.7.0-arch3-1', version='#1 SMP PREEMPT_DYNAMIC Sat, 13 Jan 2024 14:37:14 +0000', machine='x86_64')


External Imports

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
import jax.tree_util as tree_util
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import evaluate as eval
import mlflow as mlf
from torch.utils.data import Dataset, DataLoader, Subset, random_split

### Dataset

In [4]:
# Data Preprocessors
text_encoder_pretrained = load_text_model(config.model_repo, config.cache_dir)
tokenizer = text_encoder_pretrained.tokenizer
tokenize = DataTokenize( tokenizer,  max_seq_len=config.max_uttr_len)
transform = DataTransform(tokenize, max_conv_len=config.max_conv_len)
collator = DataCollator(transform)

# Load data subsets
dataset: Dict[DataSplit, ECACDataset] = ECACDataset.read_data(config.data_dir, config.subtask)
ds_train, ds_valid, ds_test = *random_split(dataset['train'], [0.75, 0.25]), dataset['test']
num_classes: int = dataset['train'].num_emotions

# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DefaultDataLoader(ds_train,  shuffle=True, collate_fn=collator, drop_last=True),
    'valid': DefaultDataLoader(ds_valid, shuffle=False, collate_fn=collator),
    'test' : DefaultDataLoader( ds_test, shuffle=False, collate_fn=collator),
}

### Model

In [5]:
key = jrm.PRNGKey(config.seed)
key, trainer_key = jrm.split(key, 2)
f1_score = eval.load('f1')
accuracy = eval.load('accuracy')

trainer = TrainerModule(
    key=trainer_key,
    finetune=config.finetune,
    batch_size=config.batch_size,
    max_conv_len=config.max_conv_len,
    max_uttr_len=config.max_uttr_len,
    text_model_repo=config.model_repo,
    learning_rate=config.learning_rate,
)


In [6]:
def take_until_index(array, index):
    output = []
    for batch, pad_idx in enumerate(index):
        output.append(array[batch, :pad_idx, ...])
    output = jnp.concatenate(output)
    return output

In [7]:
for epoch in range(20):
    for i, X in enumerate(dataloader['train']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].sum(axis=1).astype(jnp.int32)

        # Forward and backward pass
        loss_train, logits, key, trainer.state = trainer.train_step(key, trainer.state, X)

        # Track training loss
        print(i + 1, len(dataloader['train']), ': ', loss_train)

        # Track F1
        pred = take_until_index(logits, input_mask).argmax(axis=1)
        real = take_until_index(X['emotion_labels'], input_mask)
        f1_score.add_batch(predictions=pred, references=real)
        accuracy.add_batch(predictions=pred, references=real)
    print('train {}: {}  f1'.format(epoch, f1_score.compute(average='weighted')))
    print('train {}: {} acc'.format(epoch, accuracy.compute()))

    losses = []
    for i, X in enumerate(dataloader['valid']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].sum(axis=1).astype(jnp.int32)

        # Forward pass
        loss_valid, logits, key, _ = trainer.eval_step(key, trainer.state, X)
        losses.append(loss_valid)

        # Track F1
        pred = take_until_index(logits, input_mask).argmax(axis=1)
        real = take_until_index(X['emotion_labels'], input_mask)
        f1_score.add_batch(predictions=pred, references=real)
        accuracy.add_batch(predictions=pred, references=real)
    print('valid {}: {} loss'.format(epoch, np.array(losses).mean()))
    print('valid {}: {}  f1'.format(epoch, f1_score.compute(average='weighted')))
    print('valid {}: {} acc'.format(epoch, accuracy.compute()))

1 32 :  1.0125986
2 32 :  2.0120068
3 32 :  1.0018982
4 32 :  0.9496703
5 32 :  1.2052962
6 32 :  1.0290931
7 32 :  0.56536543
8 32 :  0.6995568
9 32 :  0.7244954
10 32 :  0.5673435
11 32 :  0.6165367
12 32 :  0.62261796
13 32 :  0.64360917
14 32 :  0.7007814
15 32 :  0.65691155
16 32 :  0.7357431
17 32 :  0.49557114
18 32 :  0.6283635
19 32 :  0.73973775
20 32 :  0.50049025
21 32 :  0.52500874
22 32 :  0.57250863
23 32 :  0.5440023
24 32 :  0.5248589
25 32 :  0.5903663
26 32 :  0.47550568
27 32 :  0.4767759
28 32 :  0.53062516
29 32 :  0.59193736
30 32 :  0.5174201
31 32 :  0.4778231
32 32 :  0.5792128
train 0: {'f1': 0.29666067746653585}  f1
train 0: {'accuracy': 0.33894984326018807} acc
valid 0: 0.49913614988327026 loss
valid 0: {'f1': 0.28701135494218755}  f1
valid 0: {'accuracy': 0.40701128936423053} acc
1 32 :  0.49877405
2 32 :  0.5846058
3 32 :  0.39293116
4 32 :  0.42798916
5 32 :  0.40629077
6 32 :  0.59398454
7 32 :  0.4783363
8 32 :  0.49434924
9 32 :  0.4656752
10 32 :  0.