<a href="https://colab.research.google.com/github/cxrlton/AMR/blob/main/pre_processed_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [106]:
!pip3 install datasets

import datasets
from datasets import load_dataset



In [107]:
dataset = load_dataset("Tverous/anli_amr_new2", split="train")

In [108]:
dataset[0]

{'uid': '2093cfb3-a15f-4282-81e3-0cb793ffd0d7',
 'premise': 'TOKYO, Dec 18 (Reuters) - Japan’s Shionogi & Co said on Tuesday that it has applied to health regulators in the United States, Canada and Europe for approval of its HIV drug Dolutegravir. Shionogi developed Dolutegravir with a Viiv Healthcare, an AIDS drug joint venture between GlaxoSmithKline and Pfizer, in exchange for its rights to the drug.',
 'hypothesis': 'The article was written on December 18th.',
 'label': 0,
 'reason': 'TOKYO, Dec 18 (Reuters) is when the article was written as it states in the first words of the sentence',
 'linearized_amr': '( z0 write :ARG1 ( z1 article ) :time ( z2 date-entity :day 18 :month 12 ) )',
 'amr_penman': '(z0 / write-01\n    :ARG1 (z1 / article)\n    :time (z2 / date-entity\n              :day 18\n              :month 12))',
 'amr_tokens': ['The',
  'article',
  'was',
  'written',
  'on',
  'December',
  '18th',
  '.'],
 'amr_nodes': "{'z1': 'article', 'z0': 'write-01', 'z2': 'date-e

In [109]:
!pip install penman



In [111]:
import penman
import re
import pandas as pd
from collections import defaultdict
from typing import Dict, List, Tuple, Set

class AMRPreprocessor:
    def __init__(self):
        self.ne_type_mapping = {
            'person': 'person',
            'team': 'organization',
            'country': 'location',
            'city': 'location',
            'state': 'location',
            'organization': 'organization',
            'company': 'organization',
            'government-organization': 'organization',
            'group': 'organization'
        }
        self.anon_counter = defaultdict(int)

    def preprocess_amr(self, amr_str: str) -> Tuple[str, Dict]:
        try:
            # Initialize entity map
            entity_map = {}

            # Handle multi-line AMR string
            amr_str = amr_str.replace('\n', ' ')

            # Process named entities
            for ne_type in self.ne_type_mapping:
                # Updated pattern to match the exact structure
                pattern = fr'\(\s*(?:z\d+\s*/\s*)?{ne_type}[^)]*:name\s*\([^)]*:op1\s*"([^"]+)"'
                matches = list(re.finditer(pattern, amr_str))

                for match in matches:
                    entity_name = match.group(1)
                    coarse_type = self.ne_type_mapping[ne_type]
                    anon_id = f"{coarse_type}_{self.anon_counter[coarse_type]}"
                    self.anon_counter[coarse_type] += 1

                    # Store the mapping
                    entity_map[anon_id] = entity_name

                    # Create replacement pattern
                    name_section = f':name\\s*\\([^)]*:op1\\s*"{entity_name}"[^)]*\\)'
                    replacement = f' {anon_id}'

                    # Replace the name section with the anonymous ID
                    amr_str = re.sub(name_section, replacement, amr_str)

            # Process dates
            amr_str = re.sub(r':year\s+(\d{4})', r':year year_\1', amr_str)
            amr_str = re.sub(r':month\s+(\d{1,2})', r':month month_\1', amr_str)
            amr_str = re.sub(r':day\s+(\d{1,2})', r':day day_\1', amr_str)

            # Clean up variables
            amr_str = re.sub(r'\(z\d+\s*/', '(', amr_str)
            amr_str = re.sub(r'\sz\d+\s', ' ', amr_str)
            amr_str = re.sub(r'\s*\/\s*', ' ', amr_str)

            return amr_str, entity_map

        except Exception as e:
            print(f"Error in preprocess_amr: {e}")
            return amr_str, {}

# # Test with specific examples
# test_cases = [
#     """(z0 / and
#     :op1 (z1 / beat-03
#              :ARG0 (z2 / team
#                        :name (z3 / name
#                                  :op1 "Carlton"))
#              :ARG1 (z4 / team
#                        :name (z5 / name
#                                  :op1 "Melbourne"))
#              :time (z6 / date-entity
#                        :year 2016)))""",

#     """(z0 / urge-01
#     :ARG0 (z1 / person
#               :name (z2 / name
#                         :op1 "Gillum")
#               :medium (z3 / television))
#     :ARG1 (z4 / person
#               :ARG0-of (z5 / reside-01)))"""
# ]

# preprocessor = AMRPreprocessor()

# print("Testing examples:")
# for i, test_amr in enumerate(test_cases):
#     print(f"\nTest case {i + 1}:")
#     print("Original:", test_amr)
#     processed_amr, entity_map = preprocessor.preprocess_amr(test_amr)
#     print("\nProcessed:", processed_amr)
#     print("Entity Map:", entity_map)
#     print("-" * 80)


In [112]:
import os
os.environ["WANDB_MODE"] = "dryrun"

In [113]:
def extract_amr_and_text(dataset):
    amrs = []
    texts = []

    for example in dataset:
        amr_graph = example.get('amr_penman', '')
        text = example.get('hypothesis', '')

        if amr_graph and text:
            amrs.append(amr_graph)
            texts.append(text)

    return amrs, texts

In [114]:
amrs, texts = extract_amr_and_text(dataset)

In [115]:
preprocessor = AMRPreprocessor()
processed_amrs = []

In [116]:
data_amr = pd.DataFrame({
    'amr_graph': amrs,
    'text': texts
})

In [117]:
for idx, row in data_amr.iterrows():
    try:
        processed_amr, _ = preprocessor.preprocess_amr(row['amr_graph'])
        processed_amrs.append(processed_amr)

        if idx % 1000 == 0:
            print(f"Preprocessed {idx} AMRs...")

    except Exception as e:
        print(f"Error processing row {idx}: {e}")
        processed_amrs.append(row['amr_graph'])

# Add processed AMRs to DataFrame
data_amr['processed_amr'] = processed_amrs

Preprocessed 0 AMRs...
Preprocessed 1000 AMRs...
Preprocessed 2000 AMRs...
Preprocessed 3000 AMRs...
Preprocessed 4000 AMRs...
Preprocessed 5000 AMRs...
Preprocessed 6000 AMRs...
Preprocessed 7000 AMRs...
Preprocessed 8000 AMRs...
Preprocessed 9000 AMRs...
Preprocessed 10000 AMRs...
Preprocessed 11000 AMRs...
Preprocessed 12000 AMRs...
Preprocessed 13000 AMRs...
Preprocessed 14000 AMRs...
Preprocessed 15000 AMRs...
Preprocessed 16000 AMRs...
Preprocessed 17000 AMRs...
Preprocessed 18000 AMRs...
Preprocessed 19000 AMRs...
Preprocessed 20000 AMRs...
Preprocessed 21000 AMRs...
Preprocessed 22000 AMRs...
Preprocessed 23000 AMRs...
Preprocessed 24000 AMRs...
Preprocessed 25000 AMRs...
Preprocessed 26000 AMRs...
Preprocessed 27000 AMRs...
Preprocessed 28000 AMRs...
Preprocessed 29000 AMRs...
Preprocessed 30000 AMRs...
Preprocessed 31000 AMRs...
Preprocessed 32000 AMRs...
Preprocessed 33000 AMRs...
Preprocessed 34000 AMRs...
Preprocessed 35000 AMRs...
Preprocessed 36000 AMRs...
Preprocessed 3

In [118]:
data_amr

Unnamed: 0,amr_graph,text,processed_amr
0,(z0 / write-01\n :ARG1 (z1 / article)\n ...,The article was written on December 18th.,( write-01 :ARG1 ( article) :time ( da...
1,(z0 / urge-01\n :ARG0 (z1 / person\n ...,Gillum was on TV urging residents to stay out ...,( urge-01 :ARG0 ( person pe...
2,(z0 / and\n :op1 (z1 / beat-03\n ...,Carlton beat Melbourne in 2016 and will attemp...,( and :op1 ( beat-03 :ARG0 ( ...
3,(z0 / close-01\n :ARG1 (z1 / road)\n :du...,The road was closed for more than two hours af...,( close-01 :ARG1 ( road) :duration ( m...
4,(z0 / advise-01\n :ARG2 (z1 / slow-down-03)),Its advisible to slow down,( advise-01 :ARG2 ( slow-down-03))
...,...,...,...
100454,(z0 / commit-02\n :ARG0 (z1 / person\n ...,The Soldier committed these crimes in the unit...,( commit-02 :ARG0 ( person :...
100455,(z0 / decline-01\n :ARG1 (z1 / number\n ...,there is a continuous decline in the number of...,( decline-01 :ARG1 ( number ...
100456,(z0 / experience-01\n :ARG0 (z1 / person\n ...,Bechtolsheim is experienced with computers.,( experience-01 :ARG0 ( person ...
100457,(z0 / show-01\n :ARG0 (z1 / figure)\n :A...,figures shows New Zealanders are the lowest or...,( show-01 :ARG0 ( figure) :ARG1 ( have...


In [119]:
data_amr['processed_amr'] = processed_amrs

In [120]:
data_amr

Unnamed: 0,amr_graph,text,processed_amr
0,(z0 / write-01\n :ARG1 (z1 / article)\n ...,The article was written on December 18th.,( write-01 :ARG1 ( article) :time ( da...
1,(z0 / urge-01\n :ARG0 (z1 / person\n ...,Gillum was on TV urging residents to stay out ...,( urge-01 :ARG0 ( person pe...
2,(z0 / and\n :op1 (z1 / beat-03\n ...,Carlton beat Melbourne in 2016 and will attemp...,( and :op1 ( beat-03 :ARG0 ( ...
3,(z0 / close-01\n :ARG1 (z1 / road)\n :du...,The road was closed for more than two hours af...,( close-01 :ARG1 ( road) :duration ( m...
4,(z0 / advise-01\n :ARG2 (z1 / slow-down-03)),Its advisible to slow down,( advise-01 :ARG2 ( slow-down-03))
...,...,...,...
100454,(z0 / commit-02\n :ARG0 (z1 / person\n ...,The Soldier committed these crimes in the unit...,( commit-02 :ARG0 ( person :...
100455,(z0 / decline-01\n :ARG1 (z1 / number\n ...,there is a continuous decline in the number of...,( decline-01 :ARG1 ( number ...
100456,(z0 / experience-01\n :ARG0 (z1 / person\n ...,Bechtolsheim is experienced with computers.,( experience-01 :ARG0 ( person ...
100457,(z0 / show-01\n :ARG0 (z1 / figure)\n :A...,figures shows New Zealanders are the lowest or...,( show-01 :ARG0 ( figure) :ARG1 ( have...


In [121]:
data_amr['amr_graph'] = data_amr['processed_amr']

In [122]:
data_amr

Unnamed: 0,amr_graph,text,processed_amr
0,( write-01 :ARG1 ( article) :time ( da...,The article was written on December 18th.,( write-01 :ARG1 ( article) :time ( da...
1,( urge-01 :ARG0 ( person pe...,Gillum was on TV urging residents to stay out ...,( urge-01 :ARG0 ( person pe...
2,( and :op1 ( beat-03 :ARG0 ( ...,Carlton beat Melbourne in 2016 and will attemp...,( and :op1 ( beat-03 :ARG0 ( ...
3,( close-01 :ARG1 ( road) :duration ( m...,The road was closed for more than two hours af...,( close-01 :ARG1 ( road) :duration ( m...
4,( advise-01 :ARG2 ( slow-down-03)),Its advisible to slow down,( advise-01 :ARG2 ( slow-down-03))
...,...,...,...
100454,( commit-02 :ARG0 ( person :...,The Soldier committed these crimes in the unit...,( commit-02 :ARG0 ( person :...
100455,( decline-01 :ARG1 ( number ...,there is a continuous decline in the number of...,( decline-01 :ARG1 ( number ...
100456,( experience-01 :ARG0 ( person ...,Bechtolsheim is experienced with computers.,( experience-01 :ARG0 ( person ...
100457,( show-01 :ARG0 ( figure) :ARG1 ( have...,figures shows New Zealanders are the lowest or...,( show-01 :ARG0 ( figure) :ARG1 ( have...


In [123]:
data_amr = data_amr.drop('processed_amr', axis=1)

In [125]:
data_amr

Unnamed: 0,amr_graph,text
0,( write-01 :ARG1 ( article) :time ( da...,The article was written on December 18th.
1,( urge-01 :ARG0 ( person pe...,Gillum was on TV urging residents to stay out ...
2,( and :op1 ( beat-03 :ARG0 ( ...,Carlton beat Melbourne in 2016 and will attemp...
3,( close-01 :ARG1 ( road) :duration ( m...,The road was closed for more than two hours af...
4,( advise-01 :ARG2 ( slow-down-03)),Its advisible to slow down
...,...,...
100454,( commit-02 :ARG0 ( person :...,The Soldier committed these crimes in the unit...
100455,( decline-01 :ARG1 ( number ...,there is a continuous decline in the number of...
100456,( experience-01 :ARG0 ( person ...,Bechtolsheim is experienced with computers.
100457,( show-01 :ARG0 ( figure) :ARG1 ( have...,figures shows New Zealanders are the lowest or...


In [126]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [127]:
from datasets import Dataset

dataset = Dataset.from_pandas(data_amr)


In [128]:
small_dataset = dataset.select([i for i in range(25000)])

In [129]:
small_dataset

Dataset({
    features: ['amr_graph', 'text'],
    num_rows: 25000
})

In [130]:
#Define prompt and answer templates
prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request. Instruction: {instruction}\n Response:"""
answer_template = """{response}"""

In [131]:
def _add_text(rec):
    instruction = rec["amr_graph"]  # Use amr_graph as instruction
    response = rec["text"]  # Use text as response

    # Check if both exist; raise error if not
    if not instruction:
        raise ValueError(f"Expected an instruction (amr_graph) in: {rec}")
    if not response:
        raise ValueError(f"Expected a response (text) in: {rec}")

    # Create prompt, answer, and combined text
    rec["prompt"] = prompt_template.format(instruction=instruction)
    rec["answer"] = answer_template.format(response=response)
    rec["text"] = rec["prompt"] + rec["answer"]
    return rec


In [132]:
small_dataset = small_dataset.map(_add_text)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [133]:
small_dataset

Dataset({
    features: ['amr_graph', 'text', 'prompt', 'answer'],
    num_rows: 25000
})

In [100]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2').to(device)



In [134]:
from functools import partial
from typing import Dict, List
import copy

tokenizer.pad_token = tokenizer.eos_token

MAX_LENGTH = 256

# Function to generate token embeddings from the text part of the batch
def _preprocess_batch(batch: Dict[str, List]):
    model_inputs = tokenizer(batch["text"], max_length=MAX_LENGTH, truncation=True, padding='max_length')
    model_inputs["labels"] = copy.deepcopy(model_inputs['input_ids'])
    return model_inputs

_preprocessing_function = partial(_preprocess_batch)


In [135]:
from datasets import DatasetDict

# Define the split ratios
train_test_split = small_dataset.train_test_split(test_size=0.2)  # Split off 20% as test set
train_valid_split = train_test_split['train'].train_test_split(test_size=0.1)  # From train, split 10% as validation

# Combine splits into a DatasetDict
dataset_dict = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['test'],
    'test': train_test_split['test']
})

In [136]:
#Print the size of each split to verify
print(f"Train set size: {len(dataset_dict['train'])}")
print(f"Validation set size: {len(dataset_dict['validation'])}")
print(f"Test set size: {len(dataset_dict['test'])}")

# Example check for first item in each split
print("Sample from train:", dataset_dict['train'][0])
print("Sample from validation:", dataset_dict['validation'][0])
print("Sample from test:", dataset_dict['test'][0])

Train set size: 18000
Validation set size: 2000
Test set size: 5000
Sample from train: {'amr_graph': '( write-01     :ARG0 ( person                person_4421)     :ARG1 ( fugue))', 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Instruction: ( write-01     :ARG0 ( person                person_4421)     :ARG1 ( fugue))\n Response:Bach wrote a fugue', 'prompt': 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Instruction: ( write-01     :ARG0 ( person                person_4421)     :ARG1 ( fugue))\n Response:', 'answer': 'Bach wrote a fugue'}
Sample from validation: {'amr_graph': '( possible-01     :ARG0-of ( cause-01                  :ARG1 ( drink-01                            :ARG0 ( i)                            :ARG1 ( coffee)))     :ARG1 ( stay-01               :ARG1               :ARG3 ( awake))     :polarity -)', 'text': "Below is an instruction that descr

In [137]:
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['amr_graph', 'text', 'prompt', 'answer'],
        num_rows: 18000
    })
    validation: Dataset({
        features: ['amr_graph', 'text', 'prompt', 'answer'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['amr_graph', 'text', 'prompt', 'answer'],
        num_rows: 5000
    })
})

In [138]:
encoded_train_dataset = dataset_dict['train'].map(
    _preprocessing_function,
    batched=True,
    remove_columns=["amr_graph", "text", "prompt", "answer"],
)

encoded_validation_dataset = dataset_dict['validation'].map(
    _preprocessing_function,
    batched=True,
    remove_columns=["amr_graph", "text", "prompt", "answer"],
)

encoded_test_dataset = dataset_dict['test'].map(
    _preprocessing_function,
    batched=True,
    remove_columns=["amr_graph", "text", "prompt", "answer"],
)
processed_train_dataset = encoded_train_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)
processed_validation_dataset = encoded_validation_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)
processed_test_dataset = encoded_test_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)


Map:   0%|          | 0/18000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/18000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [139]:
processed_train_dataset = encoded_train_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)
processed_validation_dataset = encoded_validation_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)
processed_test_dataset = encoded_test_dataset.filter(lambda rec: len(rec["input_ids"]) <= MAX_LENGTH)

Filter:   0%|          | 0/18000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [140]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [141]:
training_args = TrainingArguments(
    output_dir='/mnt/disks/disk1/results',  ## give the directory name where you want to save the model
    evaluation_strategy='epoch',
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,  # Accumulate gradients for 4 steps
    warmup_steps=50,
    learning_rate=5e-5,        # Lowered learning rate
    weight_decay=0.1,          # Reduced weight decay to prevent over-penalizing weights
    logging_dir='/mnt/disks/disk1/logs' ## give the directory name where you want to save the model
)




In [143]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


In [144]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_validation_dataset,
    data_collator=data_collator
)



In [145]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss
1,1.8719,0.449668
2,1.6629,0.420521
3,1.4929,0.414896


TrainOutput(global_step=13500, training_loss=1.8008159677010995, metrics={'train_runtime': 1839.956, 'train_samples_per_second': 29.349, 'train_steps_per_second': 7.337, 'total_flos': 7054884864000000.0, 'train_loss': 1.8008159677010995, 'epoch': 3.0})

In [146]:
model_output_dir = '/mnt/disks/disk1/results'
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)
model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)

('/mnt/disks/disk1/results/tokenizer_config.json',
 '/mnt/disks/disk1/results/special_tokens_map.json',
 '/mnt/disks/disk1/results/vocab.json',
 '/mnt/disks/disk1/results/merges.txt',
 '/mnt/disks/disk1/results/added_tokens.json',
 '/mnt/disks/disk1/results/tokenizer.json')

In [147]:
def get_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params


In [149]:
def main(input_text):
    # Load the tokenizer and model from the saved directory
    model_path = '/mnt/disks/disk1/results'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)

    # Calculate the number of parameters in the model being used for inference
    total_params = get_model_parameters(model)
    """
    i have commented this print stetement to avoid any print in the average calculation:
    """
    #print(f"Total number of parameters: {total_params}")

    # Prepare the input text for generation
    inputs = tokenizer(input_text, return_tensors='pt')

    # Generate text
    outputs = model.generate(**inputs, max_length=500, num_return_sequences=1)

    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the response part by splitting based on "Response:"

    match = re.search(r"Response:\s*(.*)", generated_text)
    if match:
        response_text = match.group(1)
        # Remove extra spaces between sentences
        response_text = re.sub(r'\s{2,}', ' ', response_text)
        # Keep only up to the first sensible sentence-ending punctuation
        response_text = re.split(r'[.!?]', response_text)[0].strip() + '.'
        #print("Response text:", response_text)
        return response_text
    else:
        #print("Response not found in generated text")
        return "Response not found in generated text"

# Example input for inference
example_input = """
(z0 / easy-05
    :ARG1 (z1 / scare-01
              :ARG1 (z2 / person
                        :ARG0-of (z3 / have-rel-role-91
                                     :ARG1 (z4 / i)
                                     :ARG2 (z5 / uncle))))
    :mod (z6 / certain))
"""
output = main(example_input)


print(processed_train_dataset)
print(processed_validation_dataset)
print(processed_test_dataset)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 18000
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2000
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 5000
})


In [150]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
def calculate_bleu(predicted_text, ground_truth_text):
    # Tokenize the texts into lists of words
    reference = [ground_truth_text.split()]  # BLEU expects a list of references
    hypothesis = predicted_text.split()

    # Return 0 BLEU score if the hypothesis is empty
    if not hypothesis:
        return 0.0

    # Calculate BLEU score with smoothing
    smoothie = SmoothingFunction().method4  # Use smoothing to handle short texts
    bleu_score = sentence_bleu(reference, hypothesis, smoothing_function=smoothie)

    return bleu_score

bleu_score = 0
valid_count = 0
k = 10
# Loop through the dataset
for i in range(k):
    example_input = dataset_dict['test'][i]['amr_graph']
    ground_truth_text = dataset_dict['test'][i]['answer']

    # Tokenize and check input length
    tokenized_input = tokenizer(example_input, return_tensors='pt')
    input_length = tokenized_input['input_ids'].shape[1]

    # Skip examples with input length greater than 500
    if input_length > 500:
        continue

    # Generate model output and calculate BLEU score
    model_output_text = main(example_input)
    bleu = calculate_bleu(model_output_text, ground_truth_text)

    # Only add BLEU score if it’s valid (greater than zero)
    if bleu > 0:
        bleu_score += bleu
        valid_count += 1  # Increment count of valid scores

# Calculate the average BLEU score only if there are valid scores
if valid_count > 0:
    avg_bleu_score = bleu_score / valid_count
else:
    avg_bleu_score = 0.0  # Set average to zero if no valid scores were found

print("Average BLEU score:", avg_bleu_score)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Average BLEU score: 0.14109918266662358
