### Imports

In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
from emotion_analysis import config
from emotion_analysis.utils.weight import INSWeight
from emotion_analysis.model.pretrained import load_text_model
from emotion_analysis.model.trainer import TrainerModule
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.data.loader import DefaultDataLoader
from emotion_analysis.data.types import EmotionCauseEncoding
from emotion_analysis.data.types import TrainSplit, DataSplit
from emotion_analysis.data.transform import TokenizeTransform, EncodeTransform, CollateTransform, Transform, WeightTransform

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  4 2023, 22:00:02) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.7.0-arch3-1', version='#1 SMP PREEMPT_DYNAMIC Sat, 13 Jan 2024 14:37:14 +0000', machine='x86_64')


External Imports

In [3]:
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.random as jrm
import jax.tree_util as tree_util
import mlflow as mlf
import seaborn as sea
import numpy as np
import optax as opt
from jax import Array
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
from numpy import ndarray
from torch.utils.data import DataLoader, Dataset, Subset, random_split

### Dataset

In [4]:
# Load all data subsets in-memory
dataset: Dict[DataSplit, ECACDataset] = ECACDataset.read_data(config.data_dir, config.subtask)
ds_train, ds_valid, ds_test = *random_split(dataset['train'], [0.75, 0.25]), dataset['test']

# Extract relevant statistics
num_classes: int = dataset['train'].num_emotions
emotion_labels = dataset['train'].emotion_labels
emotion_weights = INSWeight(num_classes, emotion_labels)

# Load pretrained model
text_encoder_pretrained = load_text_model(config.model_repo, config.cache_dir)
tokenizer = text_encoder_pretrained.tokenizer

# Compose data transformations
tokenize = TokenizeTransform(tokenizer, max_seq_len=config.max_uttr_len)
encode = EncodeTransform(tokenize, max_conv_len=config.max_conv_len)
weight = WeightTransform(emotion_weights)
collator = CollateTransform(Transform.chain(encode, weight))

# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[EmotionCauseEncoding]] = {
    'train': DefaultDataLoader(ds_train,  shuffle=True, collate_fn=collator, drop_last=True),
    'valid': DefaultDataLoader(ds_valid, shuffle=False, collate_fn=collator),
    'test' : DefaultDataLoader( ds_test, shuffle=False, collate_fn=collator),
}

### Model

In [10]:
key = jrm.PRNGKey(config.seed)
key, trainer_key = jrm.split(key, 2)
trainer = TrainerModule(
    key=trainer_key,
    finetune=config.finetune,
    batch_size=config.batch_size,
    max_conv_len=config.max_conv_len,
    max_uttr_len=config.max_uttr_len,
    text_model_repo=config.model_repo,
    learning_rate=config.learning_rate,
)

### Training

In [12]:
pred = trainer.predict(dataloader['test'])

inference: 100%|██████████| 42/42 [00:26<00:00,  1.61it/s]


In [16]:
pred[0]['span_start'].shape

(16, 33, 33)

In [7]:
trainer.train(
    num_epochs=40,
    train_dataloader=dataloader['train'],
    valid_dataloader=dataloader['valid'],
)

[train][epoch: 0]:   0%|          | 0/64 [20:34<?, ?it/s]
[training]:   0%|          | 0/40 [20:34<?, ?it/s]

Unexpected exception formatting exception. Falling back to standard exception



Traceback (most recent call last):
  File "/home/invokariman/.cache/pypoetry/virtualenvs/emotion-analysis-adVGJZXD-py3.11/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_8272/200836941.py", line 1, in <module>
    trainer.train(
  File "/home/invokariman/Projects/git/ub-g21-irtm/emotion-analysis/emotion_analysis/model/trainer.py", line 256, in train
    pred = output['cause']['out'][conv_mask, :].argmax(axis=1)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/invokariman/Projects/git/ub-g21-irtm/emotion-analysis/emotion_analysis/model/trainer.py", line 202, in train_epoch
  File "/home/invokariman/.cache/pypoetry/virtualenvs/emotion-analysis-adVGJZXD-py3.11/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/invokariman/.cache/py