### Notebook Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dotenv import find_dotenv
import pathlib as pb
import sys


# Add `ape` package to SYS path
sys.path.append(str(pb.Path(find_dotenv()).parent))

### Imports

In [3]:
from typing import Dict, cast
import re
import torch
from torch.utils.data import Dataset, Subset
from torch.utils.data import DataLoader, random_split
import evaluate as eval
import lightning as lit
from lightning.pytorch.loggers import MLFlowLogger
from dataclasses import asdict
from transformers import AlbertTokenizer
from collections import defaultdict, namedtuple
from transformers import AutoModel, AutoTokenizer, XLMRobertaForCausalLM
from functools import partial

In [4]:
from onmt.onmt.modules.position_ffn import ActivationFunction
from onmt.onmt.translate.translator import GeneratorLM

In [5]:
import ape
from ape.data.types import DataSplit
from ape.data.types import APETripletDict
from ape.eval.metrics import APEMetrics
from ape.data.dataset import APEDataset
from ape.data.transform import Tokenize, HFTokenizer
from ape.model.mst import MultiSourceTransformerCausalLM
from ape.model.encoders import MultiSourceTransformerEncoder
from ape.model.decoders import MultiSourceTransformerDecoder
from ape.light.causal_lm import MultiSourceCausalLMLightningModule

### Data & Metrics

In [6]:
# Load TER, chrF and BLEU
metrics = APEMetrics(cache_dir=ape.HF_CACHE_DIR)

# Load APE Dataset (Original + Synthetic)
ds_test = APEDataset(path=ape.DATA_DIR, split='test')
ds_train = APEDataset(path=ape.DATA_DIR, split='train')
ds_train, ds_valid = random_split(ds_train, lengths=[0.99, 0.01], generator=ape.gen_torch)

# Aggregate all subsets into a single object
ds: Dict[DataSplit, Dataset[APETripletDict[str]]] = {
    'train': ds_train,
    'valid': ds_valid,
    'test': ds_test,
}

In [7]:
# Choose encoders for each input type
encoder_type_src = 'roberta-base'
encoder_type_mt = 'l3cube-pune/marathi-roberta'

# Load source and target tokenizers
hf_tokenizer_src = AutoTokenizer.from_pretrained(encoder_type_src,
                                                                                 use_fast=True,
                                                                                 padding_side='right',
                                                                                 truncation_side='right',
                                                                                 cache_dir=ape.HF_CACHE_DIR / 'tokenizers')
hf_tokenizer_mt = AutoTokenizer.from_pretrained(encoder_type_mt,
                                                                                use_fast=True,
                                                                                padding_side='right',
                                                                                truncation_side='right',
                                                                                cache_dir=ape.HF_CACHE_DIR / 'tokenizers')

# Wrap and customize HF Tokenizers
max_seq_len = 512
tokenizer_src = HFTokenizer(hf_tokenizer_src, source_prefix='src', max_length=max_seq_len)
tokenizer_mt = HFTokenizer(hf_tokenizer_mt, source_prefix='mt', max_length=max_seq_len)
tokenizer_pe = HFTokenizer(hf_tokenizer_mt, source_prefix='pe', max_length=max_seq_len)
tokenize = Tokenize([tokenizer_src, tokenizer_mt, tokenizer_pe])

# Use same settings across all splits
DefaultDataLoader = partial(DataLoader,
                            collate_fn=tokenize,
                            num_workers=ape.WORKERS,
                            batch_size=ape.BATCH_SIZE,
                            prefetch_factor=ape.PREFETCH_FACTOR)

# Aggregate all dataloaders into a single object
dl: Dict[DataSplit, DataLoader] = {
    'train': DefaultDataLoader(ds['train']),
    'valid': DefaultDataLoader(ds['valid']),
    'test': DefaultDataLoader(ds['test']),
}

### Models

In [8]:
model = MultiSourceCausalLMLightningModule(
    bos_token_id=hf_tokenizer_mt.bos_token_id,
    encoder_type_src=encoder_type_src,
    encoder_type_mt=encoder_type_mt,
    tokenizer_mt=hf_tokenizer_mt,
    block_size=max_seq_len,
    temperature=1.0,
    do_sample=True,
    top_k=8,
)

logger = MLFlowLogger(
    experiment_name='Automatic Post-Editing',
    tracking_uri=ape.MLFLOW_TRACKING_URI,
    tags={ 'test': 'true', },
)

trainer = lit.Trainer(
    logger=logger,
    limit_val_batches=200,
    val_check_interval=5_000,
    accumulate_grad_batches=1,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Training

In [9]:
trainer.fit(model, train_dataloaders=dl['train'], val_dataloaders=dl['valid'])

/home/invokariman/.cache/pypoetry/virtualenvs/ape-zcZ_0igR-py3.11/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                           | Params
---------------------------------------------------------
0 | model | MultiSourceTransformerCausalLM | 629 M 
---------------------------------------------------------
629 M     Trainable params
0         Non-trainable params
629 M     Total params
2,519.085 Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/invokariman/.cache/pypoetry/virtualenvs/ape-zcZ_0igR-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/invokariman/.cache/pypoetry/virtualenvs/ape-zcZ_0igR-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]