In [None]:
!git clone https://github.com/ziadtarek12/language_distilling
%cd language_distilling
!git checkout eval

In [None]:
! pip uninstall -y torch torchvision torchaudio
!pip install transformers==4.26.0
!pip install pytorch-pretrained-bert
!pip install cytoolz
!pip install tqdm
!pip install torchtext==0.16.0
!pip install torchvision==0.16.0
!pip install torch==2.1.0
!pip install torchaudio==2.1.0
!pip install configargparse
!pip install tensorboardX
!pip install ipdb

In [None]:
import os
import sys
import torch
import numpy as np
import random
import shelve
import io
import argparse
import yaml
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import tensorboardX

In [None]:
sys.path.append('.')
sys.path.append('./opennmt')

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

In [None]:
# Create directories for data and outputs
!mkdir -p data/
!mkdir -p output/cmlm_model
!mkdir -p output/bert_dump
!mkdir -p output/kd-model/ckpt
!mkdir -p output/kd-model/log
!mkdir -p output/translation

# Download IWSLT German-English dataset using the provided script
!bash scripts/download-iwslt_deen.sh

In [None]:
from scripts.bert_tokenize import tokenize, process

# Load BERT tokenizer
bert_model = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case='uncased' in bert_model)

# Define data directories
data_dir = "data/de-en"

# BERT tokenize our dataset files
for language in ['de', 'en']:
    for split in ['train', 'valid', 'test']:
        input_file = f"{data_dir}/{split}.{language}"
        output_file = f"{data_dir}/{split}.{language}.bert"
        print(f"Tokenizing {input_file}...")
        
        with open(input_file, 'r') as reader, open(output_file, 'w') as writer:
            process(reader, writer, tokenizer)

In [None]:
# Create dataset DB for BERT training
from scripts.bert_prepro import main as bert_prepro

# Set up args for bert_prepro
prepro_args = argparse.Namespace(
    src=f"{data_dir}/train.de.bert",
    tgt=f"{data_dir}/train.en.bert",
    output='data/DEEN.db'
)

# Run preprocessing
bert_prepro(prepro_args)

# Create vocabulary file using OpenNMT's preprocess.py
print("Creating vocabulary files with OpenNMT preprocess.py...")
!python opennmt/preprocess.py \
    -train_src {data_dir}/train.de.bert \
    -train_tgt {data_dir}/train.en.bert \
    -valid_src {data_dir}/valid.de.bert \
    -valid_tgt {data_dir}/valid.en.bert \
    -save_data data/DEEN \
    -src_seq_length 150 -tgt_seq_length 150

vocab_file = "data/DEEN.vocab.pt"

In [None]:
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup

# Import needed modules
from cmlm.data import BertDataset, TokenBucketSampler
from cmlm.model import convert_embedding, BertForSeq2seq
from cmlm.util import Logger, RunningMeter
from run_cmlm_finetuning import noam_schedule

# Load vocabulary using our compatibility module
from vocab_loader import safe_load_vocab

vocab_file = "data/DEEN.vocab.pt"
train_file = "data/DEEN.db"
valid_src = f"{data_dir}/valid.de.bert"
valid_tgt = f"{data_dir}/valid.en.bert"
output_dir = "output/cmlm_model"

# Load vocabulary using custom loader to avoid PyTorch compatibility issues
vocab_dump = safe_load_vocab(vocab_file)
vocab = vocab_dump['tgt'].fields[0][1].vocab.stoi

# Create dataset
train_dataset = BertDataset(train_file, tokenizer, vocab, seq_len=512, max_len=150)

# Define sampler and data loader
BUCKET_SIZE = 8192
train_sampler = TokenBucketSampler(
    train_dataset.lens, BUCKET_SIZE, 6144, batch_multiple=1)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                         num_workers=4,
                         collate_fn=BertDataset.pad_collate)

# Prepare model
model = BertForSeq2seq.from_pretrained(bert_model)
bert_embedding = model.bert.embeddings.word_embeddings.weight

