<a href="https://colab.research.google.com/github/fansie1/chemRxnResource/blob/main/chemrxnresource.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install seqeval
!pip install transformers
!pip install -U --no-cache-dir gdown --pre

import pdb

In [None]:
# model
from dataclasses import dataclass
from typing import List, Optional, Tuple
import logging

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

from transformers import BertForTokenClassification

logger = logging.getLogger(__name__)



class BertForTagging(BertForTokenClassification):
    def __init__(self, config, use_cls=False):
        super(BertForTagging, self).__init__(config)

        self.use_cls = use_cls
        if self.use_cls:
            self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        decoder_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs[0]
#         pdb.set_trace()
        if self.use_cls:
            batch_size, seq_length, hidden_dim = sequence_output.shape
            extended_cls_h = outputs[1].unsqueeze(1).expand(batch_size, seq_length, hidden_dim)
            sequence_output = torch.cat([sequence_output, extended_cls_h], 2)

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)

    def decode(self, logits, mask):
        preds = torch.argmax(logits, dim=2).cpu().numpy()
        batch_size, seq_len = preds.shape
        preds_list = [[] for _ in range(batch_size)]
        for i in range(batch_size):
            for j in range(seq_len):
                if mask[i, j]:
                    preds_list[i].append(preds[i,j])
        return preds_list

In [None]:
# 训练
import logging
import math
import os
import re
from typing import Any, Callable, Optional
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm, trange

from transformers import Trainer
from transformers import PreTrainedModel
# from transformers import is_wandb_available
from transformers import TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers import AdamW, get_linear_schedule_with_warmup

logger = logging.getLogger(__name__)

