<a href="https://colab.research.google.com/github/lucarinelli/conditional_text_generation/blob/main/notebooks/One_Notebook_To_Rule_Them_All.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Check allocated GPU

In [1]:
!nvidia-smi

Wed May 26 21:33:31 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0    33W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Install needed python packages

In [2]:
!pip install --quiet transformers datasets tokenizers sacrebleu wandb

## Connect to WandB

In [3]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlucarinelli[0m (use `wandb login --relogin` to force relogin)


True

## Set experiment parameters

In [4]:
%env WANDB_PROJECT=ctrl_dry_runs
%env WANDB_ENTITY=polito_aiml2021_textgen

experiment_parameters = dict(
    run_name = "exp1",  # String, experiment name
    use_control_codes = True,  # True/False, enable conditional text generation or do basic text generation
    force_dataset_update = True, # True/False, enable database updates even if it is already present on the file system
    control_codes_type = "special_token",  # "special_token"/"separators"
    use_supercategories = True,  # True/False, add supercategories as control codes 
    use_categories = False, # True/False, add categories as control codes    
    use_control_codes_powerset = False,  # True/False, use powerset of control codes for each caption to augment dataset
    max_control_codes_per_caption = 3,  # positive integer, maximum number of control codes to use with one caption during training
    limited_run = True, # if set to True, the datasets will be reduced in size
    max_train_set_len = 1500,  # positive integer, maximum number of items for the training set used
    max_val_set_len = 1000,  # positive integer, maximum number of items for the validation set used
    model="gpt2",  # we tested "distilgpt2" and "gpt2" for now
    #save_model_path = "OUTPUT",
    #random_seed = 42,  # integer, random seed used anywhere it could be useful to add some determinism
)

env: WANDB_PROJECT=ctrl_dry_runs
env: WANDB_ENTITY=polito_aiml2021_textgen


In [5]:
%env WANDB_LOG_MODEL=true
%env WANDB_WATCH=all
%env WANDB_SILENT=true

env: WANDB_LOG_MODEL=true
env: WANDB_WATCH=all
env: WANDB_SILENT=true


In [6]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./data/results",  # output directory
    save_total_limit=3,
    num_train_epochs=3,  # total # of training epochs
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=1,  # batch size for evaluation
    warmup_steps=500,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,
    logging_dir='./data/logs',  # directory for storing logs
    evaluation_strategy="epoch",
    report_to="wandb",
    load_best_model_at_end=True,
    remove_unused_columns=False
)

In [7]:
#TODO integrations with drive for checkpoints? It would work only in colab... not on azure or locally... should be parametrized?

#TODO integration with WandB

# Dataset
We download and load the COCO captions dataset.

We join in a single item the caption for an image with the categories and/or supercategories associated to objects present in the image.
Categories and/or supercategories are used as control codes depending on the experiment settings.

The dataset is then post processed to train the model with different combinations of control codes for each caption, depending on the experiment parameters. The output of the postprocessing is saved on .txt files that are then loaded and further handled by the Dataset class provided by HuggingFace datasets (used for its performance and caching abilities).

Here we start with the functions needed to download the COCO captions dataset and preprocess it for our use case.

In [8]:
#TODO: Should we move this to an external file? Probably not since it is interesting to show?

import os
import sys
import subprocess  # to run sh commands
import json
from torch.utils.data import Dataset
import torch
from pathlib import Path
from itertools import chain, combinations, groupby

!mkdir data
DATA_PATH="./data"

def download_annotations_dataset(data_path=DATA_PATH):
    # download only if don't have it already
    if not os.path.isdir(os.path.join(data_path,"annotations")):
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        subprocess.run(["wget","-P", data_path, "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"])
        subprocess.run(["unzip", "-d", data_path, os.path.join(data_path,"annotations_trainval2017.zip")])