# Print model information before modifications
hidden_size = model.config.hidden_size
print(f"Original model: BERT hidden size = {hidden_size}")
print(f"Original model: BERT vocab size = {bert_embedding.size(0)}")
print(f"Target vocabulary size = {len(vocab)}")

# Convert vocabulary to embedding form
embedding = convert_embedding(tokenizer, vocab, bert_embedding)

# Update model architecture to accommodate the new vocabulary size
print(f"Updating model architecture for vocabulary size: {embedding.size(0)}")
# Create a new decoder with correct dimensions
model.cls.predictions.decoder = torch.nn.Linear(hidden_size, embedding.size(0), bias=True)
model.cls.predictions.bias = torch.nn.Parameter(torch.zeros(embedding.size(0)))
model.config.vocab_size = embedding.size(0)

# Update the weights
model.cls.predictions.decoder.weight.data.copy_(embedding.data)

# Move model to device
model.to(device)
print(f"Model adapted with vocabulary size: {model.config.vocab_size}")

In [None]:
# Training parameters
learning_rate = 5e-5
warmup_proportion = 0.1  # Using proportion instead of absolute steps
max_steps = 100000  # Full training uses 100k steps
num_steps_to_run = 100000  # We'll do fewer steps for demonstration

# Optimizer using modern AdamW from transformers
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer
                if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(max_steps * warmup_proportion),
    num_training_steps=max_steps
)

# Training loop
running_loss = RunningMeter('loss')
model.train()

print("Starting CMLM fine-tuning...")
#Use a plain iterator instead of tqdm with len()
train_iter = iter(train_loader)
for step in range(num_steps_to_run):
    try:
        batch = next(train_iter)
    except StopIteration:
        # Restart iterator if we run out of batches
        train_iter = iter(train_loader)
        batch = next(train_iter)
        
    # Move batch to device
    batch = tuple(t.to(device) for t in batch)
    input_ids, input_mask, segment_ids, lm_label_ids = batch
    
    # Zero gradients
    optimizer.zero_grad()
    
    # Create output mask from lm_label_ids for model forward pass
    output_mask = lm_label_ids != -1  # Masking for non-padded tokens
    
    # Forward pass with output_mask parameter
    loss = model(input_ids, segment_ids, input_mask, lm_label_ids, output_mask)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    running_loss(loss.item())
    print(f"Step {step}, Loss: {running_loss.val:.4f}")
    if step % 100 == 0:
        
        # Clear CUDA cache periodically to avoid memory issues
        torch.cuda.empty_cache()

# Save model checkpoint
torch.save(model.state_dict(), f"{output_dir}/model_step_{num_steps_to_run}.pt")
print(f"Model saved to {output_dir}/model_step_{num_steps_to_run}.pt")

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)
torch.cuda.empty_cache()

In [None]:
# Import extraction functions
from dump_teacher_hiddens import tensor_dumps, gather_hiddens, BertSampleDataset, batch_features, process_batch

# Path to model checkpoint from Stage 1
ckpt_path = f"{output_dir}/model_step_{num_steps_to_run}.pt"
bert_dump_path = "output/bert_dump"

# Load the fine-tuned BERT model
state_dict = torch.load(ckpt_path)
vsize = state_dict['cls.predictions.decoder.weight'].size(0)
bert = BertForSeq2seq.from_pretrained(bert_model).eval()
bert.to(device)

# Fix: Instead of using update_output_layer_by_size, which pads to multiples of 8,
# we'll directly resize the model layers to match the exact dimensions from the checkpoint
print(f"Resizing model to exact vocabulary size: {vsize}")
hidden_size = bert.config.hidden_size

# Create exact-sized layers without padding to multiples of 8
bert.cls.predictions.decoder = torch.nn.Linear(hidden_size, vsize, bias=True)
bert.cls.predictions.bias = bert.cls.predictions.decoder.bias
bert.config.vocab_size = vsize

# Now load the state dict - should have matching dimensions
bert.load_state_dict(state_dict)

# Save the final projection layer
linear = torch.nn.Linear(bert.config.hidden_size, bert.config.vocab_size)
linear.weight.data = state_dict['cls.predictions.decoder.weight']
linear.bias.data = state_dict['cls.predictions.bias']
torch.save(linear, f'{bert_dump_path}/linear.pt')