class IETrainer(Trainer):
    """
    IETrainer is inheritated from from transformers.Trainer, optimized for IE tasks.
    """
    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics=None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
        use_crf: Optional[bool]=False,
        epoch: int = 1
    ):
        super(IETrainer, self).__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
            optimizers=optimizers
        )
        self.use_crf = use_crf
        self.epoch = epoch
        self.global_step = None

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict:
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        output = self._prediction_loop(eval_dataloader, description="Evaluation")

        self._log(output['metrics'])

        return output

    def predict(self, test_dataset: Dataset) -> Dict:
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self,
        dataloader: DataLoader,
        description: str
    ) -> Dict:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`
        Works both with or without labels.
        """
        model = self.model
        batch_size = dataloader.batch_size

        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)

        model.eval()

        eval_losses: List[float] = []
        preds_ids = []
        label_ids = []

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"]
            )

            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            mask = inputs["decoder_mask"].to(torch.bool)
            preds = model.decode(logits, mask=mask)
            preds_ids.extend(preds)
            if inputs.get("labels") is not None:
                labels = [inputs["labels"][i, mask[i]].tolist() \
                            for i in range(inputs["labels"].shape[0])]
                label_ids.extend(labels)
                assert len(preds) == len(labels)
                assert len(preds[0]) == len(labels[0])

        if self.compute_metrics is not None and \
                len(preds_ids) > 0 and \
                len(label_ids) > 0:
            metrics = self.compute_metrics(preds_ids, label_ids)
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics['eval_loss'] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return {'predictions': preds_ids, 'label_ids': label_ids, 'metrics': metrics}

    def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        # if is_wandb_available():
        #     if self.is_world_master():
        #         wandb.log(logs, step=self.global_step)
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            logger.info(
                {k:round(v, 4) if isinstance(v, float) else v for k, v in output.items()}
            )





In [None]:
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
from transformers import AutoTokenizer

import torch
from torch import nn
from torch.utils.data.dataset import Dataset



@dataclass
class InputExample:
    """
    A single training/test example for token classification.

    Args:
        guid: Unique id for the example.
        words: list. The words of the sequence.
        labels: (Optional) list. The labels for each word of the sequence. This should be
        specified for train and dev examples, but not for test examples.
    """
    guid: str
    words: List[str]
    metainfo: Optional[str] = None
    labels: Optional[List[str]] = None


@dataclass
class InputFeatures:
    """
    A single set of features of data.
    Property names are the same names as the corresponding inputs to a model.
    """

    input_ids: List[int]
    attention_mask: List[int]
    token_type_ids: Optional[List[int]] = None
    label_ids: Optional[List[int]] = None
    decoder_mask: Optional[List[bool]] = None

class ProdDataset(Dataset):
    features: List[InputFeatures]
    pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
    # Use cross entropy ignore_index as padding label id so that only
    # real label ids contribute to the loss later.

    def __init__(
        self,
        data_file: str,
        tokenizer: AutoTokenizer,
        labels: List[str],
        model_type: str,
        max_seq_length: Optional[int] = None,
        overwrite_cache=False
    ):
        # Load data features from cache or dataset file
        data_dir = os.path.dirname(data_file)
        fname = os.path.basename(data_file)
        cached_features_file = os.path.join(
            "/kaggle/working/",
            "cached_{}_{}_{}".format(
                fname,
                tokenizer.__class__.__name__,
                str(max_seq_length)
            ),
        )

        if os.path.exists(cached_features_file) and not overwrite_cache:
            logger.info(f"Loading features from cached file {cached_features_file}")
            self.features = torch.load(cached_features_file)
        else:
            logger.info(f"Creating features from dataset file at {data_file}")
            examples = read_examples_from_file(data_file)
            self.features = convert_examples_to_features(
                examples,
                labels,
                max_seq_length,
                tokenizer,
                cls_token=tokenizer.cls_token,
                cls_token_segment_id=0,
                sep_token=tokenizer.sep_token,
                pad_token=tokenizer.pad_token_id,
                pad_token_segment_id=tokenizer.pad_token_type_id,
                pad_token_label_id=self.pad_token_label_id,
            )
            logger.info(f"Saving features into cached file {cached_features_file}")
            torch.save(self.features, cached_features_file)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i) -> InputFeatures:
        return self.features[i]

def read_examples_from_file(file_path) -> List[InputExample]:
    guid_index = 1
    examples = []
    with open(file_path, encoding="utf-8") as f:
        words, labels = [], []
        metainfo = None
        for line in f:
            line = line.rstrip()
            if line.startswith("#\tpassage"):
                metainfo = line
            elif line == "":
                if words:
                    examples.append(InputExample(
                        guid=f"{guid_index}",
                        words=words,
                        metainfo=metainfo,
                        labels=labels
                    ))
                    guid_index += 1
                    words, labels = [], []
            else:
                splits = line.split("\t")
                words.append(splits[0])
                if len(splits) > 1:
                    labels.append(splits[-1])
                else:
                    # Examples could have no label for plain test files
                    labels.append("O")
        if words:
            examples.append(InputExample(
                guid=f"{guid_index}",
                words=words,
                metainfo=metainfo,
                labels=labels
            ))

    return examples


def convert_examples_to_features(
    examples: List[InputExample],
    label_list: List[str],
    max_seq_length: int,
    tokenizer: AutoTokenizer,
    cls_token="[CLS]",
    cls_token_segment_id=0,
    sep_token="[SEP]",
    pad_token=0,
    pad_token_segment_id=0,
    pad_token_label_id=-100,
    sequence_a_segment_id=0,
    sequence_b_segment_id=1,
    mask_padding_with_zero=True,
    verbose=False
) -> List[InputFeatures]:
    """ Loads a data file into a list of `InputFeatures`
    """
    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10_000 == 0:
            logger.info("Writing example %d of %d", ex_index, len(examples))

        tokens = []
        label_ids = []
        for word, label in zip(example.words, example.labels):
            word_tokens = tokenizer.tokenize(word)
            # word_tokens = word_tokens[:5]

            if len(word_tokens) > 0:
                tokens.extend(word_tokens)
                label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))

        if len(tokens) > max_seq_length - 2:
            logger.warning("Sequence length exceed {} (cut).".format(max_seq_length))
            tokens = tokens[: (max_seq_length - 2)]
            label_ids = label_ids[: (max_seq_length - 2)]

        tokens += [sep_token]
        label_ids += [pad_token_label_id]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        tokens = [cls_token] + tokens
        label_ids = [pad_token_label_id] + label_ids
        segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        seq_length = len(input_ids)
        padding_length = max_seq_length - len(input_ids)
        input_ids += [pad_token] * padding_length
        input_mask += [0 if mask_padding_with_zero else 1] * padding_length
        segment_ids += [pad_token_segment_id] * padding_length
        label_ids += [pad_token_label_id] * padding_length

        decoder_mask = [(x != pad_token_label_id) for x in label_ids]

        # assert len(input_ids) == max_seq_length
        # assert len(input_mask) == max_seq_length
        # assert len(segment_ids) == max_seq_length
        # assert len(label_ids) == max_seq_length

        if verbose and ex_index < 1:
            logger.info("*** Example ***")
            logger.info("guid: {} (length: {})".format(example.guid, seq_length))
            logger.info("tokens: %s", " ".join([str(x) for x in tokens[:seq_length]]))
            logger.info("input_ids: %s", " ".join([str(x) for x in input_ids[:seq_length]]))
            # logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
            # logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
            logger.info("label_ids: %s", " ".join([str(x) for x in label_ids[:seq_length]]))
            logger.info("decode_mask: %s", " ".join([str(x) for x in decoder_mask[:seq_length]]))

        features.append(
            InputFeatures(
                input_ids=input_ids,
                attention_mask=input_mask,
                token_type_ids=segment_ids,
                label_ids=label_ids,
                decoder_mask=decoder_mask
            )
        )
    return features


In [None]:
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from tqdm.auto import tqdm, trange

from seqeval.metrics import f1_score, precision_score, recall_score
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

from transformers import AutoConfig, AutoTokenizer
from transformers.data.data_collator import default_data_collator
from transformers import set_seed


logger = logging.getLogger(__name__)

def train(model_args, data_args, train_args):
    if (
        os.path.exists(train_args.output_dir)
        and os.listdir(train_args.output_dir)
        and train_args.do_train
        and not train_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({train_args.output_dir}) already exists and is not empty."
             " Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # logger = create_logger(name="train_prod", save_dir=train_args.output_dir)
    logger.info("Training/evaluation parameters %s", train_args)

    # Set seed
    set_seed(train_args.seed)

    # Prepare prod-ext task
    labels = get_labels(data_args.labels)
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast,
    )
    model_args.use_crf = False
    if model_args.use_crf:
        # model = BertCRFForTagging.from_pretrained(
        #     model_args.model_name_or_path,
        #     config=config,
        #     cache_dir=model_args.cache_dir,
        #     tagging_schema="BIO",
        #     use_cls=model_args.use_cls
        # )
        model = BertBiLSTMCRFForTagging.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            tagging_schema="BIO",
            use_cls=model_args.use_cls,
            hidden_dim=config.hidden_size
        )
    else:
        model = BertForTagging.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            use_cls=model_args.use_cls
        )

    # Get datasets
    train_dataset = (
        ProdDataset(
            # data_file=os.path.join(data_args.data_dir, "train.txt"),
            data_file=data_args.data_dir_train,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache
        )
        if train_args.do_train
        else None
    )
    eval_dataset = (
        ProdDataset(
            # data_file=os.path.join(data_args.data_dir, "dev.txt"),
            data_file=data_args.data_dir_dev,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache
        )
        if train_args.do_eval
        else None
    )

    def compute_metrics(predictions, label_ids) -> Dict:
        label_list = [[label_map[x] for x in seq] for seq in label_ids]
        preds_list = [[label_map[x] for x in seq] for seq in predictions]

        return {
            "precision": precision_score(label_list, preds_list),
            "recall": recall_score(label_list, preds_list),
            "f1": f1_score(label_list, preds_list),
        }

    metrics_fn = compute_metrics
    dataset_len = len(train_dataset)
    batch_size = 16
    total_steps = (dataset_len // batch_size) * train_args.num_train_epochs if dataset_len % batch_size == 0 else \
        (dataset_len // batch_size + 1) * train_args.num_train_epochs
    train_args.warmup_steps = 0.1 * total_steps

    # Initialize our Trainer
    trainer = IETrainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=metrics_fn,
        use_crf=model_args.use_crf,
        optimizers=get_optimizer_grouped_parameters(
            use_crf=model_args.use_crf,
            args=train_args,
            model=model,
            num_training_steps=total_steps),
            # num_training_steps=len(train_dataset) * train_args.num_train_epochs),
        epoch=train_args.num_train_epochs
    )

    # Training
    if train_args.do_train:
        trainer.train()
        # Pass model_path to train() if continue training from an existing ckpt.
        # trainer.train(
        #     model_path=model_args.model_name_or_path
        #     if os.path.isdir(model_args.model_name_or_path)
        #     else None
        # )
        trainer.save_model()
        tokenizer.save_pretrained(train_args.output_dir)

    # Evaluation
    if train_args.do_eval:
        logger.info("*** Evaluate ***")

        output = trainer.evaluate()
        predictions = output['predictions']
        label_ids = output['label_ids']
        metrics = output['metrics']

        output_eval_file = os.path.join(train_args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

        preds_list = [[label_map[x] for x in seq] for seq in predictions]

        # Save predictions
        write_predictions(
            data_args.data_dir_dev,
            os.path.join(train_args.output_dir, "eval_predictions.txt"),
            preds_list
        )

    # Predict
    if train_args.do_predict:
        test_dataset = ProdDataset(
            data_file=data_args.data_dir_test,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
        )

        output = trainer.predict(test_dataset)

        predictions = output['predictions']
        label_ids = output['label_ids']
        metrics = output['metrics']

        preds_list = [[label_map[x] for x in seq] for seq in predictions]

        output_test_results_file = os.path.join(train_args.output_dir, "test_results.txt")
        with open(output_test_results_file, "w") as writer:
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

        # Save predictions
        write_predictions(
            data_args.data_dir_test,
            os.path.join(train_args.output_dir, "test_predictions.txt"),
            preds_list
        )

def get_optimizer_grouped_parameters(use_crf, args, model, num_training_steps):
    no_decay = ["bias", "LayerNorm.weight"]
    if use_crf:
        crf = "crf"
        crf_lr = args.crf_learning_rate
        logger.info(f"Learning rate for CRF: {crf_lr}")
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if (not any(nd in n for nd in no_decay)) and (crf not in n)
                ],
                "weight_decay": args.weight_decay
            },
            {
                "params": [p for p in model.crf.parameters()],
                "weight_decay": args.weight_decay,
                "lr": crf_lr
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay) and (not crf not in n)
                ],
                "weight_decay": 0.0,
            },
        ]
    else:
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": args.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        eps=args.adam_epsilon
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_training_steps
    )

    return optimizer, scheduler


def predict(model_args, predict_args):
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # logger = create_logger(name="predict_prod", save_dir=train_args.output_dir)
    logger.info("Predict parameters %s", predict_args)

    # Prepare prod-ext task
    labels = get_labels(predict_args.labels)
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast,
    )
    if model_args.use_crf:
        # model = BertCRFForTagging.from_pretrained(
        #     model_args.model_name_or_path,
        #     config=config,
        #     cache_dir=model_args.cache_dir,
        #     tagging_schema="BIO"
        # )
        model = BertBiLSTMCRFForTagging.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            tagging_schema="BIO",
            use_cls=model_args.use_cls,
            hidden_dim=config.hidden_size
        )
    else:
        model = BertForTagging.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir
        )

    device = torch.device(
                "cuda"
                if (not predict_args.no_cuda and torch.cuda.is_available())
                else "cpu"
            )
    model = model.to(device)

    # load test dataset
    test_dataset = ProdDataset(
        data_file=predict_args.input_file,
        tokenizer=tokenizer,
        labels=labels,
        model_type=config.model_type,
        max_seq_length=predict_args.max_seq_length,
        overwrite_cache=predict_args.overwrite_cache,
    )

    sampler = SequentialSampler(test_dataset)
    data_loader = DataLoader(
        test_dataset,
        sampler=sampler,
        batch_size=predict_args.batch_size,
        collate_fn=default_data_collator
    )

    logger.info("***** chuRunning Prediction *****")
    logger.info("  Num examples = {}".format(len(data_loader.dataset)))
    logger.info("  Batch size = {}".format(predict_args.batch_size))

    model.eval()

    with open(predict_args.input_file, "r") as f:
        all_preds = []
        for inputs in tqdm(data_loader, desc="Predicting"):
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.to(device)
            with torch.no_grad():
                outputs = model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    token_type_ids=inputs['token_type_ids']
                )
                logits = outputs[0]

            preds = model.decode(logits, inputs['decoder_mask'].bool())
            preds_list = [[label_map[x] for x in seq] for seq in preds]

            all_preds += preds_list

    write_predictions(
        predict_args.input_file,
        predict_args.output_file,
        all_preds
    )



In [None]:
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import os

from transformers import HfArgumentParser
from transformers import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier."}
    )
    use_fast: bool = field(
        default=False, metadata={"help": "Set this flag to use fast tokenization."}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_crf: bool = field(
        default=False, metadata={"help": "Whether using CRF for inference."}
    )
    use_cls: bool = field(
        default=False, metadata={"help": "Whether concatenating token representation with [CLS]."}
    )


@dataclass
class ExTrainingArguments(TrainingArguments):
    crf_learning_rate: float = field(
        default=5e-3, metadata={"help": "The initial learning rate of CRF parameters for Adam."}
    )


@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    data_dir: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
    )
    data_dir_train: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
    )
    data_dir_dev: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
    )
    data_dir_test: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
    )
    labels: Optional[str] = field(
        default=None,
        metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."},
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    evaluate_during_training: bool = field(
        default=False, metadata={"help": "evaluate_during_training"}
    )


@dataclass
class PredictArguments:
    input_file: str = field(
        metadata={"help": "Path to a file containing sentences to be extracted (can be a single column file without labels)."}
    )
    output_file: str = field(
        default="output_file",
        metadata={"help": "Path to a file saving the outputs."}
    )
    labels: Optional[str] = field(
        default=None,
        metadata={"help": "Path to a file containing all labels."},
    )
    max_seq_length: int = field(
        default=128,
        metadata={"help": "The maximum total input sequence length after tokenization. Sequences longer "
                  "than this will be truncated, sequences shorter will be padded."},
    )
    batch_size: int = field(
        default=8, metadata={"help": "Batch size for prediction."}
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached test data."}
    )
    no_cuda: bool = field(
        default=False,
        metadata={"help": "Do not use CUDA even when it is available."}
    )


def parse_train_args(args):
    parser = HfArgumentParser((ModelArguments, DataArguments, ExTrainingArguments))
    if len(args) == 1 and args[0].endswith(".json"):
        model_args, data_args, train_args = parser.parse_json_file(
            json_file=os.path.abspath(args[0]))
    else:
        model_args, data_args, train_args = parser.parse_args_into_dataclasses(args=args)

    return model_args, data_args, train_args


def parse_predict_args(args):
    parser = HfArgumentParser((ModelArguments, PredictArguments))
    if len(args) == 1 and args[0].endswith(".json"):
        model_args, predict_args = parser.parse_json_file(
            json_file=os.path.abspath(args[0]))
    else:
        model_args, predict_args = parser.parse_args_into_dataclasses(args=args)

    return model_args, predict_args



In [None]:
!gdown --no-cookies https://drive.google.com/uc?id=1--UFHJY3ldytNTsLcseIMDXEyzFeYEuj
!gdown --no-cookies https://drive.google.com/uc?id=1-oYW7JgHyhFj1HpbYvpHdsK8y-RwF2pW
!gdown --no-cookies https://drive.google.com/uc?id=1-KkDBrlyr3DF91BOTgp-vgCLwMe2JAmX
!gdown --no-cookies https://drive.google.com/uc?id=105JaXg7iAYF-0A7_srvSnaU0pfPE5ChH
!gdown --no-cookies https://drive.google.com/uc?id=1-6QHFiBYVWaNtbgCXdNbVfZ0WwdIB68N
train = "/content/train_optim_concat.txt"
dev = "/content/dev_optim_concat.txt"
test = "/content/test_optim_concat.txt"
prod_labels = "/content/prod_labels.txt"
prod_train = "/content/prod_train.json"

In [None]:
model_args, data_args, train_args = parse_train_args([prod_train])

model_args.model_name_or_path = "jiangg/chembert_cased"
data_args.data_dir = "/content"
data_args.data_dir_train = train
data_args.data_dir_dev = dev
data_args.data_dir_test = test
data_args.labels = prod_labels

train(model_args, data_args, train_args)