# Автоматическое резюмирование медицинских публикаций: извлечение результатов и выводов

Импорт библиотек

In [2]:
import torch
print(torch.__version__)

2.1.0+cu118


In [3]:
import numpy as np
print(np.__version__)

1.23.5


In [4]:
import accelerate
print(accelerate.__version__)

1.8.1


In [43]:
import os
import zipfile
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from datasets import Dataset
from rouge_score import rouge_scorer
import random

In [9]:
from transformers import LEDTokenizer, LEDForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq

Датасет SUMPUBMED содержит ~33 000 научных медицинских статей, собранных с платформы PubMed / BioMed Central (BMC). Предназначен для задачи абстрактного резюмирования — генерации кратких, содержательных аннотаций на основе полного текста статьи.

In [8]:
!kaggle datasets download -d chandrasekhardcs/sumpubmed-dataset

sumpubmed-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)


In [9]:
zip_path = "./sumpubmed-dataset.zip"
extract_to = "./sumpubmed_dataset"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

In [17]:
df = pd.read_csv("./sumpubmed_dataset/SUMPUBMED.csv")
df.head()

Unnamed: 0.1,Unnamed: 0,line_text,filename_text,text,shorter_abstract,abstract
0,0,BACKGROUND\nthe skeleton is a multifunctional ...,700,BACKGROUND\nthe skeleton is a multifunctional ...,"during mouse development, obif is initially ob...",BACKGROUND\nwhile several cell types are known...
1,1,BACKGROUND\nthe exact interactions of the sing...,10862,BACKGROUND\nthe exact interactions of the sing...,elective analysis of the effects of this drug ...,BACKGROUND\nglucocorticoids have been proven t...
2,2,BACKGROUND\nultra-high-throughput sequencing i...,21467,BACKGROUND\nultra-high-throughput sequencing i...,the processing and statistical analysis of suc...,BACKGROUND\nsolexa/illumina short-read ultra-h...
3,3,BACKGROUND\nlignin is a phenolic heteropolymer...,26308,BACKGROUND\nlignin is a phenolic heteropolymer...,many cad/cad-like genes do not seem to be asso...,BACKGROUND\ncinnamyl alcohol dehydrogenase pr...
4,4,BACKGROUND\nidentification of genetic variants...,9879,BACKGROUND\nidentification of genetic variants...,the frequency of lof variants differed greatly...,"BACKGROUND\nover the last few years, continuou..."


Модель LED (Longformer-Encoder-Decoder), которая поддерживает обработку длинных статей (до 16к токенов)

