In [1]:
import random
import torch
from transformers import AutoConfig, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
import argparse
import sys
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import numpy as np
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import time
from utils.cogs_utils import *
import _pickle as cPickle
from transformers import AutoModelForMaskedLM, AutoTokenizer, BertModel, BertConfig
from model.encoder_decoder_hf import EncoderDecoderConfig, EncoderDecoderModel
from model.encoder_decoder_lstm import EncoderDecoderLSTMModel
import pandas as pd  

torch.cuda.empty_cache()

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def find_partition_name(name, lf):
    if lf == "cogs":
        return name
    else:
        return name+f"_{lf}"

In [2]:
class COGSTrainer(object):
    def __init__(
        self, model,
        is_master,
        src_tokenizer, 
        tgt_tokenizer, 
        device,
        logger,
        lr=5e-5,
        apex_enable=False,
        n_gpu=1,
        early_stopping=5,
        do_statistic=False,
        is_wandb=False,
        model_name="",
        eval_acc=True,
    ):
        self.model = model
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.is_master = is_master
        self.logger = logger
        self.is_wandb = is_wandb
        self.model_name = model_name
        self.eval_acc = eval_acc
        
        self.device = device
        self.lr = lr
        self.n_gpu = n_gpu
    
        self.early_stopping = early_stopping
    
    def evaluate(
        self, eval_dataloader,
    ):
        logging.info("Evaluating ...")
        loss_sum = 0.0
        eval_step = 0
        correct_count = 0
        total_count = 0
        self.model.eval()
        for step, inputs in enumerate(eval_dataloader):
            for k, v in inputs.items():
                if v is not None and isinstance(v, torch.Tensor):
                    inputs[k] = v.to(self.device)
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            labels = inputs["labels"]
            outputs = self.model(**inputs)
            loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss
            loss_sum += loss.item()
            eval_step += 1
        self.model.train()
        if total_count == 0:
            return loss_sum / eval_step, 0
        return loss_sum / eval_step, correct_count / total_count
    
    def train(
        self, train_dataloader, eval_dataloader,
        optimizer, scheduler, output_dir,
        log_step, valid_steps, epochs, 
        gradient_accumulation_steps,
        save_after_epoch
    ):
        self.model.train()
        train_iterator = trange(
            0, int(epochs), desc="Epoch"
        )
        total_step = 0
        total_log_step = 0
        patient = 0
        min_eval_loss = 100
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True)
            for step, inputs in enumerate(epoch_iterator):
                if patient == self.early_stopping:
                    logging.info("Early stopping the training ...")
                    break
                for k, v in inputs.items():
                    if v is not None and isinstance(v, torch.Tensor):
                        inputs[k] = v.to(self.device)
                outputs = self.model(**inputs)
                loss = outputs.loss.mean() if self.n_gpu > 1 else outputs.loss
                
                if total_step % log_step == 0 and self.is_wandb:
                    wandb.log(
                        {
                            "train/loss": loss.item(),
                        },
                        step=total_log_step
                    )
                    total_log_step += 1
                loss_str = round(loss.item(), 2)
                epoch_iterator.set_postfix({'loss': loss_str})
                
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps
                
                if total_step % gradient_accumulation_steps == 0:
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
                    self.model.zero_grad()
                    
                total_step += 1
                
                if valid_steps != -1 and total_step % valid_steps == 0:
                    eval_loss, eval_acc = self.evaluate(eval_dataloader)
                    logging.info(f"Eval Loss: {eval_loss}; Eval Acc: {eval_acc}")
                    if self.is_wandb:
                        wandb.log(
                            {
                                "eval/loss": eval_loss.item(),
                                "eval/acc": eval_acc,
                            },
                            step=total_step
                        )
                    if eval_loss < min_eval_loss:
                        if self.is_master:
                            if self.n_gpu > 1:
                                self.model.module.save_pretrained(os.path.join(output_dir, 'model-best'))
                            else:
                                self.model.save_pretrained(os.path.join(output_dir, 'model-best'))
                        min_eval_loss = eval_loss
                        patient = 0
                    else:
                        patient += 1
                        
            if self.is_master:
                if save_after_epoch is not None and epoch % save_after_epoch == 0:
                    dir_name = f"model-epoch-{epoch}"
                else:
                    dir_name = "model-last"
                if self.n_gpu > 1:
                    self.model.module.save_pretrained(os.path.join(output_dir, dir_name))
                else:
                    self.model.save_pretrained(os.path.join(output_dir, dir_name))
            if patient == self.early_stopping:
                break
        logging.info("Training is finished ...") 
        if self.is_master:
            if self.n_gpu > 1:
                self.model.module.save_pretrained(os.path.join(output_dir, 'model-last'))
            else:
                self.model.save_pretrained(os.path.join(output_dir, 'model-last'))

