In [None]:
# coding=utf-8
import logging
import os
import sys
import pdb
import subprocess

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn

# from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from utils_ner import NerDataset, Split, get_labels

logger = logging.getLogger(__name__)

def main():
    # 固定参数
    model_name_or_path = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2/"
    data_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test"
    output_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test"
    labels_path = '/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test/'+ 'labels.txt' # 或者 "path/to/labels.txt"
    max_seq_length = 384
    overwrite_cache = False
    cache_dir = None    
    use_fast_tokenizer = False
    seed = 42


    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info("Training/evaluation parameters %s", output_dir)

    set_seed(seed)

    labels = get_labels(labels_path)
    label_map = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)


    # labels = get_labels(data_args.labels)
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    config = AutoConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        cache_dir=cache_dir,
        use_fast=use_fast_tokenizer,
    )
    model = AutoModelForTokenClassification.from_pretrained(
        model_name_or_path,
        config=config,
        cache_dir=cache_dir,
    )

    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)
        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    # 固定的训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        do_predict=True,
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
    )

    # Predict
    test_dataset = NerDataset(
        data_dir=data_dir,
        tokenizer=tokenizer,
        labels=labels,
        model_type=config.model_type,
        max_seq_length=max_seq_length,
        overwrite_cache=overwrite_cache,
        mode=Split.test,
    )

    predictions, label_ids, metrics = trainer.predict(test_dataset)
    preds_list, _ = align_predictions(predictions, label_ids)

    # Save predictions
    # output_test_results_file = os.path.join(output_dir, "test_results.txt")
    # with open(output_test_results_file, "w") as writer:
    #     logger.info("***** Test results *****")
    #     for key, value in metrics.items():
    #         logger.info("  %s = %s", key, value)
    #         writer.write("%s = %s\n" % (key, value))

    output_test_predictions_file = os.path.join(output_dir, "test_predictions.txt")
    with open(output_test_predictions_file, "w") as writer:
        with open(os.path.join(data_dir, "test.txt"), "r") as f:
            example_id = 0
            for line in f:
                if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                    writer.write(line)
                    if not preds_list[example_id]:
                        example_id += 1
                elif preds_list[example_id]:
                    entity_label = preds_list[example_id].pop(0)
                    output_line = line.split()[0] + " " + entity_label + "\n"
                    writer.write(output_line)
                else:
                    logger.warning(
                        "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
                    )

if __name__ == "__main__":
    main()


In [None]:
# coding=utf-8
import logging
import os
import sys
import subprocess

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn

from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from utils_ner import NerDataset, Split, get_labels

logger = logging.getLogger(__name__)

def main(input_text=None):
    # 固定参数
    model_name_or_path = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2/"
    data_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name"
    output_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test"
    labels_path = '/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name/'+ 'labels.txt' # 或者 "path/to/labels.txt"
    max_seq_length = 384
    overwrite_cache = False
    cache_dir = None
    use_fast_tokenizer = False
    seed = 42

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info("Training/evaluation parameters %s", output_dir)

    set_seed(seed)

    labels = get_labels(labels_path)
    label_map = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    config = AutoConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        cache_dir=cache_dir,
        use_fast=use_fast_tokenizer,
    )
    model = AutoModelForTokenClassification.from_pretrained(
        model_name_or_path,
        config=config,
        cache_dir=cache_dir,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)
        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    def predict_text(input_text):
        tokens = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_seq_length).to(device)
        model.eval()
        with torch.no_grad():
            output = model(**tokens)
        predictions = output.logits.cpu().numpy()
        preds_list, _ = align_predictions(predictions, tokens['input_ids'].cpu().numpy())
        return preds_list

    # 固定的训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        do_predict=True,
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
    )

    # Predict
    if input_text:
        preds_list = predict_text(input_text)
        print("Predictions:", preds_list)
    else:
        test_dataset = NerDataset(
            data_dir=data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=max_seq_length,
            overwrite_cache=overwrite_cache,
            mode=Split.test,
        )

        predictions, label_ids, metrics = trainer.predict(test_dataset)
        preds_list, _ = align_predictions(predictions, label_ids)

        # Save predictions
        output_test_results_file = os.path.join(output_dir, "test_results.txt")
        with open(output_test_results_file, "w") as writer:
            logger.info("***** Test results *****")
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

        output_test_predictions_file = os.path.join(output_dir, "test_predictions.txt")
        with open(output_test_predictions_file, "w") as writer:
            with open(os.path.join(data_dir, "test.txt"), "r") as f:
                example_id = 0
                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)
                        if not preds_list[example_id]:
                            example_id += 1
                    elif preds_list[example_id]:
                        entity_label = preds_list[example_id].pop(0)
                        output_line = line.split()[0] + " " + entity_label + "\n"
                        writer.write(output_line)
                    else:
                        logger.warning(
                            "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
                        )