def map_and_join_dataset(data_instances, data_captions):
    if not experiment_parameters["use_categories"] and not experiment_parameters["use_supercategories"]:
        print("One of categories and supercategories has to be used!")
        sys.exit()
    categories_data_dict = dict(map(lambda c: (c["id"], c), data_instances["categories"])) # <category_id, category>
    annotations_data_mapped = map(lambda c: (c["image_id"], c), data_instances["annotations"]) # <image_id, annotation>
    annotations_data_dict = {}
    for a in annotations_data_mapped:
        if a[0] in annotations_data_dict:
            annotations_data_dict[a[0]] += [a[1]]
        else:
            annotations_data_dict[a[0]] = [a[1]]
    captions_data_list = list(map(lambda c: (c["image_id"], c), data_captions["annotations"]))
    captions_data_dict = dict()
    for image_id, image_captions in groupby(sorted(captions_data_list, key=lambda x: x[0]), lambda x: x[0]):
      image_captions_dict = dict()
      for caption in image_captions:
        image_captions_dict[caption[1]["id"]]=caption[1]
      captions_data_dict[image_id]=image_captions_dict

    dataset = []
    control_codes_dict = {}
    no_category_counter = 0
    references_dict = {}

    for image_id, captions in captions_data_dict.items():
        references_dict[image_id] = list(map(lambda x: x[1]["caption"], captions.items()))
        for _, caption in captions.items():
          item = {"caption": caption, "categories": [], "image_id": image_id}
          if image_id in annotations_data_dict:
              tmp_categories_dict = {}
              for a in annotations_data_dict[image_id]:
                  category_name = categories_data_dict[a["category_id"]]["name"]
                  supercategory_name = categories_data_dict[a["category_id"]]["supercategory"]
                  if experiment_parameters["use_categories"]:
                      tmp_categories_dict[category_name] = 1
                      control_codes_dict[category_name] = 1
                  if experiment_parameters["use_supercategories"]:
                    tmp_categories_dict[supercategory_name] = 1
                    control_codes_dict[supercategory_name] = 1
              item["categories"]=list(tmp_categories_dict.keys())
          if len(item["categories"])==0:
              no_category_counter += 1
          else: dataset += [item]

    #TODO compute total of captions?

    print("There are "+str(no_category_counter)+" captions without a category")
    return dataset, references_dict, list(control_codes_dict.keys())

def load_or_setup_dataset(data_path=DATA_PATH, split='train'):
    if not split in ['train', 'val']:
        print("Unknown split: "+split)
        sys.exit()
    if not experiment_parameters["force_dataset_update"] and os.path.isfile(os.path.join(data_path, "dataset_with_ctrl_"+split+".json")):
        print ("Dataset json file, loading dataset...")
        with open(os.path.join(data_path, "dataset_with_ctrl_"+split+".json"), "r") as read_file:
            dataset = json.load(read_file)
        with open(os.path.join(data_path, "control_codes_"+split+".json"), "r") as read_file:
            control_codes = json.load(read_file)
        with open(os.path.join(data_path, "references_"+split+".json"), "r") as read_file:
            references_dict = json.load(read_file)
    else:
        print ("Dataset json file does not exist, creating dataset from scratch...")
        download_annotations_dataset(data_path=data_path)
        with open(os.path.join(data_path,"annotations/instances_"+split+"2017.json"), "r") as read_file:
            data_instances = json.load(read_file)

        with open(os.path.join(data_path,"annotations/captions_"+split+"2017.json"), "r") as read_file:
            data_captions = json.load(read_file)

        dataset, references_dict, control_codes = map_and_join_dataset(data_instances, data_captions)

        with open(os.path.join(data_path,"control_codes_"+split+".json"), 'w') as outfile:
            json.dump(control_codes, outfile)

        with open(os.path.join(data_path,"references_"+split+".json"), 'w') as outfile:
            json.dump(references_dict, outfile)
        
        with open(os.path.join(data_path,"dataset_with_ctrl_"+split+".json"), 'w') as outfile:
            json.dump(dataset, outfile)
    return dataset, references_dict, control_codes

mkdir: cannot create directory ‘data’: File exists


Actually call the functions previously defined

In [9]:
data_path=DATA_PATH

dataset_train, _, categories = load_or_setup_dataset(data_path=data_path, split="train")
dataset_val, references, _ = load_or_setup_dataset(data_path=data_path, split="val")

print("There are "+str(len(dataset_train))+" captions considered in total (train)")
print("There are "+str(len(dataset_val))+" captions considered in total (val)")

print("The following "+str(len(categories))+" categories are present in the dataset:")
print(categories)

if experiment_parameters["use_control_codes"] and experiment_parameters["control_codes_type"] == "special_token":
    control_codes = []
    for category in categories:
        control_codes += ["<CTRL:"+category.replace(" ","_")+">"]

    print("Processed control codes:")
    print(control_codes)

Dataset json file does not exist, creating dataset from scratch...
There are 5107 captions without a category
Dataset json file does not exist, creating dataset from scratch...
There are 240 captions without a category
There are 586646 captions considered in total (train)
There are 24774 captions considered in total (val)
The following 12 categories are present in the dataset:
['kitchen', 'food', 'animal', 'furniture', 'indoor', 'accessory', 'person', 'vehicle', 'outdoor', 'sports', 'appliance', 'electronic']
Processed control codes:
['<CTRL:kitchen>', '<CTRL:food>', '<CTRL:animal>', '<CTRL:furniture>', '<CTRL:indoor>', '<CTRL:accessory>', '<CTRL:person>', '<CTRL:vehicle>', '<CTRL:outdoor>', '<CTRL:sports>', '<CTRL:appliance>', '<CTRL:electronic>']


