### Imports

In [1]:
%load_ext autoreload
%autoreload 2

Internal Imports

In [2]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import Dict, TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
from emotion_analysis import config
from emotion_analysis.data.dataset import ECACDataset
from emotion_analysis.data.loader import DefaultDataLoader
from emotion_analysis.data.types import TrainSplit, DataSplit, SubTask
from emotion_analysis.data.transform import DataTokenize, DataTransform, DataCollator
from emotion_analysis.model.emotion_cause_text import load_text_model
from emotion_analysis.model.emotion_cause_text import EmotionCauseTextModel

JAX Backend:  gpu
JAX Version:  0.4.23
Python:  3.11.0 (main, Oct  5 2023, 23:57:12) [GCC 13.2.1 20230801]
System:  posix.uname_result(sysname='Linux', nodename='archlinux', release='6.7.0-arch3-1', version='#1 SMP PREEMPT_DYNAMIC Sat, 13 Jan 2024 14:37:14 +0000', machine='x86_64')


External Imports

In [3]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
import jax.tree_util as tree_util
from jax.tree_util import tree_structure
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import mlflow as mlf
from torch.utils.data import Dataset, DataLoader, Subset, random_split

### Dataset

In [4]:
# Data Preprocessors
text_encoder_pretrained = load_text_model(config.model_repo)
tokenizer = text_encoder_pretrained.tokenizer
tokenize = DataTokenize( tokenizer,  max_seq_len=config.max_uttr_len)
transform = DataTransform(tokenize, max_conv_len=config.max_conv_len)
collator = DataCollator(transform)

# Load data subsets
dataset: Dict[DataSplit, ECACDataset] = ECACDataset.read_data(config.data_dir, config.subtask)
ds_train, ds_valid, ds_test = *random_split(dataset['train'], [0.75, 0.25]), dataset['test']
num_classes: int = dataset['train'].num_emotions

# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[TrainSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DefaultDataLoader(ds_train,  shuffle=True, collate_fn=collator, drop_last=True),
    'valid': DefaultDataLoader(ds_train, shuffle=False, collate_fn=collator),
    'test' : DefaultDataLoader( ds_test, shuffle=False, collate_fn=collator),
}

In [5]:
next(iter(dataloader['train']))['input_mask'].shape

(4, 33)

### Model

In [6]:
# Generate a PR key for init
key = jrm.key(config.seed)
key, init_key = jrm.split(key, 2)

# Generate fake data to "imitate" a batch
fake_input_ids = jnp.zeros((config.batch_size, config.max_conv_len, config.max_uttr_len))
fake_conv_attn_mask = jnp.zeros((config.batch_size, config.max_conv_len))
fake_uttr_attn_mask = jnp.zeros_like(fake_input_ids)
fake_batch: Any = dict(
    train=False,
    input_ids=fake_input_ids,
    uttr_attn_mask=fake_uttr_attn_mask,
    conv_attn_mask=fake_conv_attn_mask,
)

# Initialize the model with random params
ect_model = EmotionCauseTextModel(text_encoder=text_encoder_pretrained.module, num_classes=num_classes)
vars = ect_model.init(init_key, **fake_batch)
params = vars['params']

# Transfer the pretrained weights 
params['text_encoder'] = text_encoder_pretrained.params

In [7]:
print(ect_model.tabulate(init_key, **fake_batch))


[3m                         EmotionCauseTextModel Summary                          [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule       [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs     [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│               │ EmotionCause… │ conv_attn_ma… │ [2mfloat32[0m[4,3… │               │
│               │               │ [2mfloat32[0m[4,33] │              │               │
│               │               │ input_ids:    │              │               │
│               │               │ [2mfloat32[0m[4,33… │              │               │
│               │               │ train: False  │              │               │
│               │               │ uttr_attn_ma… │              │               │
│               │               │ [2

In [6]:
def test_jit(x, params, train, drop_key):
    y= ect_model.apply({ 'params': params }, **x, train=train, rngs={ 'dropout': drop_key })
    return y

In [7]:
speed = jax.jit(test_jit, static_argnames='train')

In [11]:
for X in dataloader['train']:
    X = { 'input_ids': X['input_ids'], 'attn_mask': X['attention_mask'] }
    y = speed(X, params, True, init_key)

In [7]:
# Train only the classifier for now...
should_freeze = lambda p, _: 'frozen' if 'text_encoder' in p else 'trainable'
param_labels = flax.traverse_util.path_aware_map(should_freeze, params)
tx = opt.multi_transform({ 'trainable': opt.adamw(2e-4), 'frozen': opt.set_to_zero() }, param_labels)
opt_state = tx.init(params)