In [None]:
# Function to extract hidden states - with debugging option
def build_db_batched(corpus_path, out_db, bert, toker, batch_size=8, debug_mode=False, max_samples=100):
    dataset = BertSampleDataset(corpus_path, toker)
    
    # For debugging, limit the number of samples
    if debug_mode:
        # Get a subset of the dataset IDs
        all_ids = dataset.ids
        subset_ids = all_ids[:max_samples] if len(all_ids) > max_samples else all_ids
        dataset.ids = subset_ids
        print(f"DEBUG MODE: Processing only {len(subset_ids)} samples instead of {len(all_ids)}")
    
    loader = DataLoader(dataset, batch_size=batch_size,
                       num_workers=4, collate_fn=batch_features)
    
    with tqdm(desc='Computing BERT features', total=len(dataset)) as pbar:
        for ids, *batch in loader:
            outputs = process_batch(batch, bert, toker)
            for id_, output in zip(ids, outputs):
                if output is not None:
                    out_db[id_] = tensor_dumps(output)
            pbar.update(len(ids))
            
            # For debugging, break after the first batch if needed
            if debug_mode and batch_size >= max_samples:
                print("First batch processed, breaking early due to debug mode")
                break

# Extract hidden states
db_path = "data/DEEN.db"
print("Extracting hidden states...")

# Set debug mode to True for faster debugging, False for full processing
debug_mode = True  # Toggle this for quick debugging
max_samples = 100  # Number of samples to process in debug mode

with shelve.open(f'{bert_dump_path}/db', 'c') as out_db, torch.no_grad():
    build_db_batched(db_path, out_db, bert, tokenizer, batch_size=8, 
                    debug_mode=debug_mode, max_samples=max_samples)

# Free up GPU memory after extraction
print("Clearing GPU memory...")
bert.cpu()  # Move model to CPU
del bert    # Delete the model
linear.cpu()  # Move linear layer to CPU
torch.cuda.empty_cache()  # Empty the CUDA cache
print("GPU memory cleared after hidden states extraction")

if debug_mode:
    print(f"DEBUG MODE: Hidden states for {max_samples} samples extracted to {bert_dump_path}/db")
    print("To run full extraction, set debug_mode=False")
else:
    print(f"Hidden states extracted and saved to {bert_dump_path}/db")

In [None]:
# Import functions for top-k computation
from dump_teacher_topk import tensor_loads, dump_topk

# Top-K parameter
k = 8  # Following the paper

# Load linear layer
linear = torch.load(f'{bert_dump_path}/linear.pt')
# Ensure the linear layer uses the same precision as the hidden states (FP16/Half)
linear = linear.half()
linear.to(device)

# Compute top-k logits
print("Computing top-k logits...")
with shelve.open(f'{bert_dump_path}/db', 'r') as db, \
     shelve.open(f'{bert_dump_path}/topk', 'c') as topk_db:
    for key, value in tqdm(db.items(), total=len(db), desc='Computing topk...'):
        bert_hidden = torch.tensor(tensor_loads(value)).to(device)
        # bert_hidden is already in half precision, no need to convert
        topk = linear(bert_hidden).topk(dim=-1, k=k)
        dump = dump_topk(topk)
        topk_db[key] = dump
        
        # Clear tensor from GPU memory after each iteration
        del bert_hidden
        torch.cuda.empty_cache()

# Final memory cleanup
print("Clearing GPU memory...")
linear.cpu()  # Move linear layer to CPU
del linear     # Delete the linear layer
torch.cuda.empty_cache()  # Empty the CUDA cache
print("GPU memory cleared after top-k computation")
print(f"Top-k logits computed and saved to {bert_dump_path}/topk")

In [None]:
# Import required modules for training
from onmt.inputters.bert_kd_dataset import BertKdDataset, TokenBucketSampler
from onmt.utils.optimizers import Optimizer
from onmt.train_single import build_model_saver, build_trainer, cycle_loader
import torch.nn as nn  # Add missing import
import os  # Add import for checking file existence

