In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Internal Imports

In [8]:
# Find ".env" file and add the package to $PATH
import os, sys
import typing as t
from typing import TypeAlias, Any
from dotenv import find_dotenv
sys.path.append(os.path.dirname(find_dotenv()))

# Use local package for modularity
import emotion_analysis as ea
import emotion_analysis.data.dataset as data
import emotion_analysis.data.transform as transform
from emotion_analysis.model.emotion_cause_text import load_text_model
from emotion_analysis.model.emotion_cause_text import EmotionCauseTextModel

External Imports

In [9]:
import jax
import jax.numpy as jnp
import jax.random as jrm
from jax import Array
from jax.typing import ArrayLike, DTypeLike
import flax
import flax.linen as nn
import numpy as np
import optax as opt
import mlflow as mlf
from transformers import RobertaConfig, FlaxRobertaModel, RobertaTokenizerFast
from transformers import PretrainedConfig
from torch.utils.data import Dataset, DataLoader, Subset, random_split

In [25]:
dataset = data.ECACDataset(ea.DATA_DIR, 'task_1', 'train')

In [4]:
# Load pretrained model along with its tokenizer
text_encoder_pretrained = load_text_model()

# Select tokenization method
batch_size = 32
max_sen_len = 93
max_doc_len = 33
tokenizer = text_encoder_pretrained.tokenizer
tokenize = transform.Tokenize(tokenizer, max_length=max_sen_len, padding='max_length')
collate = transform.DataTokenizerCollator(tokenize, max_length=max_doc_len)

# Load and prepare the data
task = 'task_1'
DataSplit: TypeAlias = t.Literal['train', 'valid', 'test']
ds_test = data.ECACDataset(ea.DATA_DIR, task, split='test')
ds_train = data.ECACDataset(ea.DATA_DIR, task, split='train')
ds_train, ds_valid = random_split(ds_train, [0.75, 0.25])
num_classes: int = ds_test.num_emotions

# Split into train-valid-test
# Train: drop_last=True to avoid JAX graph recompilation
dataloader: t.Dict[DataSplit, DataLoader[t.Dict[str, Array]]] = {
    'train': DataLoader(ds_train, batch_size, True, collate_fn=collate, drop_last=True),
    'valid': DataLoader(ds_valid, batch_size, False, collate_fn=collate),
    'test': DataLoader(ds_test, batch_size, False, collate_fn=collate),
}

In [31]:
# Generate a PR key for init
key = jrm.key(ea.SEED)
key, init_key = jrm.split(key, 2)

# Generate fake data to "imitate" a batch
fake_input_ids = jnp.zeros((batch_size, max_doc_len * max_doc_len))
fake_attn_mask = jnp.zeros_like(fake_input_ids)
fake_batch: Any = dict(input_ids=fake_input_ids, attention_mask=fake_attn_mask)

# Initialize the model with random params
ec_model = EmotionCauseTextModel(text_encoder=text_encoder_pretrained.module, num_classes=num_classes)
vars = ec_model.init(init_key, **fake_batch)
params = vars['params']

# Transfer the pretrained weights 
params['text_encoder'] = text_encoder_pretrained.params

In [33]:
# Train only the classifier for now...
should_freeze = lambda p, _: 'frozen' if 'text_encoder' in p else 'trainable'
param_labels = flax.traverse_util.path_aware_map(should_freeze, params)
tx = opt.multi_transform({ 'trainable': opt.adamw(2e-4), 'frozen': opt.set_to_zero() }, param_labels)
opt_state = tx.init(params)