References:
- [dask-pytorch-ddp example](https://pypi.org/project/dask-pytorch-ddp/)
- [Integrated dask and pytorch](https://saturncloud.io/blog/combining-dask-and-py-torch-for-better-faster-transfer-learning/)
- [Torch DDP](https://https://pytorch.org/docs/stable/notes/ddp.html)
- [More on DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Pytorch across dask cluster](https://https://saturncloud.io/docs/examples/python/pytorch/qs-03-pytorch-gpu-dask-single-model/)
- [DDP and Output](https://github.com/saturncloud/dask-pytorch-ddp/blob/main/README.md)

# Where are we at right now?
We need to figure out how to get results (estimated time, loss data, etc.) through the `rh.process_results()`. Rn it's hard to tell what's going on because I think the `process_results()` function swallows the normal output from `Trainer.train()` method. Would recommend reading "DDP and Output" article above if you want to work on this part.

We could integrate TFRecords, however, that might change our training function which takes Datasets. We can convert from TFRecords to Datasets if necessary. Luckily, the Torch Dataset structure is similar in effect to TFRecords, and actually only represents the paths to the data itself and not the actual data. The train method uses a data preloader which doesn't actually hold the data in memory, so these things are automatically integrated because we're using the Dataset data type. We should clarify this, maybe with a couple memory tests comparing size of a Dataset variable to the size of the actual file.

In [None]:
# !pip install -q wandb
# !pip install datasets
# !pip install seqeval
# !pip install evaluate
# !pip install datasets transformers==4.28.0
# !pip install transformers[torch]
# !pip install dask-pytorch-dpp

In [None]:
# Utils
import time
import uuid
import datetime
import pickle
import json

# Stand
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from collections import Counter
from tqdm import tqdm

# Torch
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, TrainerCallback
from transformers import DataCollatorForTokenClassification
import evaluate
from datasets import load_dataset
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

# Dask
from dask_pytorch_ddp import dispatch, results
from dask.distributed import Client
from distributed.worker import logger

# W and B
import wandb

In [None]:
wandb.login()

### Helper Functions

In [None]:
def tokenize_words_with_corresponding_labels(sample):

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

    #truncation=True to specify to truncate sequences at the maximum length
    #is_split_into_words = True to specify that our input is already pre-tokenized (e.g., split into words)
    tokenized_inputs = tokenizer(sample["document"], truncation=True, is_split_into_words=True)

    #initialize list to store lists of labels for each sample
    labels = []

    for i, label in enumerate(sample["doc_bio_tags"]):

        #map tokens to their respective word
        #word_ids() method gets index of the word that each token comes from
        word_ids = tokenized_inputs.word_ids(batch_index=i)

        #initialize list of labels for each token in a given sample
        label_ids = []

        for word_idx in word_ids:

            #set the special tokens, [CLS] and [SEP], to -100.
            # we use -100 because it's an index that is ignored in the loss function we will use (cross entropy).
            if word_idx is None:
                label_ids.append(-100)

            #set labels for tokens
            else:
                label_ids.append(label2id[label[word_idx]])

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels

    return tokenized_inputs

In [None]:
metric = evaluate.load("seqeval")

def compute_metrics(preds):
    logits, labels = preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[id2label[l] for l in label if l != -100] for label in labels]

    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return all_metrics

## Loading in Inspec Dataset with samples

In [None]:
device = torch.device("cuda")

In [None]:
dataset = load_dataset("midas/inspec", "extraction")
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
tokenized_dataset = dataset.map(tokenize_words_with_corresponding_labels, batched=True)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
#counting how many beginning keywords, middle keywords, and non-keywords there are
count_0s = 0
count_1s = 0
count_2s = 0

for listt in tokenized_dataset["train"]["labels"]:
    count_dict = Counter(listt)
    count_0s += count_dict[0]
    count_1s += count_dict[1]
    count_2s += count_dict[2]

#getting weights for weighted cross_entropy
max_ = max(count_0s,count_1s,count_2s)
weights = [max_/count_0s, max_/count_1s, max_/count_2s]

#defining loss function
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels").to(model.device)
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits").to(model.device)
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = torch.nn.CrossEntropyLoss(weight= torch.tensor(weights).to(device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [None]:
# Create a new Dask client that makes use of all available GPUs
client = Client()

In [None]:
# A test to see how we can get output from train function

# class MyCallback(TrainerCallback):
#     "A callback that prints a message at the beginning of training"

#     def on_train_begin(self, args, state, control, **kwargs):
#         print("Starting training")

In [None]:
key = uuid.uuid4().hex
rh = results.DaskResultsHandler(key)

# Define the training loop
def train(model_checkpoint, dataset, data_collator, compute_metrics,  tokenizer):

  batch_size = 8
  learning_rate=4e-6

  epochs = 1
  model_name = model_checkpoint.split("/")[-1]
  args = TrainingArguments(
      f"{model_checkpoint}_finetuned_keyword_extract",
      evaluation_strategy = "epoch",
      logging_strategy = 'epoch',
      learning_rate=learning_rate,
      per_device_train_batch_size=batch_size,
      per_device_eval_batch_size=batch_size,
      num_train_epochs= epochs,
      lr_scheduler_type='linear',
      weight_decay=0.01,
      seed=0
  )

  model_token = AutoModelForTokenClassification.from_pretrained(model_checkpoint,
                                                          id2label=id2label,
                                                      label2id=label2id)
  model = model_token.to(device) #need GPU to train
  model = DDP(model, device_ids=[0])


  # Initialize a W&B run
  wandb.init(
      project = 'ppp-keyword-extraction',
      config = {
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "model_name": model_name
      },
      name = model_name
  )


  # Train model
  start_time = time.time()
  print('hey')

  # Uses dataloaders (≈TFRecords)
  trainer = CustomTrainer(
      model=model,
      args=args,
      train_dataset=dataset["train"],
      eval_dataset=dataset["validation"],
      data_collator=data_collator,
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      callbacks=[MyCallback],)
  trainer.train()


  # Update W&B
  execution_time = (time.time() - start_time)/60.0
  wandb.config.update({"execution_time": execution_time})
  # Close the W&B run
  wandb.run.finish()

  return trainer

In [None]:
start_params = {'model_checkpoint':model_checkpoint,
                'dataset':tokenized_dataset,
                'data_collator':data_collator,
                'compute_metrics':compute_metrics,
                'tokenizer':tokenizer
}

futures = dispatch.run(client, train, **start_params)

In [None]:
rh.process_results(
    "/",
    futures,
    raise_errors=True)