# Define paths
data_db = "data/DEEN.db"
bert_dump = "output/bert_dump"
data = "data/DEEN"
config_path = "opennmt/config/config-transformer-base-mt-deen.yml"
output_path = "output/kd-model"

# Check if required files exist and provide guidance
print("Checking for required database files...")
topk_db_file = f"{bert_dump}/topk"
topk_db_dir = os.path.dirname(topk_db_file)

# First make sure the directory exists
if not os.path.exists(topk_db_dir):
    print(f"Creating directory: {topk_db_dir}")
    os.makedirs(topk_db_dir, exist_ok=True)

# Check if topk database exists
if not any(os.path.exists(f"{topk_db_file}{ext}") for ext in ["", ".db", ".dat", ".bak", ".dir"]):
    print(f"Warning: Top-k database not found at {topk_db_file}")
    print("Running top-k computation from Stage 2...")
    
    # Import functions for top-k computation if they haven't been imported yet
    from dump_teacher_topk import tensor_loads, dump_topk
    
    # Load the fine-tuned BERT model if not already loaded
    if 'linear' not in locals():
        linear_path = f'{bert_dump}/linear.pt'
        if os.path.exists(linear_path):
            print(f"Loading linear layer from {linear_path}")
            linear = torch.load(linear_path)
            linear.to(device)
        else:
            raise ValueError(f"Linear layer not found at {linear_path}. Please run Stage 2 first.")
    
    # Check if hidden states database exists
    db_path = f"{bert_dump}/db"
    if not any(os.path.exists(f"{db_path}{ext}") for ext in ["", ".db", ".dat", ".bak", ".dir"]):
        raise ValueError(f"Hidden states database not found at {db_path}. Please run Stage 2 first.")
    
    print("Computing top-k logits...")
    # Set k value for top-k computation
    k = 8  # Following the paper
    
    # Create the topk database in create mode
    with shelve.open(f'{bert_dump}/db', 'r') as db, \
         shelve.open(f'{bert_dump}/topk', 'c') as topk_db:
        for key, value in tqdm(db.items(), total=len(db), desc='Computing topk...'):
            # Load the hidden states and convert to the same data type as the linear layer
            bert_hidden = torch.tensor(tensor_loads(value), dtype=torch.float32).to(device)
            
            # Ensure same precision between hidden states and linear layer
            if linear.weight.dtype != bert_hidden.dtype:
                print(f"Converting tensors to match dtypes - hidden: {bert_hidden.dtype}, linear: {linear.weight.dtype}")
                # Either convert hidden to match linear
                if hasattr(linear, 'half') and linear.weight.dtype == torch.float16:
                    bert_hidden = bert_hidden.half()
                # Or convert linear to match hidden
                else:
                    linear = linear.float()
                    
            # Compute top-k
            topk = linear(bert_hidden).topk(dim=-1, k=k)
            dump = dump_topk(topk)
            topk_db[key] = dump
    
    print(f"Top-k logits computed and saved to {bert_dump}/topk")
else:
    print(f"Top-k database exists at {topk_db_file}")

# Load configuration
with open(config_path, 'r') as stream:
    config = yaml.safe_load(stream)

# Create args object
args = argparse.Namespace(**config)

# Setup KD parameters
args.train_from = None
args.max_grad_norm = None
args.kd_topk = 8
args.train_steps = 100000
args.kd_temperature = 10.0
args.kd_alpha = 0.5
args.warmup_steps = 8000
args.learning_rate = 2.0
args.bert_dump = bert_dump
args.data_db = data_db
args.bert_kd = True
args.data = data

# Add missing required parameters
args.model_type = "text"  # Required for OpenNMT model builder
args.copy_attn = False    # Common OpenNMT parameter
args.global_attention = "general"  # Common OpenNMT parameter

# Add embeddings parameters
# If word_vec_size is already defined, use it for both src and tgt
args.src_word_vec_size = args.word_vec_size
args.tgt_word_vec_size = args.word_vec_size
# Add any other required embedding parameters
args.feat_merge = "concat"
args.feat_vec_size = -1
args.feat_vec_exponent = 0.7

