In [1]:
import os
from functools import partial
import datetime

import wandb
import torch
from transformers import Trainer, TrainingArguments, AutoConfig, DataCollatorWithPadding
from transformers.trainer_utils import set_seed
from transformers.integrations import WandbCallback
from scipy.special import expit

from data import NERDataModule
from config import get_configs
from model import get_pretrained
from utils import (
    kaggle_metrics,
    kaggle_metrics2,
    DataCollatorWithMasking,
    set_wandb_env_vars,
    reinit_model_weights,
)
from callbacks import NewWandbCB, SaveCallback, MaskingProbCallback, BasicSWACallback
from sift import SiftTrainer

if __name__ == "__main__":

    config_file = "j-dv3l-repl-2-pp-cv.yml"
    output = config_file.split(".")[0]
    cfg, args = get_configs(config_file)
    set_seed(args["seed"])
    
    cfg["model_name_or_path"] = cfg["model_name_or_path"].format(fold=0)

    datamodule = NERDataModule(cfg)
    datamodule.prepare_datasets()

    for fold in range(3):

        cfg, args = get_configs(config_file)
        cfg["fold"] = fold
        cfg["model_name_or_path"] = cfg["model_name_or_path"].format(fold=fold)
        args["output_dir"] = f"{output}-f{fold}"

        args = TrainingArguments(**args)

        eval_dataset = datamodule.get_eval_dataset(fold=fold)
        
        print(f"Eval dataset length: {len(eval_dataset)}")
        compute_metrics = partial(kaggle_metrics, dataset=eval_dataset)

        model_config = AutoConfig.from_pretrained(cfg["model_name_or_path"], use_auth_token=os.environ.get("HUGGINGFACE_HUB_TOKEN", True))

        model = get_pretrained(model_config, cfg["model_name_or_path"]+"/pytorch_model.bin")

        data_collator = DataCollatorWithPadding(
            tokenizer=datamodule.tokenizer,
            return_tensors="pt",
            padding=True,
        )

        Trainer = SiftTrainer if cfg.get("use_sift") else Trainer
        trainer = Trainer(
            model=model,
            args=args,
            tokenizer=datamodule.tokenizer,
            data_collator=data_collator,
        )

        trainer.remove_callback(WandbCallback)

        preds = trainer.predict(eval_dataset.remove_columns([x for x in eval_dataset.column_names if x not in {"input_ids", "attention_mask", "token_type_ids"}]))
        
        print(kaggle_metrics(preds, eval_dataset))
        print("v2")
        print(kaggle_metrics2(preds, eval_dataset))
        

        torch.cuda.empty_cache()


        

#0:   0%|          | 0/2043 [00:00<?, ?ex/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


 

#1:   0%|          | 0/2043 [00:00<?, ?ex/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


 

#2:   0%|          | 0/2043 [00:00<?, ?ex/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  

#3:   0%|          | 0/2043 [00:00<?, ?ex/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


#4:   0%|          | 0/2043 [00:00<?, ?ex/s]

 

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


#5:   0%|          | 0/2043 [00:00<?, ?ex/s]

 

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


#6:   0%|          | 0/2042 [00:00<?, ?ex/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Eval dataset length: 2860


{'precision': 0.874704550543627, 'recall': 0.9079803719008265, 'f1': 0.891031895885342}
v2
{'precision': 0.7489993084225748, 'recall': 0.9230371900826446, 'f1': 0.8269606765621204}
Eval dataset length: 2860


{'precision': 0.8690023630946265, 'recall': 0.8992132681267276, 'f1': 0.8838497309159308}
v2
{'precision': 0.740963725426728, 'recall': 0.9126355517754625, 'f1': 0.8178884058834038}
Eval dataset length: 2860


{'precision': 0.8675240046404857, 'recall': 0.9099523612261806, 'f1': 0.888231800548416}
v2
{'precision': 0.7441647024467575, 'recall': 0.9236743993371996, 'f1': 0.8242592271703344}
