### Imports

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Internal Imports

In [4]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict, TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
from emotion_analysis import config
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.model.trainer import TrainerModule
from emotion_analysis.data.loader import DefaultDataLoader
from emotion_analysis.model.pretrained import load_text_model
from emotion_analysis.model.model import EmotionCauseTextModel
from emotion_analysis.data.types import TrainSplit, DataSplit, SubTask
from emotion_analysis.data.transform import DataTokenize, DataTransform, DataCollator

External Imports

In [5]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
import jax.tree_util as tree_util
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import evaluate as eval
import mlflow as mlf
from torch.utils.data import Dataset, DataLoader, Subset, random_split

### Dataset

In [6]:
# Data Preprocessors
text_encoder_pretrained = load_text_model(config.model_repo, config.cache_dir)
tokenizer = text_encoder_pretrained.tokenizer
tokenize = DataTokenize( tokenizer,  max_seq_len=config.max_uttr_len)
transform = DataTransform(tokenize, max_conv_len=config.max_conv_len)
collator = DataCollator(transform)

# Load data subsets
dataset: Dict[DataSplit, ECACDataset] = ECACDataset.read_data(config.data_dir, config.subtask)
ds_train, ds_valid, ds_test = *random_split(dataset['train'], [0.75, 0.25]), dataset['test']
num_classes: int = dataset['train'].num_emotions

# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DefaultDataLoader(ds_train,  shuffle=True, collate_fn=collator, drop_last=True),
    'valid': DefaultDataLoader(ds_valid, shuffle=False, collate_fn=collator),
    'test' : DefaultDataLoader( ds_test, shuffle=False, collate_fn=collator),
}

### Model

In [7]:
key = jrm.PRNGKey(config.seed)
key, trainer_key = jrm.split(key, 2)
f1_score = eval.load('f1')
accuracy = eval.load('accuracy')

trainer = TrainerModule(
    key=trainer_key,
    finetune=config.finetune,
    batch_size=config.batch_size,
    max_conv_len=config.max_conv_len,
    max_uttr_len=config.max_uttr_len,
    text_model_repo=config.model_repo,
    learning_rate=config.learning_rate,
)


In [9]:
for epoch in range(20):
    for i, X in enumerate(dataloader['train']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].astype(jnp.bool_)

        # Forward and backward pass
        loss_train, logits, key, trainer.state = trainer.train_step(key, trainer.state, X)

        # Track training loss
        print(i + 1, len(dataloader['train']), ': ', loss_train)

        # Track F1
        pred = logits[input_mask, :].argmax(axis=1)
        real = X['emotion_labels'][input_mask]
        f1_score.add_batch(predictions=pred, references=real)
        accuracy.add_batch(predictions=pred, references=real)
    print('train {}: {}  f1'.format(epoch, f1_score.compute(average='weighted')))
    print('train {}: {} acc'.format(epoch, accuracy.compute()))

    losses = []
    for i, X in enumerate(dataloader['valid']):
        # Ignore padded entries
        input_mask = X['conv_attn_mask'].astype(jnp.bool_)

        # Forward pass
        loss_valid, logits, key, _ = trainer.eval_step(key, trainer.state, X)
        losses.append(loss_valid)

        # Track F1
        pred = logits[input_mask, :].argmax(axis=1)
        real = X['emotion_labels'][input_mask]
        f1_score.add_batch(predictions=pred, references=real)
        accuracy.add_batch(predictions=pred, references=real)
    print('valid {}: {} loss'.format(epoch, np.array(losses).mean()))
    print('valid {}: {}  f1'.format(epoch, f1_score.compute(average='weighted')))
    print('valid {}: {} acc'.format(epoch, accuracy.compute()))

1 64 :  2.2270439
2 64 :  1.0621885
3 64 :  1.221214
4 64 :  1.0962687
5 64 :  1.1727551
6 64 :  0.6713203
7 64 :  0.76793635
8 64 :  1.1773005
9 64 :  0.7098235
10 64 :  0.93583095
11 64 :  1.0595444
12 64 :  0.84797025
13 64 :  0.65686053
14 64 :  0.57161474
15 64 :  0.6522981
16 64 :  0.5666307
17 64 :  0.50737184
18 64 :  0.5652526
19 64 :  0.4539348
20 64 :  0.44855458
21 64 :  0.4720006
22 64 :  0.4528192
23 64 :  0.53168285
24 64 :  0.4695926
25 64 :  0.37161326
26 64 :  0.3941302
27 64 :  0.5419952
28 64 :  0.48950705
29 64 :  0.58330864
30 64 :  0.7249353
31 64 :  0.5716671
32 64 :  0.44968382
33 64 :  0.667528
34 64 :  0.54698396
35 64 :  0.6088284
36 64 :  0.51612085
37 64 :  0.40283164
38 64 :  0.5115936
39 64 :  0.26780596
40 64 :  0.49915668
41 64 :  0.47447413
42 64 :  0.5111405
43 64 :  0.5911601
44 64 :  0.56397516
45 64 :  0.4524314
46 64 :  0.4302904
47 64 :  0.48532957
48 64 :  0.4806305
49 64 :  0.53956836
50 64 :  0.4261108
51 64 :  0.5737419
52 64 :  0.41775984
5

KeyboardInterrupt: 