# Distributed Training with LightningLite, SageMaker, and Flair

## Getting started
- [Getting Started with Tensor Parallelism using the SageMaker Model Parallelism Library
](https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel/gpt-j/11_train_gptj_smp_tensor_parallel_notebook.ipynb)
- [LightningLite Integration with Flair](https://github.com/flairNLP/flair/pull/2700)
- [LIGHTNINGLITE - STEPPING STONE TO LIGHTNING](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.html)

## Debug training
- [PYTORCH_LIGHTNING.LITE.LIGHTNINGLITE](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.lite.LightningLite.html#pytorch_lightning.lite.LightningLite)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import inspect
import json
import logging
import os
import sys
from dataclasses import dataclass, field

import torch
from transformers import HfArgumentParser

import flair
from flair import set_seed
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from custom_trainer import LiteTrainer # changed
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast # changed

logger = logging.getLogger("flair")
logger.setLevel(level="INFO")


@dataclass
class model_args:
    model_name_or_path: str = field(
        metadata={"help": "The model checkpoint for weights initialization."},
    )
    layers: str = field(default="-1", metadata={"help": "Layers to be fine-tuned."})
    subtoken_pooling: str = field(
        default="first",
        metadata={"help": "Subtoken pooling strategy used for fine-tuned."},
    )
    hidden_size: int = field(default=256, metadata={"help": "Hidden size for NER model."})
    use_crf: bool = field(default=False, metadata={"help": "Whether to use a CRF on-top or not."})


@dataclass
class training_args:
    num_epochs: int = field(default=10, metadata={"help": "The number of training epochs."})
    batch_size: int = field(default=8, metadata={"help": "Batch size used for training."})
    mini_batch_chunk_size: int = field(
        default=1,
        metadata={"help": "If smaller than batch size, batches will be chunked."},
    )
    learning_rate: float = field(default=5e-05, metadata={"help": "Learning rate"})
    seed: int = field(default=42, metadata={"help": "Seed used for reproducible fine-tuning results."})
    device: str = field(default="cuda:0", metadata={"help": "CUDA device string."})
    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for optimizer."})
    embeddings_storage_mode: str = field(default="none", metadata={"help": "Defines embedding storage method."})
    accelerator: Optional[str] = field(default=None, metadata={"help": "Choose the hardware to run on e.g. 'gpu'."})
    strategy: Optional[str] = field(
        default=None, 
        metadata={"help": "Strategy for how to run across multiple devices e.g. 'ddp', 'deepspeed'."})
    devices: Optional[int] = field(
        default=None, 
        metadata={"help": "Number of devices to train on (int), which GPUs to train on (list or str)"})
    num_nodes: Optional[int] = field(default=1, metadata={"help": "Number of GPU nodes for distributed training."})
    precision: Optional[int] = field(default=32, metadata={"help": "Choose training precision to use."})

@dataclass
class flert_args:
    context_size: int = field(default=0, metadata={"help": "Context size when using FLERT approach."})
    respect_document_boundaries: bool = field(
        default=False,
        metadata={"help": "Whether to respect document boundaries or not when using FLERT."},
    )


@dataclass
class data_args:
    dataset_name: str = field(metadata={"help": "Flair NER dataset name."})
    dataset_arguments: str = field(default="", metadata={"help": "Dataset arguments for Flair NER dataset."})
    output_dir: str = field(
        default="resources/taggers/ner",
        metadata={"help": "Defines output directory for final fine-tuned model."},
    )


def get_flair_corpus(data_args):
    ner_task_mapping = {}

    for name, obj in inspect.getmembers(flair.datasets.sequence_labeling):
        if inspect.isclass(obj):
            if name.startswith("NER") or name.startswith("CONLL") or name.startswith("WNUT"):
                ner_task_mapping[name] = obj

    dataset_args = {}
    dataset_name = data_args.dataset_name

    if data_args.dataset_arguments:
        dataset_args = json.loads(data_args.dataset_arguments)

    if dataset_name not in ner_task_mapping:
        raise ValueError(f"Dataset name {dataset_name} is not a valid Flair datasets name!")

    return ner_task_mapping[dataset_name](**dataset_args)

In [None]:
data_args.dataset_name = 'NER_ENGLISH_PERSON'
data_args.output_dir = 'ner-english-test'
model_args.model_name_or_path = 'xlm-roberta-base'
training_args.batch_size = 32
training_args.learning_rate = 5e-05
training_args.accelerator = 'gpu'
training_args.strategy = None 
training_args.devices = 1
training_args.num_nodes = 1
training_args.precision = 16
training_args.num_epochs = 50
training_args.context_size = 64

In [None]:
set_seed(training_args.seed)

flair.device = training_args.device

corpus = get_flair_corpus(data_args)

logger.info(corpus)

tag_type: str = "ner"
tag_dictionary = corpus.make_label_dictionary(tag_type)
logger.info(tag_dictionary)

embeddings = TransformerWordEmbeddings(
    model=model_args.model_name_or_path,
    layers=model_args.layers,
    subtoken_pooling=model_args.subtoken_pooling,
    fine_tune=True,
    use_context=flert_args.context_size,
    respect_document_boundaries=flert_args.respect_document_boundaries,
)

tagger = SequenceTagger(
    hidden_size=model_args.hidden_size,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_type=tag_type,
    use_crf=model_args.use_crf,
    use_rnn=False,
    reproject_embeddings=False,
)

# changed
trainer = LiteTrainer( 
    accelerator=training_args.accelerator,
    strategy=training_args.strategy,
    devices=training_args.devices,
    num_nodes=training_args.num_nodes,
    precision=training_args.precision,
)

# changed
trainer.train(tagger, corpus,
    data_args.output_dir,
    learning_rate=training_args.learning_rate,
    mini_batch_size=training_args.batch_size,
    mini_batch_chunk_size=training_args.mini_batch_chunk_size,
    max_epochs=training_args.num_epochs,
    embeddings_storage_mode=training_args.embeddings_storage_mode,
    weight_decay=training_args.weight_decay,
)

torch.save(model_args, os.path.join(data_args.output_dir, "model_args.bin"))
torch.save(training_args, os.path.join(data_args.output_dir, "training_args.bin"))

# finally, print model card for information
tagger.print_model_card()