if __name__ == "__main__":
    input_text = "Thus, changes in plasma  PC(20:1) levels, plasma S1P d18:1 levels, plasma MonCer d18:1 levels or plasma LacCer d18:1 levels were inferred to be disease-induced changes in Alzheimer's disease or DLB"  # 在这里输入你的文本
    main(input_text)


In [13]:
# coding=utf-8
import logging
import os
import sys
import subprocess

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn

from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from utils_ner import NerDataset, Split, get_labels

logger = logging.getLogger(__name__)

@dataclass
class InputExample:
    guid: str
    words: List[str]
    labels: List[str]

class CustomNerDataset(NerDataset):
    def __init__(self, tokenizer, input_text, labels, max_seq_length):
        self.tokenizer = tokenizer
        self.input_text = input_text
        self.labels = labels
        self.max_seq_length = max_seq_length
        self.examples = self._create_examples()

    def _create_examples(self):
        words = self.input_text.split()
        labels = ["O"] * len(words)
        return [InputExample(guid="input_text", words=words, labels=labels)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        tokens = []
        label_ids = []
        for word, label in zip(example.words, example.labels):
            word_tokens = self.tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            label_ids.extend([self.labels.index(label)] * len(word_tokens))
        if len(tokens) > self.max_seq_length - 2:
            tokens = tokens[: (self.max_seq_length - 2)]
            label_ids = label_ids[: (self.max_seq_length - 2)]
        tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
        label_ids = [self.labels.index("O")] + label_ids + [self.labels.index("O")]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_ids)
        padding_length = self.max_seq_length - len(input_ids)
        input_ids += [self.tokenizer.pad_token_id] * padding_length
        attention_mask += [0] * padding_length
        label_ids += [self.labels.index("O")] * padding_length

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }

def main(input_text=None):
    # 固定参数
    model_name_or_path = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2/"
    data_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name"
    output_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test"
    labels_path = '/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name/'+ 'labels.txt' # 或者 "path/to/labels.txt"
    max_seq_length = 384
    overwrite_cache = False
    cache_dir = None
    use_fast_tokenizer = False
    seed = 42

    logging.basicConfig(
        format="%(asctime)s - %(levellevelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info("Training/evaluation parameters %s", output_dir)

    set_seed(seed)

    labels = get_labels(labels_path)
    label_map = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    config = AutoConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        cache_dir=cache_dir,
        use_fast=use_fast_tokenizer,
    )
    model = AutoModelForTokenClassification.from_pretrained(
        model_name_or_path,
        config=config,
        cache_dir=cache_dir,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)
        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    def predict_text(input_text):
        # 创建临时数据集
        dataset = CustomNerDataset(tokenizer, input_text, labels, max_seq_length)
        predictions, label_ids, metrics = trainer.predict(dataset)
        preds_list, _ = align_predictions(predictions, label_ids)
        return preds_list

    # 固定的训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        do_predict=True,
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
    )

    # Predict
    if input_text:
        preds_list = predict_text(input_text)
        print("Predictions:", preds_list)
    else:
        test_dataset = NerDataset(
            data_dir=data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=max_seq_length,
            overwrite_cache=overwrite_cache,
            mode=Split.test,
        )

        predictions, label_ids, metrics = trainer.predict(test_dataset)
        preds_list, _ = align_predictions(predictions, label_ids)

        # Save predictions
        output_test_results_file = os.path.join(output_dir, "test_results.txt")
        with open(output_test_results_file, "w") as writer:
            logger.info("***** Test results *****")
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

        output_test_predictions_file = os.path.join(output_dir, "test_predictions.txt")
        with open(output_test_predictions_file, "w") as writer:
            with open(os.path.join(data_dir, "test.txt"), "r") as f:
                example_id = 0
                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)
                        if not preds_list[example_id]:
                            example_id += 1
                    elif preds_list[example_id]:
                        entity_label = preds_list[example_id].pop(0)
                        output_line = line.split()[0] + " " + entity_label + "\n"
                        writer.write(output_line)
                    else:
                        logger.warning(
                            "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
                        )

if __name__ == "__main__":
    input_text = "Thus, changes in plasma  PC(20:1) levels, plasma S1P d18:1 levels, plasma MonCer d18:1 levels or plasma LacCer d18:1 levels were inferred to be disease-induced changes in Alzheimer's disease or DLB"  # 在这里输入你的文本
    main(input_text)


07/18/2024 16:03:33 - INFO - __main__ -   Training/evaluation parameters /home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test


dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