# Add pretrained word vectors parameters
args.pre_word_vecs_enc = None  # Path to pretrained word vectors for encoder
args.pre_word_vecs_dec = None  # Path to pretrained word vectors for decoder
args.pre_word_vecs = None      # General pretrained word vectors

# Add fix_word_vecs parameters that were missing
args.fix_word_vecs_enc = False
args.fix_word_vecs_dec = False

# Add critical RNN and transformer parameters
args.enc_rnn_size = args.rnn_size  # This was missing
args.dec_rnn_size = args.rnn_size
# Additional transformer-specific parameters
args.transformer_ff = getattr(args, 'transformer_ff', 2048)
args.heads = getattr(args, 'heads', 8)

# Add transformer position parameters
args.max_relative_positions = 0  # Default for standard transformer without relative positions
args.position_encoding = True  # Enable position encoding
args.param_init = 0.0  # Parameter initialization
args.param_init_glorot = True  # Use Glorot initialization

# Fix share_embeddings - set to False since we don't have shared vocabulary
args.share_embeddings = False  # This was causing the assertion error
args.share_decoder_embeddings = False  # Also disable this to be safe

# Add training parameters needed by OpenNMT trainer
args.truncated_decoder = 0  # Truncated BPTT
args.max_generator_batches = getattr(args, 'max_generator_batches', 32)
args.normalization = getattr(args, 'normalization', 'sents')
args.accum_count = getattr(args, 'accum_count', 1)
args.accum_steps = [0]
args.average_decay = 0.0  # Exponential moving average decay
args.average_every = 1  # Average every N updates
args.report_manager = None
args.valid_steps = getattr(args, 'valid_steps', 10000)
args.early_stopping = 0
args.early_stopping_criteria = None
args.valid_batch_size = 32

# Add the missing transformer attention parameters
args.self_attn_type = "scaled-dot"  # Default self-attention type for transformer
args.input_feed = 1  # Input feeding for RNN decoders
args.copy_attn_type = None  # Type of copy attention
args.generator_function = "softmax"  # Generator function

# Add distributed training parameters
args.local_rank = -1  # For distributed training (not used here)
args.gpu_ranks = getattr(args, 'gpu_ranks', [0])  # List of GPUs to use
args.gpu_verbose_level = 0  # GPU logging verbosity
args.world_size = getattr(args, 'world_size', 1)  # Number of processes for distributed

# Add other required parameters
args.encoder_type = getattr(args, 'encoder_type', "transformer")
args.decoder_type = getattr(args, 'decoder_type', "transformer") 
args.enc_layers = getattr(args, 'layers', 6)
args.dec_layers = getattr(args, 'layers', 6)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'dropout', 0.1)
args.bridge = ""
args.aux_tune = False
args.subword_prefix = "▁"
args.subword_prefix_is_joiner = False

args.save_model = os.path.join(output_path, 'ckpt', 'model')
args.log_file = os.path.join(output_path, 'log', 'log')
args.tensorboard_log_dir = os.path.join(output_path, 'log')

In [None]:
# Load vocabulary and dataset
vocab = torch.load(data + '.vocab.pt')
src_vocab = vocab['src'].fields[0][1].vocab.stoi
tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi

# Create dataset
train_dataset = BertKdDataset(data_db, bert_dump, 
                             src_vocab, tgt_vocab,
                             max_len=150, k=args.kd_topk)

# Create data loader
BUCKET_SIZE = 8192
train_sampler = TokenBucketSampler(
    train_dataset.keys, BUCKET_SIZE, 6144,
    batch_multiple=1)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                         num_workers=4,
                         collate_fn=BertKdDataset.pad_collate)

train_iter = cycle_loader(train_loader, device)

In [None]:
# Build the model
from onmt.model_builder import build_model

# Make sure nn is imported at the top of the notebook
model = build_model(args, args, fields=vocab, checkpoint=None)
model.to(device)

# Build optimizer
optim = Optimizer.from_opt(model, args, checkpoint=None)

# Build model saver
model_saver = build_model_saver(args, args, model, vocab, optim)

# Build trainer
trainer = build_trainer(args, 0, model, vocab, optim, model_saver=model_saver)

