<a href="https://colab.research.google.com/github/AxelAllen/Pre-trained-Multimodal-Text-Image-Classifier-in-a-Sparse-Data-Application/blob/master/run_mmbt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Run MMBT Experiments

This notebook shows the end-to-end pipeline for fine-tuning pre-trained MMBT model for multimodal (text and image) classification on our dataset.

Parts of this pipeline are adapted from the
Huggingface `run_mmimdb.py` script to execute the MMBT model. This code can
be accessed [here.](https://github.com/huggingface/transformers/blob/8ea412a86faa8e9edeeb6b5c46b08def06aa03ea/examples/research_projects/mm-imdb/run_mmimdb.py#L305)

In [1]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 2070 with Max-Q Design


## Install Huggingface Library

These should have been installed during your environment set-up; you only need to run these cells in Google Colab.

In [9]:
!pip install transformers



# Data directories and file paths

Paths to data files options are provide in the following cell. Uncomment the train/val/test partitions according to the desired labeling scheme:

- filenames with 'major' are labeled with the 'major' metadata column text
- filenames without are labeled with the 'impression' metadata column text
- filenames with 'multi' are labeled for multiclass classification
- filename without 'multi' are labeled for binary classification


In [10]:
#train_file = "image_labels_impression_frontal_train.jsonl"
#val_file = "image_labels_impression_frontal_val.jsonl"
#test_file = "image_labels_impression_frontal_test.jsonl"

#train_file = "image_multi_labels_major_findings_frontal_train.jsonl"
#val_file = "image_multi_labels_major_findings_frontal_val.jsonl"
#test_file = "image_multi_labels_major_findings_frontal_test.jsonl"


#train_file = "image_labels_major_findings_frontal_train.jsonl"
#val_file = "image_labels_major_findings_frontal_val.jsonl"
#test_file = "image_labels_major_findings_frontal_test.jsonl"


train_file = "image_labels_findings_frontal_train.jsonl"
val_file = "image_labels_findings_frontal_val.jsonl"
test_file = "image_labels_findings_frontal_test.jsonl"

## Import Required Modules

In [11]:
from textBert_utils import set_seed
from MMBT.image import ImageEncoderDenseNet
from MMBT.mmbt_config import MMBTConfig
from MMBT.mmbt import MMBTForClassification
import tqdm as notebook_tqdm

In [12]:
from MMBT.mmbt_utils import JsonlDataset, get_image_transforms, get_labels, load_examples, collate_fn, get_multiclass_labels, get_multiclass_criterion

In [13]:
import argparse

In [14]:
import glob
import logging
import random
import json
import os
from collections import Counter
import numpy as np
from matplotlib.pyplot import imshow

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

In [15]:
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm, trange

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

# Set-up Experiment Hyperparameters and Arguments

Specify the training, validation, and test files to run the experiment on. The default here is running the model on 'impression' texts.  

To re-make the training, validation, and test data, please refer to the information in the **data/** directory.  

Change the default values in the parser.add_argument function for the hyperparameters that you want to specify in the following cell or use the default option.  

For multiple experiment runs, please make sure to change the `output_dir` argument so that new results don't overwrit existing ones.

The arguments specified here are the same as in the `run_mmimdb.py` file 
in the [Huggingface example implementation of MMBT.](https://github.com/huggingface/transformers/blob/8ea412a86faa8e9edeeb6b5c46b08def06aa03ea/examples/research_projects/mm-imdb/run_mmimdb.py#L305)

In [16]:
parser = argparse.ArgumentParser(f'Project Hyperparameters and Other Configurations Argument Parser')

parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
    "--data_dir",
    default="data/json",
    type=str,
    help="The input data dir. Should contain the .jsonl files.",
)
parser.add_argument(
    "--model_name",
    default="bert-base-uncased",
    type=str,
    help="model identifier from huggingface.co/models",
)
parser.add_argument(
    "--output_dir",
    default="mmbt_output_findings_10epochs_n",
    type=str,
    help="The output directory where the model predictions and checkpoints will be written.",
)

    
parser.add_argument(
    "--config_name", default="bert-base-uncased", type=str, help="Pretrained config name if not the same as model_name"
)
parser.add_argument(
    "--tokenizer_name",
    default="bert-base-uncased",
    type=str,
    help="Pretrained tokenizer name or path if not the same as model_name",
)

parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.")
parser.add_argument(
    "--eval_batch_size", default=32, type=int, help="Batch size for evaluation."
)
parser.add_argument(
    "--max_seq_length",
    default=300,
    type=int,
    help="The maximum total input sequence length after tokenization. Sequences longer "
    "than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
    "--num_image_embeds", default=3, type=int, help="Number of Image Embeddings from the Image Encoder"
)
parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.")
parser.add_argument("--do_eval", default=True, type=bool, help="Whether to run eval on the dev set.")
parser.add_argument(
    "--evaluate_during_training", default=True, type=bool, help="Run evaluation during training at each logging step."
)


parser.add_argument(
    "--gradient_accumulation_steps",
    type=int,
    default=1,
    help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
    "--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform."
)
parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
parser.add_argument(
    "--max_steps",
    default=-1,
    type=int,
    help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

parser.add_argument("--logging_steps", type=int, default=25, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=25, help="Save checkpoint every X updates steps.")
parser.add_argument(
    "--eval_all_checkpoints",
    default=True, type=bool,
    help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)

parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")


args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count() if torch.cuda.is_available() else 0
args.device = device

# for multiclass labeling
args.multiclass = False

In [17]:
# Setup Train/Val/Test filenames
args.train_file = train_file
args.val_file = val_file
args.test_file = test_file

## Showing a sample from JsonDataset
i.e. calling "\_\_getitem\_\_"

Note:   
image_end_token is the BERT token id for [SEP].   
image_start_token is the BERT token id for [CLS]. 


In [18]:
tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=True,
        cache_dir=None,
    )
train_dataset = load_examples(tokenizer, args)

In [20]:
train_dataset[0]

{'image_start_token': tensor(101),
 'image_end_token': tensor(102),
 'sentence': tensor([ 8948,  3961, 23760, 10288,  9739,  5732,  1012,  2053,  2689,  1999,
          1996,  2157,  2690, 21833,  6728,  6305,  9031,  1012,  2053, 22038,
         20348, 29543,  2015,  2030, 11678,  1012, 21908, 28915,  2024,  4069,
         25497,  1012]),
 'image': tensor([[[-0.7650, -0.7479, -0.7308,  ..., -0.3541, -0.3369, -0.3198],
          [-0.7137, -0.7137, -0.6794,  ..., -0.2171, -0.1828, -0.1999],
          [-0.6109, -0.6109, -0.6109,  ..., -0.1143, -0.0801, -0.0801],
          ...,
          [ 1.8722,  1.9064,  1.9064,  ...,  1.6324,  1.6667,  1.7523],
          [ 1.8893,  1.9064,  1.9407,  ...,  1.6153,  1.6838,  1.7523],
          [ 1.8722,  1.9064,  1.9407,  ...,  1.6324,  1.7180,  1.7694]],
 
         [[-0.6527, -0.6352, -0.6176,  ..., -0.2325, -0.2150, -0.1975],
          [-0.6001, -0.6001, -0.5651,  ..., -0.0924, -0.0574, -0.0749],
          [-0.4951, -0.4951, -0.4951,  ...,  0.0126,  0


### Training and Evaluating Functions.

In [21]:
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    
    # add the spexified batch size and output dir for current run to Tensorboard 
    # saved run's name for easy identifiation
    comment = f"train_{args.output_dir}_{args.train_batch_size}"
    tb_writer = SummaryWriter(comment=comment)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=args.train_batch_size,
        collate_fn=collate_fn
    )

    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
    

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info(
        "  Total train batch size = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_eval_metric, n_no_improve = 0, 0
    model.train()
    model.zero_grad()
    optimizer.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Training Batch Iteration")
        for step, batch in enumerate(epoch_iterator):
            # model.train()
            # each sample in batch is a tuple
            # batch is the return of the collate_fn function
            # see function definition for data tuple order
            batch = tuple(t.to(args.device) for t in batch)
            labels = batch[5]
            input_ids = batch[0]
            input_modal = batch[2]
            attention_mask = batch[1]
            modal_start_tokens = batch[3]
            modal_end_tokens = batch[4]

            #inputs = {
            #    "input_ids": batch[0],
            #    "input_modal": batch[2],
            #    "attention_mask": batch[1],
            #    "modal_start_tokens": batch[3],
            #    "modal_end_tokens": batch[4],
            #    "labels": batch[5]
            #}

            if args.multiclass:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=None,
                    return_dict=True
                )
            else:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=labels,
                    return_dict=True
                )
            #logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
            logits = outputs.logits
            if args.multiclass:
                criterion = get_multiclass_criterion(train_dataset)
                loss = criterion(logits, labels)
            else:
                loss = outputs.loss
            
            
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps


            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:

                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training:  
                        # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_last_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["training_loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
                    # uncomment below to be able to save args
                    # torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)


        results = evaluate(args, model, tokenizer)
        if args.multiclass:
            eval_result = results["micro_f1"]
        else:
            eval_result = results["accuracy"]

        if eval_result > best_eval_metric:
            best_eval_metric = eval_result
            n_no_improve = 0
        else:
            n_no_improve += 1

        if n_no_improve > args.patience:
            train_iterator.close()
            break

    tb_writer.close()

    return global_step, tr_loss / global_step

In [22]:
def evaluate(args, model, tokenizer, evaluate=True, test=False, prefix=""):
    
    if test:
        # start a separate tensorboard to track testing eval result
        comment = f"test_{args.output_dir}_{args.eval_batch_size}"
        tb_writer = SummaryWriter(comment=comment)

    eval_output_dir = args.output_dir
    eval_dataset = load_examples(tokenizer, args, evaluate=evaluate, test=test)

    if not os.path.exists(eval_output_dir):
        os.makedirs(eval_output_dir)

    
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn
    )

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            batch = tuple(t.to(args.device) for t in batch)
            labels = batch[5]
            input_ids = batch[0]
            input_modal = batch[2]
            attention_mask = batch[1]
            modal_start_tokens = batch[3]
            modal_end_tokens = batch[4]
            
            if args.multiclass:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=None,
                    return_dict=True
                )
            else:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=labels,
                    return_dict=True
                )
            #logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
            #tmp_eval_loss = criterion(logits, labels)
            logits = outputs.logits
            if args.multiclass:
                criterion = get_multiclass_criterion(eval_dataset)
                tmp_eval_loss = criterion(logits, labels)
            else:
                tmp_eval_loss = outputs.loss
            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        # Move logits and labels to CPU
        if args.multiclass:
            pred = torch.sigmoid(logits).cpu().detach().numpy() > 0.5
        else:            
            pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1).cpu().detach().numpy()
        out_label_id = labels.detach().cpu().numpy()
        preds.append(pred)
        out_label_ids.append(out_label_id)

    eval_loss = eval_loss / nb_eval_steps

    result = {"loss": eval_loss}

    if args.multiclass:
        tgts = np.vstack(out_label_ids)
        preds = np.vstack(preds)
        result["macro_f1"] = f1_score(tgts, preds, average="macro")
        result["micro_f1"] = f1_score(tgts, preds, average="micro")
    else:
        preds = [l for sl in preds for l in sl]
        out_label_ids = [l for sl in out_label_ids for l in sl]
        result["accuracy"] = accuracy_score(out_label_ids, preds)

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
            if test:
                tb_writer.add_scalar(f'eval_{key}', result[key], nb_eval_steps)
    
    if test:
        tb_writer.close()


    return result


## Training MMBT Model 

Set up logging and the MMBT Model. Similar to the text-only model, check points 
are saved during a similar customizable interval.



In [23]:
# Setup logging
logger = logging.getLogger(__name__)
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
                    datefmt="%m/%d/%Y %H:%M:%S",
                    filename=os.path.join(args.output_dir, f"{os.path.splitext(args.train_file)[0]}_logging.txt"),
                    level=logging.INFO)
logger.warning("device: %s, n_gpu: %s",
        args.device,
        args.n_gpu
)
# Set the verbosity to info of the Transformers logger (on main process only):

# Set seed
set_seed(args)

In [25]:
# Setup model
if args.multiclass:
    labels = get_multiclass_labels()
    num_labels = len(labels)
else:
    labels = get_labels()
    num_labels = len(labels)
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=True,
        cache_dir=None,
    )
transformer = AutoModel.from_pretrained(args.model_name, config=transformer_config, cache_dir=None)
img_encoder = ImageEncoderDenseNet(num_image_embeds=args.num_image_embeds)
multimodal_config = MMBTConfig(transformer, img_encoder, num_labels=num_labels, modal_hidden_size=1024)
model = MMBTForClassification(transformer_config, multimodal_config)

model.to(args.device)

logger.info(f"Training/evaluation parameters: {args}")

# Training
if args.do_train:
    train_dataset = load_examples(tokenizer, args)
    # criterion = nn.CrossEntropyLoss
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    logger.info("Saving model checkpoint to %s", args.output_dir)
    # Save a trained model, configuration and tokenizer using `save_pretrained()`.
    # They can then be reloaded using `from_pretrained()`
    model_to_save = (model.module if hasattr(model, "module") else model)  # Take care of distributed/parallel training
    torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
    tokenizer.save_pretrained(args.output_dir)
    transformer_config.save_pretrained(args.output_dir)
    # Good practice: save your training arguments together with the trained model
    torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    # Load a trained model and vocabulary that you have fine-tuned
    model = MMBTForClassification(transformer_config, multimodal_config)
    model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
    tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
    model.to(args.device)
logger.info("***** Training Finished *****")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

: 

: 

## Evaluating on the Test Set



In [None]:
# Evaluation
results = {}
if args.do_eval:
    checkpoints = [args.output_dir]
    if args.eval_all_checkpoints:
        checkpoints = list(os.path.dirname(c) 
        for c in sorted(glob.glob(args.output_dir + "/**/" + 
                                  WEIGHTS_NAME, recursive=False)))
        # recursive=False because otherwise the parent diretory gets included
        # which is not what we want; only subdirectories

    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    for checkpoint in checkpoints:
        global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
        prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
        model = MMBTForClassification(transformer_config, multimodal_config)
        checkpoint = os.path.join(checkpoint, 'pytorch_model.bin')
        model.load_state_dict(torch.load(checkpoint))
        model.to(args.device)
        result = evaluate(args, model, tokenizer, evaluate=True, test=True, prefix=prefix)
        result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
        results.update(result)

results.keys()

Evaluating: 100%|██████████| 18/18 [03:47<00:00, 12.64s/it]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.79it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.79it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.78it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.76it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.75it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.76it/s]
Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.77it/s]
Evaluating: 100%|██████████| 18/18 [00:1

In [None]:
results

## Saving Test Eval Results

The code automatically saved evaluation result from each checkpoint in its respective folder. This next cell simply saves all of them in one place.

In [None]:
with open(os.path.join(args.output_dir, f"{os.path.splitext(args.test_file)[0]}_eval_results.txt"), mode='w', encoding='utf-8') as out_f:
    print(results, file=out_f)