In [3]:
if __name__ == '__main__':
    is_notebook = False
    try:
        cmd = argparse.ArgumentParser('The testing components of')
        cmd.add_argument('--gpu', default=-1, type=int, help='use id of gpu, -1 if cpu.')
        cmd.add_argument('--train_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--eval_batch_size', default=128, type=int, help='training batch size')
        cmd.add_argument('--lr', default=0.01, type=float, help='learning rate')
        cmd.add_argument('--data_path', required=True, type=str, help='path to the training corpus')
        cmd.add_argument(
            '--encoder_config_path', 
            type=str, help='path to the encoder config'
        )
        cmd.add_argument(
            '--decoder_config_path', 
            type=str, help='path to the decoder config'
        )
        cmd.add_argument('--max_seq_len', default=512, type=int)
        cmd.add_argument('--seed', default=42, type=int)
        cmd.add_argument('--gradient_accumulation_steps', default=1, type=int)
        cmd.add_argument('--output_dir', required=True, type=str, help='save dir')
        cmd.add_argument('--local_rank', default=-1, type=int, help='multi gpu training')
        cmd.add_argument('--epochs', default=10, type=int, help='training epochs')
        cmd.add_argument('--model_path', type=str, required=False, default=None)
        cmd.add_argument('--warm_up', type=float, default=0.1)
        cmd.add_argument('--is_wandb', default=False, action='store_true')
        cmd.add_argument('--spanformer', default=False, action='store_true')
        cmd.add_argument('--log_step', default=10, type=int)
        cmd.add_argument('--valid_steps', default=500, type=int)
        cmd.add_argument('--early_stopping', default=5, type=int)
        cmd.add_argument('--device', default="cuda", type=str, help='')
        cmd.add_argument('--do_train', default=False, action='store_true')
        cmd.add_argument('--do_eval', default=False, action='store_true')
        cmd.add_argument('--do_test', default=False, action='store_true')
        cmd.add_argument('--do_gen', default=False, action='store_true')
        cmd.add_argument('--least_to_most', default=False, action='store_true')
        cmd.add_argument('--use_glove', default=False, action='store_true')
        cmd.add_argument('--eval_acc', default=False, action='store_true')
        cmd.add_argument('--use_span_match', default=False, action='store_true')
        cmd.add_argument('--save_after_epoch', type=int, default=None)
        cmd.add_argument('--lf', default="cogs", type=str, help='')
        cmd.add_argument('--model_name', default="cogs", type=str, help='')
        
        args = cmd.parse_args(sys.argv[1:])
    except:
        # LSTM settings best: {batch = 512, lr = 8e-4, epoch = 200}
        # Transformer settings best: {batch = 128, lr = 1e-4, epoch = 200}
        is_notebook = True
        parser = argparse.ArgumentParser()
        args = parser.parse_args([])
        args.gpu = 1
        args.train_batch_size = 128
        args.eval_batch_size = 128
        args.gradient_accumulation_steps = 1
        args.lr = 1e-4
        args.data_path = "./cogs_participle_verb/"
        args.model_data_path = "./model/"
        args.encoder_config_path = None
        args.decoder_config_path = None
        args.max_seq_len = 512
        args.seed = 77
        args.output_dir = "./results_cogs_notebook/"
        args.epochs = 200
        args.warm_up = 0.1
        args.is_wandb = False
        args.log_step = 10
        # args.valid_steps = 500 # -1 not do training eval!
        args.valid_steps = -1
        args.early_stopping = None # large == never early stop!
        args.device = "cuda:0"
        args.spanformer = False
        args.model_path = None
        args.do_train = True
        args.do_eval = False
        args.do_test = True
        args.do_gen = True
        args.least_to_most = False
        args.use_glove = False
        args.eval_acc = False
        args.save_after_epoch = None
        args.use_span_match = False
        args.model_name = "ende_transformer"
        # args.lf = "no_()" # cogs, es, noexp
        # args.model_path = "./results_cogs_notebook/cogs_pipeline.model.ende_lstm.lf.cogs.glove.False.seed.42/model-last/"
        print("Using in a notebook env.")

Using in a notebook env.