In [7]:
model_name = "allenai/led-base-16384"
tokenizer = LEDTokenizer.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name, gradient_checkpointing=True, use_safetensors=True).to("cuda")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [10]:
special_tokens = ["<results>", "<conclusions>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
model.resize_token_embeddings(len(tokenizer))

Embedding(50267, 768, padding_idx=1)

Подготовка данных к обучению

In [18]:
df.drop(columns=["Unnamed: 0", "line_text", "filename_text", "shorter_abstract"], inplace=True)

In [19]:
def extract_results_conclusions(text):
    pattern = r"(BACKGROUND.*?)(RESULTS.*?)(CONCLUSIONS.*?)(?=\n\n|$)"
    matches = re.search(pattern, text, re.DOTALL)
    cleaned = "".join([matches.group(2), matches.group(3)])
    return cleaned

In [20]:
def remove_section_headers(text):
    cleaned = re.sub(r"\b(?:RESULTS|CONCLUSIONS)\b[:\s]*", "", text, flags=re.IGNORECASE)
    cleaned = re.sub(r"\s*\n\s*", " ", cleaned)
    cleaned = re.sub(r"\s{2,}", " ", cleaned).strip()
    return cleaned

In [21]:
def preprocess_abstract(abstract):
    processed_abstract = (
        "<results> " + re.search(r"RESULTS(.*?)CONCLUSIONS", abstract, re.DOTALL).group(1).strip() +
        " <conclusions> " + abstract.split("CONCLUSIONS")[1].strip()
    )
    return processed_abstract

In [22]:
df['text'] = df['text'].apply(extract_results_conclusions)
df['abstract'] = df['abstract'].apply(extract_results_conclusions)

df['text'] = df['text'].apply(remove_section_headers)
df['abstract'] = df['abstract'].apply(preprocess_abstract)

Оставляем тексты до 4096 токенов и абстракты до 256 токенов

In [23]:
def checker(processed_text, processed_abstract):
    if len(tokenizer.tokenize(processed_text)) <= 4096 and len(tokenizer.tokenize(processed_abstract)) <= 256:
        return True
    return False

In [24]:
def preprocess_and_filter(df):
    to_drop = []
    
    for i in range(len(df)):
        if not checker(df.loc[i, 'text'], df.loc[i, 'abstract']):
            to_drop.append(i)
            
    df = df.drop(to_drop)
    return df.reset_index(drop=True)

In [25]:
df = preprocess_and_filter(df)
print(len(df))

11807


In [26]:
def create_dataset(df):
    return Dataset.from_dict({
        "input_text": df['text'].tolist(),
        "target_text": df['abstract'].tolist()
    })

In [27]:
train_df, val_df = train_test_split(df, test_size=0.1)

In [28]:
train_dataset = create_dataset(train_df)
val_dataset = create_dataset(val_df)

In [29]:
def tokenize_function(examples):
    inputs = tokenizer(
        examples["input_text"],
        max_length=4096,
        truncation=True,
        padding="max_length"
    )
    outputs = tokenizer(
        examples["target_text"],
        max_length=256,
        truncation=True,
        padding="max_length"
    )
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": outputs["input_ids"]
    }

In [30]:
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)

Map: 100%|██████████| 10626/10626 [03:07<00:00, 56.72 examples/s]
Map: 100%|██████████| 1181/1181 [00:21<00:00, 53.81 examples/s]


Обучение модели

In [24]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./led_summarization",  
    eval_strategy="steps",          
    eval_steps=500,                     
    save_steps=500,                       
    logging_steps=100,                    
    learning_rate=3e-5,                     
    per_device_train_batch_size=2,        
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,         
    warmup_steps=500,                    
    num_train_epochs=3,                     
    predict_with_generate=True,           
    fp16=True,                             
    load_best_model_at_end=True,          
    metric_for_best_model="eval_loss",      
    greater_is_better=False,
    report_to="tensorboard",               
)

In [25]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    processing_class=tokenizer
)

In [26]:
trainer.train()

  3%|▎         | 100/3987 [10:45<6:57:45,  6.45s/it]

{'loss': 5.8725, 'grad_norm': 17.915178298950195, 'learning_rate': 5.82e-06, 'epoch': 0.08}


  5%|▌         | 200/3987 [21:30<6:47:02,  6.45s/it]

{'loss': 3.5465, 'grad_norm': 12.345606803894043, 'learning_rate': 1.1760000000000001e-05, 'epoch': 0.15}


  8%|▊         | 300/3987 [32:14<6:36:23,  6.45s/it]

{'loss': 2.5325, 'grad_norm': 5.908603191375732, 'learning_rate': 1.776e-05, 'epoch': 0.23}


 10%|█         | 400/3987 [42:59<6:25:37,  6.45s/it]

{'loss': 2.0477, 'grad_norm': 2.879793643951416, 'learning_rate': 2.3760000000000003e-05, 'epoch': 0.3}


 13%|█▎        | 500/3987 [53:44<6:14:47,  6.45s/it]