In [10]:
import multiprocessing as mp

chunk_size = 500

def powerset(iterable, max_size=None):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    if max_size is None:
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
    else:
        return chain.from_iterable(combinations(s, r) for r in range(min(max_size, len(s)+1)))

def process_chunk(chunk):
    chunk_number = chunk[0]
    chunk_items = chunk[1]
    data_path = chunk[2]
    split = chunk[3]
    json_file = os.path.join(data_path, "captions_"+split+"_"+str(chunk_number)+".json")
    captions_array_for_json = []
    for item in chunk_items:
        if experiment_parameters["use_control_codes"]:
            if experiment_parameters["use_control_codes_powerset"]:
                control_codes_combinations = powerset(item['categories'], experiment_parameters["max_control_codes_per_caption"])
            else:
                control_codes_combinations = [item['categories']]
        else:
            control_codes_combinations = [[]]
        for control_codes_combination in control_codes_combinations:
            pre_control_codes_string=""
            for category in sorted(control_codes_combination):
                if experiment_parameters["control_codes_type"] == "special_token":
                    pre_control_codes_string+="<CTRL:"+category.replace(" ","_")+">"
                elif experiment_parameters["control_codes_type"] == "separators":
                    pre_control_codes_string+=category+", "
                else:
                    print("ERROR: wrong control code type")
                    return -1  # TODO here we could fail better
            captions_array_for_json += [{"caption": pre_control_codes_string+'<|endoftext|>'+item["caption"]["caption"]+'<|endoftext|>',"image_id": item["caption"]["image_id"]}]
    with open(json_file, 'w') as captions_json:
        json.dump({"data": captions_array_for_json}, captions_json)


def write_json_chunks(dataset, split, data_path, chunk_size):
    chunks = [dataset[start:min(start+chunk_size,len(dataset))] for start in range(0, len(dataset), chunk_size)]
    pool = mp.Pool(processes=8)
    pool.map(process_chunk, [(chunk_n, chunk_items, data_path, split) for chunk_n, chunk_items in enumerate(chunks)])

In [11]:
write_json_chunks(dataset_train, "train", data_path, chunk_size)
write_json_chunks(dataset_val, "val", data_path, chunk_size)

In [12]:
from datasets import load_dataset, Dataset
import glob

dataset_train, dataset_val = load_dataset('json', data_files={'train': glob.glob('./data/captions_train_*.json'), 'val': glob.glob('./data/captions_val_*.json')}, split=['train', 'val'], field="data")
print("Augmented dataset has: "+str(len(dataset_train))+" train elements and "+str(len(dataset_val))+" validation elements")

if experiment_parameters["limited_run"]: # shuffle and cut the datasets
  dataset_train = dataset_train.shuffle(42).select(range(experiment_parameters["max_train_set_len"]))
  dataset_val = dataset_val.shuffle(42).select(range(experiment_parameters["max_val_set_len"]))
  print("We take only a small part of that: "+str(len(dataset_train))+" train elements and "+str(len(dataset_val))+" validation elements")
else: # just shuffle them
  dataset_train = dataset_train.shuffle(42)
  dataset_val = dataset_val.shuffle(42)
  print("Train elements: "+str(len(dataset_train))+"\nValidation elements: "+str(len(dataset_val)))

Using custom data configuration default-5aec630f072696ca


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/json/default-5aec630f072696ca/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-5aec630f072696ca/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.
Augmented dataset has: 586646 train elements and 24774 validation elements
We take only a small part of that: 1500 train elements and 1000 validation elements


# Tokenization

In [13]:
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained(experiment_parameters['model'])
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer before added special tokens "+str(len(tokenizer)))

if experiment_parameters["use_control_codes"] and experiment_parameters["control_codes_type"] == "special_token":
    special_tokens_dict = {'additional_special_tokens': control_codes}
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    print("added "+str(num_added_toks)+" tokens to the pretrained tokenizer")

Tokenizer before added special tokens 50257
added 12 tokens to the pretrained tokenizer


In [14]:
def encode(examples):
    encoded = tokenizer(examples['caption'], truncation=True, max_length=64, padding="max_length")
    encoded['labels'] = encoded['input_ids']
    encoded['image_id'] = examples['image_id']
    return encoded

dataset_train_encoded = dataset_train.map(encode, batched=True)
dataset_val_encoded = dataset_val.map(encode, batched=True)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




# Model

In [15]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained(experiment_parameters['model'], pad_token_id=tokenizer.eos_token_id)
model.resize_token_embeddings(len(tokenizer))

Embedding(50269, 768)