In [None]:
# the problem is in the following cell

In [None]:
# Train - for demonstration, we'll only do a few steps
num_steps_to_run_kd = 100000  # Adjust for full training (paper used 100k steps)

# Make sure the optimizer is tracking the step correctly
if not hasattr(optim, 'training_step'):
    optim._training_step = 0
    
# Define a custom iterator that provides batches without its own step limitation
def manual_train_iter():
    # We don't reset the optimizer step counter here anymore
    global train_iter  # Use global instead of nonlocal for variables defined at module level
    while True:  # This will keep yielding batches indefinitely
        try:
            batch = next(train_iter)
        except StopIteration:
            # Restart the iterator when we run out of batches
            print("Restarting data iterator")
            train_iter = cycle_loader(train_loader, device)
            batch = next(train_iter)
        
        # We let the trainer handle step counting now
        yield batch

print("Starting model training with knowledge distillation...")
# Now the trainer will properly control the number of steps
trainer.train(
    manual_train_iter(),
    num_steps_to_run_kd,
    save_checkpoint_steps=100,  # Save every 100 steps
    valid_iter=None
)

print(f"Model trained for {num_steps_to_run_kd} steps and saved to {output_path}/ckpt")

In [None]:
# Define paths for translation
model_path = f"{output_path}/ckpt/model_step_{num_steps_to_run_kd}.pt"
src_file = f"{data_dir}/test.de.bert"
tgt_file = f"{data_dir}/test.en.bert"
out_dir = "output/translation"
ref_file = f"{data_dir}/test.en"

# Ensure the output directory exists
os.makedirs(out_dir, exist_ok=True)

# Run translation if model exists
if os.path.exists(model_path):
    print(f"Model found at {model_path}. Running translation...")
    try:
        # Run translation
        !python opennmt/translate.py -model {model_path} \
                                    -src {src_file} \
                                    -tgt {tgt_file} \
                                    -output {out_dir}/result.en \
                                    -gpu 0 \
                                    -beam_size 5 -alpha 0.6 \
                                    -length_penalty wu

        print("Translation completed. Detokenizing output...")
        # Check if translation output exists
        if os.path.exists(f"{out_dir}/result.en"):
            # Detokenize output
            !python scripts/bert_detokenize.py --file {out_dir}/result.en \
                                          --output_dir {out_dir}

            # Check if detokenized output exists
            if os.path.exists(f"{out_dir}/result.en.detok"):
                print("Evaluating with BLEU score...")
                # Evaluate with BLEU
                !perl opennmt/tools/multi-bleu.perl {ref_file} \
                                               < {out_dir}/result.en.detok \
                                               > {out_dir}/result.bleu

                # Display BLEU score if file exists
                if os.path.exists(f"{out_dir}/result.bleu"):
                    with open(f"{out_dir}/result.bleu", "r") as f:
                        bleu_score = f.read().strip()
                        print(f"BLEU Score: {bleu_score}")
                else:
                    print("Warning: BLEU score file was not generated. This might indicate an issue with the evaluation.")
            else:
                print("Warning: Detokenized output file was not generated.")
        else:
            print("Warning: Translation output file was not generated.")
            
    except Exception as e:
        print(f"Error during translation process: {str(e)}")
        import traceback
        traceback.print_exc()
else:
    print(f"Model file {model_path} not found. Skipping translation.")
    print("You need to train the model first or adjust the model path to point to an existing checkpoint.")

In [None]:
import matplotlib.pyplot as plt

# Display the figures from the paper
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot CMLM finetuning
axes[0].set_title('CMLM Finetuning')
img = plt.imread('figures/cmlm-finetuning.png')
axes[0].imshow(img)
axes[0].axis('off')

# Plot translation losses
axes[1].set_title('Translation Losses')
img = plt.imread('figures/translation-losses.png')
axes[1].imshow(img)
axes[1].axis('off')

# Plot translation accuracy
axes[2].set_title('Translation Accuracy')
img = plt.imread('figures/translation-accuracy.png')
axes[2].imshow(img)
axes[2].axis('off')

plt.tight_layout()
plt.show()