Predictions: [['O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'O', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-lipid', 'I-lipid', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-disease', 'I-disease', 'I-disease', 'I-disease', 'I-disease', 'I-disease', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'I-disease', 'O', 'I-lipid', 'O', 'O', 'O', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-disease', 'I-disease', 'I-disease', 'O', 'I-disease', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'I-disease', 'B-lipid', 'O', 'I-lipid', 'I-disease', 'I-disease', 'I-dis

In [14]:
# coding=utf-8
import logging
import os
import sys
import subprocess

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn

from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from utils_ner import NerDataset, Split, get_labels

logger = logging.getLogger(__name__)

@dataclass
class InputExample:
    guid: str
    words: List[str]
    labels: List[str]

class CustomNerDataset(NerDataset):
    def __init__(self, tokenizer, input_text, labels, max_seq_length):
        self.tokenizer = tokenizer
        self.input_text = input_text
        self.labels = labels
        self.max_seq_length = max_seq_length
        self.examples = self._create_examples()

    def _create_examples(self):
        words = self.input_text.split()
        labels = ["O"] * len(words)
        return [InputExample(guid="input_text", words=words, labels=labels)]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        tokens = []
        label_ids = []
        for word, label in zip(example.words, example.labels):
            word_tokens = self.tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            label_ids.extend([self.labels.index(label)] * len(word_tokens))
        if len(tokens) > self.max_seq_length - 2:
            tokens = tokens[: (self.max_seq_length - 2)]
            label_ids = label_ids[: (self.max_seq_length - 2)]
        tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
        label_ids = [self.labels.index("O")] + label_ids + [self.labels.index("O")]
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1] * len(input_ids)
        padding_length = self.max_seq_length - len(input_ids)
        input_ids += [self.tokenizer.pad_token_id] * padding_length
        attention_mask += [0] * padding_length
        label_ids += [self.labels.index("O")] * padding_length

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }

def main(input_text=None):
    # 固定参数
    model_name_or_path = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2/"
    data_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name"
    output_dir = "/home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test"
    labels_path = '/home/data/t200404/bioinfo/P_subject/NLP/biobert/datasets/for_train/datasets_from_download/NER/lipid/2_LipidCorpus_Normalized.Name/'+ 'labels.txt' # 或者 "path/to/labels.txt"
    max_seq_length = 384
    overwrite_cache = False
    cache_dir = None
    use_fast_tokenizer = False
    seed = 42

    logging.basicConfig(
        format="%(asctime)s - %(levellevelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info("Training/evaluation parameters %s", output_dir)

    set_seed(seed)

    labels = get_labels(labels_path)
    label_map = {i: label for i, label in enumerate(labels)}
    num_labels = len(labels)

    config = AutoConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_labels,
        id2label=label_map,
        label2id={label: i for i, label in enumerate(labels)},
        cache_dir=cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        cache_dir=cache_dir,
        use_fast=use_fast_tokenizer,
    )
    model = AutoModelForTokenClassification.from_pretrained(
        model_name_or_path,
        config=config,
        cache_dir=cache_dir,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)
        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    def predict_text(input_text):
        # 创建临时数据集
        dataset = CustomNerDataset(tokenizer, input_text, labels, max_seq_length)
        predictions, label_ids, metrics = trainer.predict(dataset)
        preds_list, _ = align_predictions(predictions, label_ids)
        return preds_list

    # 固定的训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        do_predict=True,
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
    )

    # Predict
    if input_text:
        preds_list = predict_text(input_text)
        print("Predictions:", preds_list)
    else:
        test_dataset = NerDataset(
            data_dir=data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=max_seq_length,
            overwrite_cache=overwrite_cache,
            mode=Split.test,
        )

        predictions, label_ids, metrics = trainer.predict(test_dataset)
        preds_list, _ = align_predictions(predictions, label_ids)

        # Save predictions
        output_test_results_file = os.path.join(output_dir, "test_results.txt")
        with open(output_test_results_file, "w") as writer:
            logger.info("***** Test results *****")
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

        output_test_predictions_file = os.path.join(output_dir, "test_predictions.txt")
        with open(output_test_predictions_file, "w") as writer:
            with open(os.path.join(data_dir, "test.txt"), "r") as f:
                example_id = 0
                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)
                        if not preds_list[example_id]:
                            example_id += 1
                    elif preds_list[example_id]:
                        entity_label = preds_list[example_id].pop(0)
                        output_line = line.split()[0] + " " + entity_label + "\n"
                        writer.write(output_line)
                    else:
                        logger.warning(
                            "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
                        )

if __name__ == "__main__":
    input_text = "Thus, changes in plasma  PC(20:1) levels, plasma S1P d18:1 levels, plasma MonCer d18:1 levels or plasma LacCer d18:1 levels were inferred to be disease-induced changes in Alzheimer's disease or DLB"
    main(input_text)


07/18/2024 16:06:19 - INFO - __main__ -   Training/evaluation parameters /home/data/t200404/bioinfo/P_subject/NLP/biobert/biobertModelWarehouse/model_from_trained/NER/4_combine_lipid_disease_ture_combine_2_for_test
dataloader_config = DataLoaderConfiguration(dispatch_batches=None)


Predictions: [['O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'O', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-lipid', 'I-lipid', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-disease', 'I-disease', 'I-disease', 'I-disease', 'I-disease', 'I-disease', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'I-disease', 'O', 'I-lipid', 'O', 'O', 'O', 'I-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-disease', 'I-disease', 'I-disease', 'O', 'I-disease', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-lipid', 'B-lipid', 'I-disease', 'B-lipid', 'O', 'I-lipid', 'I-disease', 'I-disease', 'I-dis