In [16]:
#TODO add the possibility to freeze some layers? Add an experiment parameter for this?
#TODO print the model structure?

# Training

In [17]:
import random
import torch
import numpy as np

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [18]:
import datasets

def compute_metrics(pred, image_ids):
  labels = pred.label_ids
  preds = pred.predictions
  metric = datasets.load_metric('sacrebleu')

  preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  references_local_list = [references[image_id.item()] for image_id in image_ids]

  final_score = metric.compute(predictions=preds, references=references_local_list)
  return {
      'bleu': final_score
  }

In [19]:
#TODO: Should we move this to an external file?

from transformers import Trainer

import collections
import inspect
import math
import os
import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from logging import StreamHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from tqdm.auto import tqdm


# Integrations must be imported before ML frameworks:
from transformers.integrations import (  # isort: split
    default_hp_search_backend,
    get_reporting_integration_callbacks,
    hp_params,
    is_fairscale_available,
    is_optuna_available,
    is_ray_tune_available,
    run_hp_search_optuna,
    run_hp_search_ray,
    deepspeed_init,
    is_deepspeed_zero3_enabled,
)

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from transformers import __version__
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.dependency_versions_check import dep_version_check
from transformers.file_utils import (
    CONFIG_NAME,
    WEIGHTS_NAME,
    PushToHubMixin,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
    is_training_run_on_sagemaker,
)
from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.optimization import Adafactor, AdamW, get_scheduler
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    DistributedLengthGroupedSampler,
    DistributedSamplerWithLoop,
    DistributedTensorGatherer,
    IterableDatasetShard,
    LabelSmoother,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    ShardSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    find_batch_size,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_truncate,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalLoopOutput,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    ShardedDDPOption,
    TrainerMemoryTracker,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    denumpify_detensorize,
    get_last_checkpoint,
    set_seed,
    speed_metrics,
)
from transformers.training_args import ParallelMode, TrainingArguments
from transformers.utils import logging
from transformers.utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES


_is_torch_generator_available = False
_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from transformers.utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

if is_apex_available():
    from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_torch_generator_available = True
    _is_native_amp_available = True
    from torch.cuda.amp import autocast

if is_datasets_available():
    import datasets

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
    dep_version_check("fairscale")
    import fairscale
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
    from fairscale.nn.wrap import auto_wrap
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

if is_sagemaker_dp_enabled():
    import smdistributed.dataparallel.torch.distributed as dist
    from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
    import torch.distributed as dist

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

if is_training_run_on_sagemaker():
    logging.add_handler(StreamHandler(sys.stdout))


if TYPE_CHECKING:
    import optuna

logger = logging.get_logger(__name__)

class MyTrainer(Trainer):
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        # if eval is called w/o train init deepspeed here
        if self.args.deepspeed and not self.deepspeed:

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

        model = self._wrap_model(self.model, training=False)

        # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
        # ``train`` is running, halve it first and then put on device
        if not self.is_in_train and self.args.fp16_full_eval:
            model = model.half().to(self.args.device)

        batch_size = dataloader.batch_size

        logger.info(f"***** Running {description} *****")
        if isinstance(dataloader.dataset, collections.abc.Sized):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = dataloader.dataset

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size

            # Prediction step
            if isinstance(inputs, list):
                inputs_for_prediction = [dict(filter(lambda i: i[0]!='image_id', input.items())) for input in inputs]
            else:
                inputs_for_prediction = dict(filter(lambda i: i[0]!='image_id', inputs.items()))
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            ############################
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                logits_reduced = np.argmax(logits.cpu(), axis=-1) # Obtain a single value instead of a vector, for memory efficiency
                preds_host = logits_reduced if preds_host is None else nested_concat(preds_host, logits_reduced, padding_index=-100)
                # preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            ############################
            if labels is not None:
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
        if not isinstance(eval_dataset, IterableDataset):
            num_samples = len(eval_dataset)
        elif isinstance(eval_dataset, IterableDatasetShard):
            num_samples = eval_dataset.num_examples
        else:
            num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)

        if isinstance(inputs, list):
            image_ids = [input["image_id"] for input in inputs]
        else:
            image_ids = inputs["image_id"]
        
        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels), image_ids)
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)


In [20]:
dataset_train_encoded.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
dataset_val_encoded.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'image_id'])

trainer = MyTrainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=dataset_train_encoded,         # training dataset
    eval_dataset=dataset_val_encoded,
    compute_metrics=compute_metrics,
    )

In [21]:
trainer.train()

config = wandb.config
config.update(experiment_parameters)

Epoch,Training Loss,Validation Loss


TypeError: ignored

In [None]:
trainer.save_model("./data/results")
wandb.finish()