usage: The testing components of [-h] [--gpu GPU]
                                 [--train_batch_size TRAIN_BATCH_SIZE]
                                 [--eval_batch_size EVAL_BATCH_SIZE] [--lr LR]
                                 --data_path DATA_PATH
                                 [--encoder_config_path ENCODER_CONFIG_PATH]
                                 [--decoder_config_path DECODER_CONFIG_PATH]
                                 [--max_seq_len MAX_SEQ_LEN] [--seed SEED]
                                 [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
                                 --output_dir OUTPUT_DIR
                                 [--local_rank LOCAL_RANK] [--epochs EPOCHS]
                                 [--model_path MODEL_PATH] [--warm_up WARM_UP]
                                 [--is_wandb] [--spanformer]
                                 [--log_step LOG_STEP]
                                 [--valid_steps VALID_STEPS]
                                 

In [4]:
results = {}

In [None]:
for lf in [
    "cogs",
]:
    for seed in [42]: # 42, 66, 77, 88, 99
        set_seed(args.seed)
        
        args.lf = lf
        args.seed = seed

        model_name = args.model_name
        run_name = f"cogs_pipeline.model.{model_name}.lf.{args.lf}.glove.{args.use_glove}.seed.{args.seed}"
        if args.do_train == False:
            args.model_path = f"./results_cogs_notebook/{run_name}/model-last/"
        
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        device = torch.device(args.device)
        
        encoder_config_filename = "encoder_config_lstm.json" if model_name == "ende_lstm" else "encoder_config.json"
        decoder_config_filename = "decoder_config_lstm.json" if model_name == "ende_lstm" else "decoder_config.json"
        
        if "participle_verb" in args.data_path:
            config_encoder = AutoConfig.from_pretrained(
                os.path.join(args.data_path, encoder_config_filename)
            )
            config_decoder = AutoConfig.from_pretrained(
                    os.path.join(args.data_path, decoder_config_filename) if args.decoder_config_path is None else \
                        args.decoder_config_path
            )
        else:
            config_encoder = AutoConfig.from_pretrained(
                os.path.join(args.model_data_path, encoder_config_filename)
            )
            config_decoder = AutoConfig.from_pretrained(
                    os.path.join(args.model_data_path, decoder_config_filename) if args.decoder_config_path is None else \
                        args.decoder_config_path
            )

        if "participle_verb" in args.data_path:
            src_tokenizer = WordLevelTokenizer(
                os.path.join(args.data_path, "src_vocab.txt"), 
                config_encoder,
                max_seq_len=args.max_seq_len
            )
            tgt_tokenizer = WordLevelTokenizer(
                os.path.join(args.data_path, "tgt_vocab.txt"), 
                config_decoder,
                max_seq_len=args.max_seq_len
            )  
        else:
            src_tokenizer = WordLevelTokenizer(
                os.path.join(args.model_data_path, "src_vocab.txt"), 
                config_encoder,
                max_seq_len=args.max_seq_len
            )
            tgt_tokenizer = WordLevelTokenizer(
                os.path.join(args.model_data_path, "tgt_vocab.txt"), 
                config_decoder,
                max_seq_len=args.max_seq_len
            )

        if args.least_to_most:
            logging.info("Preparing training set to be least to most order.")
        train_dataset = COGSDataset(
            cogs_path=args.data_path, 
            src_tokenizer=src_tokenizer, 
            tgt_tokenizer=tgt_tokenizer, 
            partition=find_partition_name("train", args.lf),
            least_to_most=args.least_to_most
        )
        train_dataloader = DataLoader(
            train_dataset, batch_size=args.train_batch_size, 
            sampler=SequentialSampler(train_dataset),
            collate_fn=train_dataset.collate_batch
        )

        eval_dataset = COGSDataset(
            cogs_path=args.data_path, 
            src_tokenizer=src_tokenizer, 
            tgt_tokenizer=tgt_tokenizer, 
            partition=find_partition_name("dev", args.lf),
        )
        eval_dataloader = DataLoader(
            eval_dataset, batch_size=args.eval_batch_size, 
            sampler=SequentialSampler(eval_dataset),
            collate_fn=train_dataset.collate_batch
        )

        test_dataset = COGSDataset(
            cogs_path=args.data_path, 
            src_tokenizer=src_tokenizer, 
            tgt_tokenizer=tgt_tokenizer, 
            partition=find_partition_name("test", args.lf),
        )
        test_dataloader = DataLoader(
            test_dataset, batch_size=args.eval_batch_size, 
            sampler=SequentialSampler(test_dataset),
            collate_fn=train_dataset.collate_batch
        )

        gen_dataset = COGSDataset(
            cogs_path=args.data_path, 
            src_tokenizer=src_tokenizer, 
            tgt_tokenizer=tgt_tokenizer, 
            partition=find_partition_name("gen", args.lf),
        )
        gen_dataloader = DataLoader(
            gen_dataset, batch_size=args.eval_batch_size, 
            sampler=SequentialSampler(gen_dataset),
            collate_fn=train_dataset.collate_batch
        )
        
        if model_name == "ende_transformer":
            logging.info("Baselining the Transformer Encoder-Decoder Model")
            model_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            model_config.decoder_start_token_id = config_encoder.bos_token_id
            model_config.pad_token_id = config_encoder.pad_token_id
            model_config.eos_token_id = config_encoder.eos_token_id
            model = EncoderDecoderModel(config=model_config)
        elif model_name == "ende_lstm":
            logging.info("Baselining the LSTM Encoder-Decoder Model")
            model_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            model_config.decoder_start_token_id = config_encoder.bos_token_id
            model_config.pad_token_id = config_encoder.pad_token_id
            model_config.eos_token_id = config_encoder.eos_token_id
            model = EncoderDecoderLSTMModel(config=model_config)
            
        if args.model_path is not None and model_name == "ende_transformer":
            logging.info("Loading pretrained model.")
            model = model.from_pretrained(args.model_path)
        elif args.model_path is not None and model_name == "ende_lstm":
            logging.info("Loading pretrained model.")
            raw_weights = torch.load(os.path.join(args.model_path, 'pytorch_model.bin'))
            model.load_state_dict(raw_weights)
            
        

        if "cuda:" not in args.device:
            n_gpu = torch.cuda.device_count()
            logging.info(f'__Number CUDA Devices: {n_gpu}')
        else:
            n_gpu = 1
            logging.info(f'__Number CUDA Devices: {n_gpu}')

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)
        _ = model.to(device)

        t_total = int(len(train_dataloader) * args.epochs)

        warm_up_steps = args.warm_up * t_total
        optimizer = torch.optim.AdamW(
            model.parameters(), lr=args.lr
        )
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_steps,
                                                    num_training_steps=t_total)
        is_master = True
        apex_enable = False                                       
        if not os.path.exists(args.output_dir) and is_master:
            os.mkdir(args.output_dir)

        os.environ["WANDB_PROJECT"] = f"COGS"

        output_dir = os.path.join(args.output_dir, run_name)
        if args.is_wandb:
            import wandb
            run = wandb.init(
                project="COGS-CKY-Transformer", 
                entity="wuzhengx",
                name=run_name,
            )
            wandb.config.update(args)
        if not os.path.exists(args.output_dir) and is_master:
            os.mkdir(args.output_dir)

        trainer = COGSTrainer(
            model, device=device, 
            src_tokenizer=src_tokenizer, 
            tgt_tokenizer=tgt_tokenizer, 
            logger=logger,
            is_master=is_master, 
            n_gpu=n_gpu,
            is_wandb=args.is_wandb, 
            model_name=model_name,
            eval_acc=args.eval_acc,
            # early_stopping=args.early_stopping
        )
        num_params = count_parameters(model)
        logging.info(f'Number of model params: {num_params}')

        if args.do_train:
            logging.info(f"OUTPUT DIR: {output_dir}")
            trainer.train(
                train_dataloader, eval_dataloader,
                optimizer, scheduler, 
                log_step=args.log_step, valid_steps=args.valid_steps,
                output_dir=output_dir, epochs=args.epochs, 
                gradient_accumulation_steps=args.gradient_accumulation_steps,
                save_after_epoch=args.save_after_epoch,
            )
        
        if args.do_test:
            trainer.model.eval()
            epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
            total_count = 0
            correct_count = 0
            for step, inputs in enumerate(epoch_iterator):
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                labels = inputs["labels"].to(device)
                if model_name == "ende_lstm":
                    outputs = trainer.model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                    )
                else:
                    outputs = trainer.model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        eos_token_id=model_config.eos_token_id,
                        max_length=args.max_seq_len,
                    )
                decoded_preds = tgt_tokenizer.batch_decode(outputs)
                decoded_labels = tgt_tokenizer.batch_decode(labels)

                for i in range(len(decoded_preds)):
                    if args.use_span_match:
                        if set(decoded_preds[i].split(" ; ")) == set(decoded_labels[i].split(" ; ")):
                            correct_count += 1
                        else:
                            print(decoded_preds[i])
                            print(decoded_labels[i])
                    else:
                        if decoded_preds[i] == decoded_labels[i]:
                            correct_count += 1
                        else:
                            print(decoded_preds[i])
                            print(decoded_labels[i])
                    total_count += 1
                current_acc = round(correct_count/total_count, 2)
                epoch_iterator.set_postfix({'acc': current_acc})
            test_acc = current_acc

        if args.do_gen:
            per_cat_eval = {}
            for cat in set(gen_dataset.eval_cat):
                per_cat_eval[cat] = [0, 0] # correct, total
            trainer.model.eval()
            epoch_iterator = tqdm(gen_dataloader, desc="Iteration", position=0, leave=True)
            total_count = 0
            correct_count = 0
            for step, inputs in enumerate(epoch_iterator):
                input_ids = inputs["input_ids"].to(device)
                attention_mask = inputs["attention_mask"].to(device)
                labels = inputs["labels"].to(device)
                if model_name == "ende_lstm":
                    outputs = trainer.model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                    )
                else:
                    outputs = trainer.model.generate(
                        input_ids,
                        attention_mask=attention_mask,
                        eos_token_id=model_config.eos_token_id,
                        max_length=args.max_seq_len,
                    )
                decoded_preds = tgt_tokenizer.batch_decode(outputs)
                decoded_labels = tgt_tokenizer.batch_decode(labels)

                input_labels = src_tokenizer.batch_decode(input_ids)
                for i in range(len(decoded_preds)):
                    cat = gen_dataset.eval_cat[total_count]
                    if args.use_span_match:
                        if set(decoded_preds[i].split(" ; ")) == set(decoded_labels[i].split(" ; ")):
                            correct_count += 1
                            per_cat_eval[cat][0] += 1
                        else:
                            if cat == "obj_pp_to_subj_pp":
                                pass
        #                         print("input: ", input_labels[i])
        #                         print("pred: ", decoded_preds[i])
        #                         print("actual: ", decoded_labels[i])
        #                         print("cat: ", cat)
        #                         print()
                    else:
                        if decoded_preds[i] == decoded_labels[i]:
                            correct_count += 1
                            per_cat_eval[cat][0] += 1
                        else:
                            if cat == "prim_to_obj_proper":
                                pass
        #                             print("input: ", input_labels[i])
        #                             print("pred: ", decoded_preds[i])
        #                             print("actual: ", decoded_labels[i])
        #                             print("cat: ", cat)
        #                             print()
                    total_count += 1
                    per_cat_eval[cat][1] += 1
                current_acc = correct_count/total_count
                epoch_iterator.set_postfix({'acc': current_acc})

            struct_pp_acc = 0
            struct_cp_acc = 0
            struct_obj_subj_acc = 0

            lex_acc = 0
            lex_count = 0
            for k, v in per_cat_eval.items():
                if k  == "pp_recursion":
                    struct_pp_acc = 100 * v[0]/v[1]
                elif k  == "cp_recursion":
                    struct_cp_acc = 100 * v[0]/v[1]
                elif k  == "obj_pp_to_subj_pp":
                    struct_obj_subj_acc = 100 * v[0]/v[1]
                elif k  == "subj_to_obj_proper":
                    subj_to_obj_proper_acc = 100 * v[0]/v[1]
                elif k  == "prim_to_obj_proper":
                    prim_to_obj_proper_acc = 100 * v[0]/v[1]
                elif k  == "prim_to_subj_proper": 
                    prim_to_subj_proper_acc = 100 * v[0]/v[1]
                else:
                    lex_acc += v[0]
                    lex_count += v[1]
            lex_acc /= lex_count
            lex_acc *= 100
            current_acc *= 100

            print(f"obj_pp_to_subj_pp: {struct_obj_subj_acc}")
            print(f"cp_recursion: {struct_cp_acc}")
            print(f"pp_recursion: {struct_pp_acc}")
            print(f"subj_to_obj_proper: {subj_to_obj_proper_acc}")
            print(f"prim_to_obj_proper: {prim_to_obj_proper_acc}")
            print(f"prim_to_subj_proper: {prim_to_subj_proper_acc}")
            print(f"LEX: {lex_acc}")
            print(f"OVERALL: {current_acc}")

            results[(seed, lf)] = {
                "obj_pp_to_subj_pp" : struct_obj_subj_acc,
                "cp_recursion" : struct_cp_acc,
                "pp_recursion" : struct_pp_acc,
                "subj_to_obj_proper" : subj_to_obj_proper_acc,
                "prim_to_obj_proper" : prim_to_obj_proper_acc,
                "prim_to_subj_proper" : prim_to_subj_proper_acc,
                "lex_acc" : lex_acc,
                "overall_acc" : current_acc,
                "test_acc" : test_acc
            }