{'loss': 1.9461, 'grad_norm': 2.8475658893585205, 'learning_rate': 2.976e-05, 'epoch': 0.38}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:39,  5.90it/s][A
  1%|          | 3/591 [00:00<02:22,  4.12it/s][A
  1%|          | 4/591 [00:01<02:45,  3.56it/s][A
  1%|          | 5/591 [00:01<02:57,  3.30it/s][A
  1%|          | 6/591 [00:01<03:05,  3.15it/s][A
  1%|          | 7/591 [00:02<03:11,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:13,  3.01it/s][A
  2%|▏         | 9/591 [00:02<03:15,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.91it/s][A
  2%|▏         | 13/591 [00:04<03:19,  2.90it/s][A
  2%|▏         | 14/591 [00:04<03:19,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 18/591 [00:05<03:17,  2.90it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.8689097166061401, 'eval_runtime': 204.347, 'eval_samples_per_second': 5.779, 'eval_steps_per_second': 2.892, 'epoch': 0.38}


 15%|█▌        | 600/3987 [1:08:27<6:04:01,  6.45s/it]

{'loss': 1.9092, 'grad_norm': 3.3743855953216553, 'learning_rate': 2.9174075136220246e-05, 'epoch': 0.45}


 18%|█▊        | 700/3987 [1:19:12<5:53:01,  6.44s/it]

{'loss': 1.8978, 'grad_norm': 2.475923776626587, 'learning_rate': 2.831373673644967e-05, 'epoch': 0.53}


 20%|██        | 800/3987 [1:29:57<5:42:21,  6.45s/it]

{'loss': 1.8639, 'grad_norm': 3.261599063873291, 'learning_rate': 2.7453398336679095e-05, 'epoch': 0.6}


 23%|██▎       | 900/3987 [1:40:42<5:31:39,  6.45s/it]

{'loss': 1.8311, 'grad_norm': 2.476743221282959, 'learning_rate': 2.659305993690852e-05, 'epoch': 0.68}


 25%|██▌       | 1000/3987 [1:51:26<5:20:50,  6.44s/it]

{'loss': 1.8085, 'grad_norm': 2.4143636226654053, 'learning_rate': 2.5732721537137943e-05, 'epoch': 0.75}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:40,  5.88it/s][A
  1%|          | 3/591 [00:00<02:22,  4.13it/s][A
  1%|          | 4/591 [00:01<02:45,  3.55it/s][A
  1%|          | 5/591 [00:01<02:57,  3.30it/s][A
  1%|          | 6/591 [00:01<03:05,  3.15it/s][A
  1%|          | 7/591 [00:02<03:10,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:14,  3.00it/s][A
  2%|▏         | 9/591 [00:02<03:16,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.91it/s][A
  2%|▏         | 13/591 [00:04<03:19,  2.90it/s][A
  2%|▏         | 14/591 [00:04<03:19,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:19,  2.89it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 18/591 [00:05<03:17,  2.90it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.89it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.7816846370697021, 'eval_runtime': 204.3857, 'eval_samples_per_second': 5.778, 'eval_steps_per_second': 2.892, 'epoch': 0.75}


 28%|██▊       | 1100/3987 [2:06:09<5:10:17,  6.45s/it] 

{'loss': 1.7939, 'grad_norm': 2.500584125518799, 'learning_rate': 2.4872383137367364e-05, 'epoch': 0.83}


 30%|███       | 1200/3987 [2:16:55<4:59:28,  6.45s/it]

{'loss': 1.7863, 'grad_norm': 2.6468656063079834, 'learning_rate': 2.401204473759679e-05, 'epoch': 0.9}


 33%|███▎      | 1300/3987 [2:27:40<4:48:40,  6.45s/it]

{'loss': 1.7708, 'grad_norm': 2.524620532989502, 'learning_rate': 2.315170633782621e-05, 'epoch': 0.98}


 35%|███▌      | 1400/3987 [2:38:19<4:38:01,  6.45s/it]

{'loss': 1.7069, 'grad_norm': 2.3090741634368896, 'learning_rate': 2.2291367938055637e-05, 'epoch': 1.05}


 38%|███▊      | 1500/3987 [2:49:04<4:27:19,  6.45s/it]

{'loss': 1.6759, 'grad_norm': 2.3924901485443115, 'learning_rate': 2.143102953828506e-05, 'epoch': 1.13}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:40,  5.88it/s][A
  1%|          | 3/591 [00:00<02:22,  4.13it/s][A
  1%|          | 4/591 [00:01<02:45,  3.56it/s][A
  1%|          | 5/591 [00:01<02:57,  3.30it/s][A
  1%|          | 6/591 [00:01<03:05,  3.15it/s][A
  1%|          | 7/591 [00:02<03:10,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:14,  3.00it/s][A
  2%|▏         | 9/591 [00:02<03:16,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 13/591 [00:04<03:18,  2.91it/s][A
  2%|▏         | 14/591 [00:04<03:19,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 18/591 [00:05<03:17,  2.90it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.7548778057098389, 'eval_runtime': 204.2668, 'eval_samples_per_second': 5.782, 'eval_steps_per_second': 2.893, 'epoch': 1.13}


 40%|████      | 1600/3987 [3:03:47<4:16:37,  6.45s/it] 

{'loss': 1.7159, 'grad_norm': 2.5115864276885986, 'learning_rate': 2.0570691138514483e-05, 'epoch': 1.2}


 43%|████▎     | 1700/3987 [3:14:32<4:05:43,  6.45s/it]

{'loss': 1.6748, 'grad_norm': 2.550893545150757, 'learning_rate': 1.9710352738743904e-05, 'epoch': 1.28}


 45%|████▌     | 1800/3987 [3:25:17<3:55:00,  6.45s/it]

{'loss': 1.6754, 'grad_norm': 2.282221555709839, 'learning_rate': 1.885001433897333e-05, 'epoch': 1.35}


 48%|████▊     | 1900/3987 [3:36:01<3:44:11,  6.45s/it]

{'loss': 1.6881, 'grad_norm': 2.3332366943359375, 'learning_rate': 1.7989675939202756e-05, 'epoch': 1.43}


 50%|█████     | 2000/3987 [3:46:46<3:33:29,  6.45s/it]

{'loss': 1.675, 'grad_norm': 2.2867679595947266, 'learning_rate': 1.7129337539432177e-05, 'epoch': 1.51}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:40,  5.88it/s][A
  1%|          | 3/591 [00:00<02:22,  4.13it/s][A
  1%|          | 4/591 [00:01<02:45,  3.55it/s][A
  1%|          | 5/591 [00:01<02:58,  3.29it/s][A
  1%|          | 6/591 [00:01<03:05,  3.15it/s][A
  1%|          | 7/591 [00:02<03:10,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:14,  3.00it/s][A
  2%|▏         | 9/591 [00:02<03:15,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 13/591 [00:04<03:18,  2.91it/s][A
  2%|▏         | 14/591 [00:04<03:19,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:19,  2.89it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 18/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.7262248992919922, 'eval_runtime': 204.3277, 'eval_samples_per_second': 5.78, 'eval_steps_per_second': 2.892, 'epoch': 1.51}


 53%|█████▎    | 2100/3987 [4:01:29<3:22:52,  6.45s/it] 

{'loss': 1.6459, 'grad_norm': 2.3459630012512207, 'learning_rate': 1.62689991396616e-05, 'epoch': 1.58}


 55%|█████▌    | 2200/3987 [4:12:14<3:12:04,  6.45s/it]

{'loss': 1.6732, 'grad_norm': 2.169931173324585, 'learning_rate': 1.5408660739891022e-05, 'epoch': 1.66}


 58%|█████▊    | 2300/3987 [4:22:58<3:01:21,  6.45s/it]

{'loss': 1.6541, 'grad_norm': 2.369981288909912, 'learning_rate': 1.4548322340120447e-05, 'epoch': 1.73}


 60%|██████    | 2400/3987 [4:33:43<2:50:33,  6.45s/it]

{'loss': 1.6687, 'grad_norm': 2.4304423332214355, 'learning_rate': 1.3687983940349873e-05, 'epoch': 1.81}


 63%|██████▎   | 2500/3987 [4:44:28<2:39:49,  6.45s/it]

{'loss': 1.661, 'grad_norm': 2.2869393825531006, 'learning_rate': 1.2827645540579295e-05, 'epoch': 1.88}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:39,  5.89it/s][A
  1%|          | 3/591 [00:00<02:22,  4.13it/s][A
  1%|          | 4/591 [00:01<02:44,  3.56it/s][A
  1%|          | 5/591 [00:01<02:57,  3.30it/s][A
  1%|          | 6/591 [00:01<03:05,  3.15it/s][A
  1%|          | 7/591 [00:02<03:10,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:14,  3.00it/s][A
  2%|▏         | 9/591 [00:02<03:16,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 13/591 [00:04<03:19,  2.90it/s][A
  2%|▏         | 14/591 [00:04<03:19,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:19,  2.89it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 18/591 [00:05<03:18,  2.89it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.71266770362854, 'eval_runtime': 204.3245, 'eval_samples_per_second': 5.78, 'eval_steps_per_second': 2.892, 'epoch': 1.88}


 65%|██████▌   | 2600/3987 [4:59:11<2:29:02,  6.45s/it] 

{'loss': 1.6726, 'grad_norm': 2.223902702331543, 'learning_rate': 1.1967307140808718e-05, 'epoch': 1.96}


 68%|██████▊   | 2700/3987 [5:09:50<2:18:10,  6.44s/it]

{'loss': 1.6268, 'grad_norm': 2.4879283905029297, 'learning_rate': 1.1106968741038142e-05, 'epoch': 2.03}


 70%|███████   | 2800/3987 [5:20:35<2:07:25,  6.44s/it]

{'loss': 1.5503, 'grad_norm': 2.3781542778015137, 'learning_rate': 1.0246630341267565e-05, 'epoch': 2.11}


 73%|███████▎  | 2900/3987 [5:31:20<1:57:20,  6.48s/it]

{'loss': 1.5801, 'grad_norm': 2.354300022125244, 'learning_rate': 9.38629194149699e-06, 'epoch': 2.18}


 75%|███████▌  | 3000/3987 [5:42:04<1:46:01,  6.44s/it]

{'loss': 1.5842, 'grad_norm': 2.443019390106201, 'learning_rate': 8.525953541726412e-06, 'epoch': 2.26}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:40,  5.89it/s][A
  1%|          | 3/591 [00:00<02:22,  4.14it/s][A
  1%|          | 4/591 [00:01<02:44,  3.56it/s][A
  1%|          | 5/591 [00:01<02:57,  3.29it/s][A
  1%|          | 6/591 [00:01<03:05,  3.16it/s][A
  1%|          | 7/591 [00:02<03:10,  3.07it/s][A
  1%|▏         | 8/591 [00:02<03:13,  3.01it/s][A
  2%|▏         | 9/591 [00:02<03:15,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.95it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 13/591 [00:04<03:18,  2.91it/s][A
  2%|▏         | 14/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 18/591 [00:05<03:17,  2.90it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.7099876403808594, 'eval_runtime': 203.9146, 'eval_samples_per_second': 5.792, 'eval_steps_per_second': 2.898, 'epoch': 2.26}


 78%|███████▊  | 3100/3987 [5:56:47<1:35:20,  6.45s/it] 

{'loss': 1.589, 'grad_norm': 2.580486297607422, 'learning_rate': 7.665615141955837e-06, 'epoch': 2.33}


 80%|████████  | 3200/3987 [6:07:32<1:24:33,  6.45s/it]

{'loss': 1.5565, 'grad_norm': 2.5114998817443848, 'learning_rate': 6.805276742185259e-06, 'epoch': 2.41}


 83%|████████▎ | 3300/3987 [6:18:16<1:13:49,  6.45s/it]

{'loss': 1.5692, 'grad_norm': 2.3136117458343506, 'learning_rate': 5.944938342414684e-06, 'epoch': 2.48}


 85%|████████▌ | 3400/3987 [6:29:01<1:03:02,  6.44s/it]

{'loss': 1.5473, 'grad_norm': 2.5220417976379395, 'learning_rate': 5.084599942644107e-06, 'epoch': 2.56}


 88%|████████▊ | 3500/3987 [6:39:46<52:20,  6.45s/it]  

{'loss': 1.5862, 'grad_norm': 2.3528950214385986, 'learning_rate': 4.224261542873531e-06, 'epoch': 2.63}



  0%|          | 0/591 [00:00<?, ?it/s][A
  0%|          | 2/591 [00:00<01:40,  5.88it/s][A
  1%|          | 3/591 [00:00<02:22,  4.13it/s][A
  1%|          | 4/591 [00:01<02:45,  3.56it/s][A
  1%|          | 5/591 [00:01<02:57,  3.30it/s][A
  1%|          | 6/591 [00:01<03:05,  3.16it/s][A
  1%|          | 7/591 [00:02<03:10,  3.06it/s][A
  1%|▏         | 8/591 [00:02<03:13,  3.01it/s][A
  2%|▏         | 9/591 [00:02<03:16,  2.97it/s][A
  2%|▏         | 10/591 [00:03<03:17,  2.94it/s][A
  2%|▏         | 11/591 [00:03<03:18,  2.93it/s][A
  2%|▏         | 12/591 [00:03<03:18,  2.92it/s][A
  2%|▏         | 13/591 [00:04<03:18,  2.91it/s][A
  2%|▏         | 14/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 15/591 [00:04<03:18,  2.90it/s][A
  3%|▎         | 16/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 17/591 [00:05<03:18,  2.90it/s][A
  3%|▎         | 18/591 [00:05<03:17,  2.90it/s][A
  3%|▎         | 19/591 [00:06<03:17,  2.90it/s][A
  3%|▎         | 20/591 [00:

{'eval_loss': 1.701791763305664, 'eval_runtime': 204.202, 'eval_samples_per_second': 5.783, 'eval_steps_per_second': 2.894, 'epoch': 2.63}


 90%|█████████ | 3600/3987 [6:54:28<41:34,  6.44s/it]   

{'loss': 1.5607, 'grad_norm': 2.10062575340271, 'learning_rate': 3.3639231431029538e-06, 'epoch': 2.71}


 93%|█████████▎| 3700/3987 [7:05:13<30:49,  6.44s/it]

{'loss': 1.5792, 'grad_norm': 2.513568878173828, 'learning_rate': 2.5035847433323777e-06, 'epoch': 2.78}


 95%|█████████▌| 3800/3987 [7:15:58<20:05,  6.45s/it]

{'loss': 1.5455, 'grad_norm': 2.3291306495666504, 'learning_rate': 1.643246343561801e-06, 'epoch': 2.86}


 98%|█████████▊| 3900/3987 [7:26:42<09:20,  6.45s/it]

{'loss': 1.5717, 'grad_norm': 2.3142495155334473, 'learning_rate': 7.829079437912246e-07, 'epoch': 2.94}


100%|██████████| 3987/3987 [7:35:58<00:00,  5.00s/it]There were missing keys in the checkpoint model loaded: ['led.encoder.embed_tokens.weight', 'led.decoder.embed_tokens.weight', 'lm_head.weight'].
100%|██████████| 3987/3987 [7:36:32<00:00,  6.87s/it]

{'train_runtime': 27392.8616, 'train_samples_per_second': 1.164, 'train_steps_per_second': 0.146, 'train_loss': 1.8613673301756097, 'epoch': 3.0}





TrainOutput(global_step=3987, training_loss=1.8613673301756097, metrics={'train_runtime': 27392.8616, 'train_samples_per_second': 1.164, 'train_steps_per_second': 0.146, 'total_flos': 8.607712972426445e+16, 'train_loss': 1.8613673301756097, 'epoch': 3.0})

In [27]:
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)

('./led_summarization/tokenizer_config.json',
 './led_summarization/special_tokens_map.json',
 './led_summarization/vocab.json',
 './led_summarization/merges.txt',
 './led_summarization/added_tokens.json')

# Inference на тестовой выборке

In [42]:
pretrained_tokenizer = LEDTokenizer.from_pretrained("./led_summarization")
pretrained_model = LEDForConditionalGeneration.from_pretrained("./led_summarization").to("cuda")

In [51]:
sample_indices = random.sample(range(len(tokenized_val)), 3)

In [52]:
def format_medical_summary(generated_text):
    clean_text = generated_text.replace('</s>', '').replace('<s>', '').strip()
    results_section = ''
    conclusions_section = ''
    
    if '<results>' in clean_text:
        results_part = clean_text.split('<results>')[1]
        results_section = results_part.split('<conclusions>')[0].strip()
    
    if '<conclusions>' in clean_text:
        conclusions_part = clean_text.split('<conclusions>')[1]
        conclusions_section = conclusions_part.split('<dig>')[0].strip()
    
    formatted_output = ""
    if results_section:
        results_section = results_section[0].upper() + results_section[1:]
        formatted_output += "RESULTS:\n" + results_section + "\n\n"
    
    if conclusions_section:
        conclusions_section = conclusions_section[0].upper() + conclusions_section[1:]
        formatted_output += "CONCLUSIONS:\n" + conclusions_section
    
    return formatted_output.strip()

In [54]:
for idx in sample_indices:
    sample = tokenized_val[idx]
    input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to("cuda")
    
    outputs = pretrained_model.generate(
        input_ids,
        max_length=256,
        num_beams=4,
        no_repeat_ngram_size=4,      
        repetition_penalty=2.0,
        early_stopping=True
    )
    
    generated_summary = pretrained_tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    print("\nGenerated Summary:")
    print(format_medical_summary(generated_summary))
    print("\n" + "="*50)


Generated Summary:
RESULTS:
We developed a milp approach to compute for a given large metabolic network one or more minimum subnetworks preserving biological requirements that can be specified by the user. compared to previous work  <cit> , our method guarantees minimality of the subnetwork regarding the number of active reactions while preserving all the given requirements. in case there exist several minimum solutions, we are able to enumerate all of them. this may give additional insight how the network is functioning and which reactions are really needed to satisfy the requirements. we applied our algorithms to several genome-scale metabolic networks and we always found all the maximum subnetworks in reasonable time.

CONCLUSIONS:
We developed an milp approach using indicator variables and some other features of cplex. we implemented our algorithms in matlab and we were able to find all the minimum subsnetworks in normal time.


Generated Summary:
RESULTS:
Here we present peakanal

# Оценка качества модели: метрика ROUGE

In [14]:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

In [32]:
def calculate_rouge(predictions, references):
    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []
    
    for pred, ref in zip(predictions, references):
        scores = scorer.score(ref, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)
    
    return {
        'rouge1': np.mean(rouge1_scores),
        'rouge2': np.mean(rouge2_scores),
        'rougeL': np.mean(rougeL_scores)
    }

In [38]:
def generate_predictions(model, tokenizer, tokenized_dataset, device="cuda"):
    predictions = []
    references = []
    
    model.to(device)
    model.eval()
    
    for sample in tokenized_dataset:
        input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model.generate(
                input_ids,
                max_length=256,
                num_beams=4,
                early_stopping=True,
                length_penalty=1.2,
                no_repeat_ngram_size=3,
                repetition_penalty=1.5
            )
        
        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        ref = tokenizer.decode(sample["labels"], skip_special_tokens=True)
        
        predictions.append(pred)
        references.append(ref)
    
    return predictions, references

In [40]:
predictions, references = generate_predictions(pretrained_model, pretrained_tokenizer, tokenized_val)

In [41]:
rouge_scores = calculate_rouge(predictions, references)
print(f"ROUGE-1: {rouge_scores['rouge1']:.4f}")
print(f"ROUGE-2: {rouge_scores['rouge2']:.4f}")
print(f"ROUGE-L: {rouge_scores['rougeL']:.4f}")

ROUGE-1: 0.4547
ROUGE-2: 0.1718
ROUGE-L: 0.2569