INFO:root:Baselining the Transformer Encoder-Decoder Model
INFO:root:__Number CUDA Devices: 1
INFO:root:Number of model params: 4384278
INFO:root:OUTPUT DIR: ./results_cogs_notebook/cogs_pipeline.model.ende_transformer.lf.cogs.glove.False.seed.42
Epoch: 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [00:10<00:00, 18.82it/s, loss=5.73]
Epoch: 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [00:09<00:00, 19.60it/s, loss=4.23]
Epoch: 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [00:09<00:00

In [6]:
result_tables = []
for k,v in results.items():
    result_tables += [[
        k[1], v['obj_pp_to_subj_pp'], 
        v['cp_recursion'], v['pp_recursion'], 
        v['subj_to_obj_proper'], v['prim_to_obj_proper'], v['prim_to_subj_proper'],
        v['lex_acc'], v['overall_acc'], v['test_acc']
    ]]
result_df = pd.DataFrame(result_tables, columns = [
    'exp', 'obj_pp_to_subj_pp', 'cp_recursion', 'pp_recursion',
    'subj_to_obj_proper', 'prim_to_obj_proper', 'prim_to_subj_proper',
    'lex_acc', 'overall_acc', 'test_acc'
]) 
result_df.groupby(['exp'], as_index=False).mean()

Unnamed: 0,exp,obj_pp_to_subj_pp,cp_recursion,pp_recursion,subj_to_obj_proper,prim_to_obj_proper,prim_to_subj_proper,lex_acc,overall_acc,test_acc
0,cogs,0.0,0.0,0.0,0.266667,0.366667,0.1,0.042222,0.065079,0.0


In [16]:
result_df

Unnamed: 0,exp,obj_pp_to_subj_pp,cp_recursion,pp_recursion,subj_to_obj_proper,prim_to_obj_proper,prim_to_subj_proper,lex_acc,overall_acc,test_acc
0,cogs,75.6,0.2,5.7,29.6,0.0,0.0,34.346667,29.82381,100.0
1,cogs,80.8,0.6,6.4,4.0,50.3,98.2,48.58,46.142857,1.0
2,cogs,81.2,0.4,7.3,1.9,1.3,90.2,49.66,44.152381,1.0
3,cogs,75.6,0.2,5.7,29.6,0.0,0.0,34.346667,29.82381,1.0
4,cogs,80.9,0.8,2.7,1.7,0.3,92.5,54.586667,47.509524,1.0


### Eval without absolute indexing

In [7]:
trainer.model.eval()
epoch_iterator = tqdm(test_dataloader, desc="Iteration", position=0, leave=True)
total_count = 0
correct_count = 0
for step, inputs in enumerate(epoch_iterator):
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    labels = inputs["labels"].to(device)
    outputs = trainer.model.generate(
        input_ids,
        attention_mask=attention_mask,
#         eos_token_id=model_config.eos_token_id,
#         max_length=512,
    )
    decoded_preds = tgt_tokenizer.batch_decode(outputs)
    decoded_labels = tgt_tokenizer.batch_decode(labels)

    for i in range(len(decoded_preds)):
        
        index_mapping = {}
        current_idx = 0
        for t in decoded_labels[i].split():
            if t.isnumeric():
                if int(t) not in index_mapping:
                    index_mapping[int(t)] = current_idx
                    current_idx += 1
        decoded_labels_ii = []
        for t in decoded_labels[i].split():
            if t.isnumeric():
                decoded_labels_ii += [str(index_mapping[int(t)])]
            else:
                decoded_labels_ii += [t]
                
        index_mapping = {}
        current_idx = 0
        for t in decoded_preds[i].split():
            if t.isnumeric():
                if int(t) not in index_mapping:
                    index_mapping[int(t)] = current_idx
                    current_idx += 1
        decoded_preds_ii = []
        for t in decoded_preds[i].split():
            if t.isnumeric():
                decoded_preds_ii += [str(index_mapping[int(t)])]
            else:
                decoded_preds_ii += [t]
            
        
        decoded_labels_ii_str = " ".join(decoded_labels_ii)
        decoded_preds_ii_str = " ".join(decoded_preds_ii)
            
        if decoded_preds_ii_str == decoded_labels_ii_str:
            correct_count += 1
        else:
            print(decoded_labels_ii_str)
            print(decoded_preds_ii_str)

        total_count += 1
    current_acc = round(correct_count/total_count, 2)
    epoch_iterator.set_postfix({'acc': current_acc})

Iteration:   8%|████████████████▊                                                                                                                                                                                        | 2/24 [00:00<00:02,  9.11it/s, acc=0.16]

* cake ( 0 ) ; agent . like ( 1 , Mila ) AND ccomp . like ( 1 , 2 ) AND theme . offer ( 2 , 0 ) AND recipient . offer ( 2 , Emma )
* cake ( 0 ) ; agent . like ( 1 , Mila ) AND ccomp . like
coach ( 0 ) ; * cake ( 1 ) ; agent . support ( 2 , 0 ) AND ccomp . support ( 2 , 3 ) AND theme . snap ( 3 , 1 )
coach ( 0 ) ; * cake ( 1 ) ; agent . support ( 2 , 0
* moose ( 0 ) ; agent . want ( 1 , 0 ) AND xcomp . want ( 1 , 2 ) AND agent . read ( 2 , 0 )
* moose ( 0 ) ; agent . want ( 1 , 0 ) AND xcomp . want
box ( 0 ) ; * cat ( 1 ) ; theme . give ( 2 , 0 ) AND recipient . give ( 2 , 1 ) AND agent . give ( 2 , Aiden )
box ( 0 ) ; * cat ( 1 ) ; theme . give ( 2 , 0
* boy ( 0 ) ; agent . clean ( 1 , Emma ) AND theme . clean ( 1 , 0 )
* boy ( 0 ) ; agent . clean ( 1 , Emma ) AND theme . clean
* dog ( 0 ) ; * boy ( 1 ) ; agent . paint ( 2 , 0 ) AND theme . paint ( 2 , 1 )
* dog ( 0 ) ; * boy ( 1 ) ; agent . paint ( 2 ,
* customer ( 0 ) ; * priest ( 1 ) ; box ( 2 ) ; nmod . in ( 1 , 2 ) AND agent . hol

Iteration:  17%|█████████████████████████████████▌                                                                                                                                                                       | 4/24 [00:00<00:02,  9.22it/s, acc=0.16]

donut ( 0 ) ; table ( 1 ) ; nmod . beside ( 0 , 1 ) AND recipient . return ( 2 , Charlotte ) AND theme . return ( 2 , 0 )
donut ( 0 ) ; table ( 1 ) ; nmod . beside ( 0 , 1 )
horse ( 0 ) ; agent . see ( 1 , Penelope ) AND theme . see ( 1 , 0 )
horse ( 0 ) ; agent . see ( 1 , Penelope ) AND theme . see (
* tiger ( 0 ) ; doll ( 1 ) ; * bed ( 2 ) ; nmod . beside ( 1 , 2 ) AND agent . lend ( 3 , Ava ) AND recipient . lend ( 3 , 0 ) AND theme . lend ( 3 , 1 )
* tiger ( 0 ) ; doll ( 1 ) ; * bed ( 2 ) ; nmod
donut ( 0 ) ; theme . stab ( 1 , 0 ) AND agent . stab ( 1 , Abigail )
donut ( 0 ) ; theme . stab ( 1 , 0 ) AND agent . stab (
* cake ( 0 ) ; backpack ( 1 ) ; nmod . in ( 0 , 1 ) AND recipient . wire ( 2 , Noah ) AND theme . wire ( 2 , 0 )
* cake ( 0 ) ; backpack ( 1 ) ; nmod . in ( 0 , 1
girl ( 0 ) ; agent . like ( 1 , 0 ) AND theme . like ( 1 , Ethan )
girl ( 0 ) ; agent . like ( 1 , 0 ) AND theme . like (
* hamburger ( 0 ) ; * road ( 1 ) ; nmod . on ( 0 , 1 ) AND recipient . lend ( 2 , M

Iteration:  25%|██████████████████████████████████████████████████▎                                                                                                                                                      | 6/24 [00:00<00:01,  9.26it/s, acc=0.16]

* penguin ( 0 ) ; * lollipop ( 1 ) ; agent . wish ( 2 , 0 ) AND ccomp . wish ( 2 , 3 ) AND theme . help ( 3 , 1 )
* penguin ( 0 ) ; * lollipop ( 1 ) ; agent . wish ( 2 ,
donut ( 0 ) ; agent . appreciate ( 1 , Olivia ) AND theme . appreciate ( 1 , 0 )
donut ( 0 ) ; agent . appreciate ( 1 , Olivia ) AND theme . appreciate (
yogurt ( 0 ) ; cat ( 1 ) ; theme . inflate ( 2 , 0 ) AND agent . inflate ( 2 , 1 )
yogurt ( 0 ) ; cat ( 1 ) ; theme . inflate ( 2 , 0 )
cake ( 0 ) ; * turkey ( 1 ) ; theme . find ( 2 , 0 ) AND agent . find ( 2 , 1 )
cake ( 0 ) ; * turkey ( 1 ) ; theme . find ( 2 , 0
boy ( 0 ) ; * donut ( 1 ) ; agent . wish ( 2 , 0 ) AND ccomp . wish ( 2 , 3 ) AND theme . appreciate ( 3 , 1 )
boy ( 0 ) ; * donut ( 1 ) ; agent . wish ( 2 , 0
dog ( 0 ) ; agent . declare ( 1 , Liam ) AND ccomp . declare ( 1 , 2 ) AND agent . clean ( 2 , 0 )
dog ( 0 ) ; agent . declare ( 1 , Liam ) AND ccomp . declare (
dog ( 0 ) ; * book ( 1 ) ; agent . serve ( 2 , 0 ) AND recipient . serve ( 2 , Emma ) A

Iteration:  33%|███████████████████████████████████████████████████████████████████                                                                                                                                      | 8/24 [00:00<00:01,  9.25it/s, acc=0.15]

* cake ( 0 ) ; * house ( 1 ) ; nmod . in ( 0 , 1 ) AND recipient . feed ( 2 , Grace ) AND theme . feed ( 2 , 0 )
* cake ( 0 ) ; * house ( 1 ) ; nmod . in ( 0 ,
hamburger ( 0 ) ; * judge ( 1 ) ; theme . cook ( 2 , 0 ) AND agent . cook ( 2 , 1 )
hamburger ( 0 ) ; * judge ( 1 ) ; theme . cook ( 2 , 0
* cake ( 0 ) ; * pedestal ( 1 ) ; nmod . on ( 0 , 1 ) AND agent . lend ( 2 , Audrey ) AND recipient . lend ( 2 , Sophia ) AND theme . lend ( 2 , 0 )
* cake ( 0 ) ; * pedestal ( 1 ) ; nmod . on ( 0 ,
journalist ( 0 ) ; agent . nurse ( 1 , 0 ) AND theme . nurse ( 1 , Emma )
journalist ( 0 ) ; agent . nurse ( 1 , 0 ) AND theme . nurse (
* girl ( 0 ) ; * needle ( 1 ) ; agent . worship ( 2 , 0 ) AND theme . worship ( 2 , 1 )
* girl ( 0 ) ; * needle ( 1 ) ; agent . worship ( 2 ,
cake ( 0 ) ; hen ( 1 ) ; theme . offer ( 2 , 0 ) AND recipient . offer ( 2 , 1 ) AND agent . offer ( 2 , Olivia )
cake ( 0 ) ; hen ( 1 ) ; theme . offer ( 2 , 0 )
* rose ( 0 ) ; theme . dust ( 1 , 0 ) AND agent . dust ( 1 ,

Iteration:  42%|███████████████████████████████████████████████████████████████████████████████████▎                                                                                                                    | 10/24 [00:01<00:01,  9.27it/s, acc=0.15]

* donut ( 0 ) ; theme . snap ( 1 , 0 ) AND agent . snap ( 1 , Mia )
* donut ( 0 ) ; theme . snap ( 1 , 0 ) AND agent . snap
balloon ( 0 ) ; block ( 1 ) ; nmod . on ( 0 , 1 ) AND agent . paint ( 2 , William ) AND theme . paint ( 2 , 0 )
balloon ( 0 ) ; block ( 1 ) ; nmod . on ( 0 , 1 )
* cake ( 0 ) ; theme . draw ( 1 , 0 ) AND agent . draw ( 1 , Elijah )
* cake ( 0 ) ; theme . draw ( 1 , 0 ) AND agent . draw
* girl ( 0 ) ; * cake ( 1 ) ; agent . respect ( 2 , 0 ) AND theme . respect ( 2 , 1 )
* girl ( 0 ) ; * cake ( 1 ) ; agent . respect ( 2 ,
* bowl ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll ( 1 , Emma )
* bowl ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll
* donut ( 0 ) ; theme . throw ( 1 , 0 ) AND agent . throw ( 1 , Luke )
* donut ( 0 ) ; theme . throw ( 1 , 0 ) AND agent . throw
block ( 0 ) ; * girl ( 1 ) ; theme . miss ( 2 , 0 ) AND agent . miss ( 2 , 1 )
block ( 0 ) ; * girl ( 1 ) ; theme . miss ( 2 , 0
* cookie ( 0 ) ; boy ( 1 ) ; theme . help ( 2 , 0 ) AND agent . help (

Iteration:  50%|████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                    | 12/24 [00:01<00:01,  9.26it/s, acc=0.16]

girl ( 0 ) ; book ( 1 ) ; house ( 2 ) ; nmod . in ( 1 , 2 ) AND recipient . award ( 3 , 0 ) AND theme . award ( 3 , 1 )
girl ( 0 ) ; book ( 1 ) ; house ( 2 ) ; nmod . in
cake ( 0 ) ; theme . give ( 1 , 0 ) AND recipient . give ( 1 , Isabella ) AND agent . give ( 1 , Amelia )
cake ( 0 ) ; theme . give ( 1 , 0 ) AND recipient . give (
* girl ( 0 ) ; * mirror ( 1 ) ; agent . give ( 2 , 0 ) AND recipient . give ( 2 , Emma ) AND theme . give ( 2 , 1 )
* girl ( 0 ) ; * mirror ( 1 ) ; agent . give ( 2 ,
* deer ( 0 ) ; * child ( 1 ) ; * trailer ( 2 ) ; nmod . in ( 1 , 2 ) AND agent . roll ( 3 , 0 ) AND theme . roll ( 3 , 1 )
* deer ( 0 ) ; * child ( 1 ) ; * trailer ( 2 ) ;
* girl ( 0 ) ; * yogurt ( 1 ) ; agent . pass ( 2 , 0 ) AND recipient . pass ( 2 , Emma ) AND theme . pass ( 2 , 1 )
* girl ( 0 ) ; * yogurt ( 1 ) ; agent . pass ( 2 ,
farmer ( 0 ) ; * stage ( 1 ) ; * vase ( 2 ) ; nmod . on ( 0 , 1 ) AND nmod . in ( 1 , 2 ) AND agent . roll ( 3 , Emma ) AND theme . roll ( 3 , 0 )
farmer ( 0 )

Iteration:  58%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                   | 14/24 [00:01<00:01,  9.27it/s, acc=0.16]

* mandarin ( 0 ) ; bed ( 1 ) ; nmod . on ( 0 , 1 ) AND recipient . give ( 2 , Emma ) AND theme . give ( 2 , 0 ) AND agent . give ( 2 , Samuel )
* mandarin ( 0 ) ; bed ( 1 ) ; nmod . on ( 0 , 1
* child ( 0 ) ; * girl ( 1 ) ; agent . freeze ( 2 , 0 ) AND theme . freeze ( 2 , 1 )
* child ( 0 ) ; * girl ( 1 ) ; agent . freeze ( 2 ,
* cat ( 0 ) ; agent . draw ( 1 , 0 ) AND theme . draw ( 1 , William )
* cat ( 0 ) ; agent . draw ( 1 , 0 ) AND theme . draw
cake ( 0 ) ; agent . tolerate ( 1 , Jacob ) AND theme . tolerate ( 1 , 0 )
cake ( 0 ) ; agent . tolerate ( 1 , Jacob ) AND theme . tolerate (
* boy ( 0 ) ; carpet ( 1 ) ; * stage ( 2 ) ; nmod . on ( 0 , 1 ) AND nmod . on ( 1 , 2 ) AND agent . eat ( 3 , Mia ) AND theme . eat ( 3 , 0 )
* boy ( 0 ) ; carpet ( 1 ) ; * stage ( 2 ) ; nmod
* cake ( 0 ) ; * giraffe ( 1 ) ; theme . pass ( 2 , 0 ) AND recipient . pass ( 2 , 1 )
* cake ( 0 ) ; * giraffe ( 1 ) ; theme . pass ( 2 ,
cake ( 0 ) ; * bed ( 1 ) ; nmod . on ( 0 , 1 ) AND agent . award ( 2 , O

Iteration:  67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                  | 16/24 [00:01<00:00,  9.28it/s, acc=0.17]

* boy ( 0 ) ; * penny ( 1 ) ; * bird ( 2 ) ; recipient . lend ( 3 , 0 ) AND theme . lend ( 3 , 1 ) AND agent . lend ( 3 , 2 )
* boy ( 0 ) ; * penny ( 1 ) ; * bird ( 2 ) ;
* cake ( 0 ) ; table ( 1 ) ; goose ( 2 ) ; nmod . on ( 0 , 1 ) AND recipient . offer ( 3 , Olivia ) AND theme . offer ( 3 , 0 ) AND agent . offer ( 3 , 2 )
* cake ( 0 ) ; table ( 1 ) ; goose ( 2 ) ; nmod .
toothbrush ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll ( 1 , Liam )
toothbrush ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll (
bean ( 0 ) ; * baby ( 1 ) ; theme . split ( 2 , 0 ) AND agent . split ( 2 , 1 )
bean ( 0 ) ; * baby ( 1 ) ; theme . split ( 2 , 0
* cake ( 0 ) ; theme . sell ( 1 , 0 ) AND recipient . sell ( 1 , Emma ) AND agent . sell ( 1 , Liam )
* cake ( 0 ) ; theme . sell ( 1 , 0 ) AND recipient . sell
* biscuit ( 0 ) ; agent . paint ( 1 , Madison ) AND theme . paint ( 1 , 0 )
* biscuit ( 0 ) ; agent . paint ( 1 , Madison ) AND theme . paint
cake ( 0 ) ; * lamb ( 1 ) ; agent . give ( 2 , Ava ) AND 

Iteration:  75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                  | 18/24 [00:01<00:00,  9.27it/s, acc=0.17]

girl ( 0 ) ; agent . burn ( 1 , 0 ) AND theme . burn ( 1 , Henry )
girl ( 0 ) ; agent . burn ( 1 , 0 ) AND theme . burn (
teacher ( 0 ) ; * pizza ( 1 ) ; agent . give ( 2 , 0 ) AND recipient . give ( 2 , Olivia ) AND theme . give ( 2 , 1 )
teacher ( 0 ) ; * pizza ( 1 ) ; agent . give ( 2 , 0
* cake ( 0 ) ; * room ( 1 ) ; nmod . in ( 0 , 1 ) AND recipient . slip ( 2 , Emma ) AND theme . slip ( 2 , 0 )
* cake ( 0 ) ; * room ( 1 ) ; nmod . in ( 0 ,
turtle ( 0 ) ; bottle ( 1 ) ; table ( 2 ) ; stage ( 3 ) ; nmod . beside ( 1 , 2 ) AND nmod . beside ( 2 , 3 ) AND recipient . lend ( 4 , 0 ) AND theme . lend ( 4 , 1 ) AND agent . lend ( 4 , Michael )
turtle ( 0 ) ; bottle ( 1 ) ; table ( 2 ) ; stage ( 3
* cookie ( 0 ) ; girl ( 1 ) ; cat ( 2 ) ; theme . send ( 3 , 0 ) AND recipient . send ( 3 , 1 ) AND agent . send ( 3 , 2 )
* cookie ( 0 ) ; girl ( 1 ) ; cat ( 2 ) ; theme .
* gumball ( 0 ) ; theme . draw ( 1 , 0 ) AND agent . draw ( 1 , Emma )
* gumball ( 0 ) ; theme . draw ( 1 , 0 ) AND agent 

Iteration:  83%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                 | 20/24 [00:02<00:00,  9.27it/s, acc=0.17]

cake ( 0 ) ; table ( 1 ) ; nmod . on ( 0 , 1 ) AND agent . tolerate ( 2 , Emma ) AND theme . tolerate ( 2 , 0 )
cake ( 0 ) ; table ( 1 ) ; nmod . on ( 0 , 1 )
* cake ( 0 ) ; * tiger ( 1 ) ; theme . admire ( 2 , 0 ) AND agent . admire ( 2 , 1 )
* cake ( 0 ) ; * tiger ( 1 ) ; theme . admire ( 2 ,
* biscuit ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll ( 1 , Victoria )
* biscuit ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll
* crayon ( 0 ) ; boy ( 1 ) ; theme . eat ( 2 , 0 ) AND agent . eat ( 2 , 1 )
* crayon ( 0 ) ; boy ( 1 ) ; theme . eat ( 2 , 0
* donut ( 0 ) ; theme . float ( 1 , 0 ) AND agent . float ( 1 , Isabella )
* donut ( 0 ) ; theme . float ( 1 , 0 ) AND agent . float
child ( 0 ) ; * strawberry ( 1 ) ; recipient . wire ( 2 , 0 ) AND theme . wire ( 2 , 1 )
child ( 0 ) ; * strawberry ( 1 ) ; recipient . wire ( 2 , 0
* cookie ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll ( 1 , Joshua )
* cookie ( 0 ) ; theme . roll ( 1 , 0 ) AND agent . roll
dog ( 0 ) ; agent . hate ( 1 , 0 )

Iteration:  92%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                | 22/24 [00:02<00:00,  9.27it/s, acc=0.17]

girl ( 0 ) ; agent . like ( 1 , 0 ) AND ccomp . like ( 1 , 2 ) AND agent . sleep ( 2 , Mia )
girl ( 0 ) ; agent . like ( 1 , 0 ) AND ccomp . like (
agent . want ( 0 , Sebastian ) AND xcomp . want ( 0 , 1 ) AND agent . run ( 1 , Sebastian )
agent . want ( 0 , Sebastian ) AND xcomp . want ( 0 , 1 ) AND
girl ( 0 ) ; agent . roll ( 1 , 0 ) AND theme . roll ( 1 , Riley )
girl ( 0 ) ; agent . roll ( 1 , 0 ) AND theme . roll (
girl ( 0 ) ; box ( 1 ) ; recipient . sell ( 2 , 0 ) AND theme . sell ( 2 , 1 ) AND agent . sell ( 2 , Emma )
girl ( 0 ) ; box ( 1 ) ; recipient . sell ( 2 , 0 )
* cat ( 0 ) ; * cake ( 1 ) ; table ( 2 ) ; nmod . beside ( 1 , 2 ) AND agent . like ( 3 , 0 ) AND theme . like ( 3 , 1 )
* cat ( 0 ) ; * cake ( 1 ) ; table ( 2 ) ; nmod
* cake ( 0 ) ; bunny ( 1 ) ; theme . eat ( 2 , 0 ) AND agent . eat ( 2 , 1 )
* cake ( 0 ) ; bunny ( 1 ) ; theme . eat ( 2 , 0
* girl ( 0 ) ; * drink ( 1 ) ; vessel ( 2 ) ; nmod . on ( 1 , 2 ) AND agent . collapse ( 3 , 0 ) AND theme . collapse ( 

Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00,  9.26it/s, acc=0.17]

bear ( 0 ) ; house ( 1 ) ; nmod . beside ( 0 , 1 ) AND agent . freeze ( 2 , Emma ) AND theme . freeze ( 2 , 0 )
bear ( 0 ) ; house ( 1 ) ; nmod . beside ( 0 , 1 )
* box ( 0 ) ; * teacher ( 1 ) ; theme . give ( 2 , 0 ) AND recipient . give ( 2 , 1 )
* box ( 0 ) ; * teacher ( 1 ) ; theme . give ( 2 ,
yogurt ( 0 ) ; * girl ( 1 ) ; recipient . lend ( 2 , Mason ) AND theme . lend ( 2 , 0 ) AND agent . lend ( 2 , 1 )
yogurt ( 0 ) ; * girl ( 1 ) ; recipient . lend ( 2 , Mason
* chicken ( 0 ) ; * cart ( 1 ) ; * stage ( 2 ) ; nmod . in ( 0 , 1 ) AND nmod . beside ( 1 , 2 ) AND agent . crumple ( 3 , Aria ) AND theme . crumple ( 3 , 0 )
* chicken ( 0 ) ; * cart ( 1 ) ; * stage ( 2 ) ;
boy ( 0 ) ; cake ( 1 ) ; * doctor ( 2 ) ; agent . hope ( 3 , 0 ) AND ccomp . hope ( 3 , 4 ) AND theme . pass ( 4 , 1 ) AND recipient . pass ( 4 , 2 )
boy ( 0 ) ; cake ( 1 ) ; * doctor ( 2 ) ; agent .
* girl ( 0 ) ; cake ( 1 ) ; agent . like ( 2 , 0 ) AND ccomp . like ( 2 , 3 ) AND theme . worship ( 3 , 1 )
* girl ( 




In [None]:
per_cat_eval = {}
for cat in set(gen_dataset.eval_cat):
    per_cat_eval[cat] = [0, 0] # correct, total
trainer.model.eval()
epoch_iterator = tqdm(gen_dataloader, desc="Iteration", position=0, leave=True)
total_count = 0
correct_count = 0
for step, inputs in enumerate(epoch_iterator):
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    labels = inputs["labels"].to(device)
    outputs = trainer.model.generate(
        input_ids,
        attention_mask=attention_mask,
#         eos_token_id=model_config.eos_token_id,
#         max_length=args.max_seq_len
    )
    decoded_preds = tgt_tokenizer.batch_decode(outputs)
    decoded_labels = tgt_tokenizer.batch_decode(labels)

    input_labels = src_tokenizer.batch_decode(input_ids)
    for i in range(len(decoded_preds)):
        
        index_mapping = {}
        current_idx = 0
        for t in decoded_labels[i].split():
            if t.isnumeric():
                if int(t) not in index_mapping:
                    index_mapping[int(t)] = current_idx
                    current_idx += 1
        decoded_labels_ii = []
        for t in decoded_labels[i].split():
            if t.isnumeric():
                decoded_labels_ii += [str(index_mapping[int(t)])]
            else:
                decoded_labels_ii += [t]
                
        index_mapping = {}
        current_idx = 0
        for t in decoded_preds[i].split():
            if t.isnumeric():
                if int(t) not in index_mapping:
                    index_mapping[int(t)] = current_idx
                    current_idx += 1
        decoded_preds_ii = []
        for t in decoded_preds[i].split():
            if t.isnumeric():
                decoded_preds_ii += [str(index_mapping[int(t)])]
            else:
                decoded_preds_ii += [t]
            
        
        decoded_labels_ii_str = " ".join(decoded_labels_ii)
        decoded_preds_ii_str = " ".join(decoded_preds_ii)

        cat = gen_dataset.eval_cat[total_count]
        if decoded_preds_ii_str == decoded_labels_ii_str:
            correct_count += 1
            per_cat_eval[cat][0] += 1
            if cat == "obj_pp_to_subj_pp":
                pass
        else:
            if cat == "obj_pp_to_subj_pp":
                # pass
                print("input: ", input_labels[i])
                print("pred: ", decoded_preds_ii_str)
                print("actual: ", decoded_labels_ii_str)
                print("cat: ", cat)
                print()
        total_count += 1
        per_cat_eval[cat][1] += 1
    current_acc = correct_count/total_count
    epoch_iterator.set_postfix({'acc': current_acc})

struct_pp_acc = 0
struct_cp_acc = 0
struct_obj_subj_acc = 0

lex_acc = 0
lex_count = 0
for k, v in per_cat_eval.items():
    if k  == "pp_recursion":
        struct_pp_acc = 100 * v[0]/v[1]
    elif k  == "cp_recursion":
        struct_cp_acc = 100 * v[0]/v[1]
    elif k  == "obj_pp_to_subj_pp":
        struct_obj_subj_acc = 100 * v[0]/v[1]
    elif k  == "subj_to_obj_proper":
        subj_to_obj_proper_acc = 100 * v[0]/v[1]
    elif k  == "prim_to_obj_proper":
        prim_to_obj_proper_acc = 100 * v[0]/v[1]
    elif k  == "prim_to_subj_proper":
        prim_to_subj_proper_acc = 100 * v[0]/v[1]
    else:
        lex_acc += v[0]
        lex_count += v[1]
lex_acc /= lex_count
lex_acc *= 100
current_acc *= 100

print(f"obj_pp_to_subj_pp: {struct_obj_subj_acc}")
print(f"cp_recursion: {struct_cp_acc}")
print(f"pp_recursion: {struct_pp_acc}")
print(f"subj_to_obj_proper: {subj_to_obj_proper_acc}")
print(f"prim_to_obj_proper: {prim_to_obj_proper_acc}")
print(f"prim_to_subj_proper: {prim_to_subj_proper_acc}")
print(f"LEX: {lex_acc}")
print(f"OVERALL: {current_acc}")

results[lf] = {
    "obj_pp_to_subj_pp" : struct_obj_subj_acc,
    "cp_recursion" : struct_cp_acc,
    "pp_recursion" : struct_pp_acc,
    "subj_to_obj_proper" : subj_to_obj_proper_acc,
    "prim_to_obj_proper" : prim_to_obj_proper_acc,
    "prim_to_subj_proper" : prim_to_subj_proper_acc,
    "lex_acc" : lex_acc,
    "overall_acc" : current_acc
}