<a href="https://colab.research.google.com/github/cxrlton/AMR/blob/main/amr_trial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [161]:
import datasets
from datasets import load_dataset
import penman
import re
import pandas as pd
from collections import defaultdict
from typing import Dict, List, Tuple, Set

In [174]:
class AMRPreprocessor:
    def __init__(self):
        self.ne_type_mapping = {
            'person': 'person',
            'team': 'organization',
            'country': 'location',
            'city': 'location',
            'state': 'location',
            'organization': 'organization',
            'company': 'organization',
            'government-organization': 'organization',
            'group': 'organization'
        }
        self.anon_counter = defaultdict(int)

    def preprocess_amr(self, amr_str: str) -> Tuple[str, Dict]:
        try:
            # Initialize entity map
            entity_map = {}

            # Handle multi-line AMR string
            amr_str = amr_str.replace('\n', ' ')

            # Process named entities
            for ne_type in self.ne_type_mapping:
                # Updated pattern to match the exact structure
                pattern = fr'\(\s*(?:z\d+\s*/\s*)?{ne_type}[^)]*:name\s*\([^)]*:op1\s*"([^"]+)"'
                matches = list(re.finditer(pattern, amr_str))

                for match in matches:
                    entity_name = match.group(1)
                    coarse_type = self.ne_type_mapping[ne_type]
                    anon_id = f"{coarse_type}_{self.anon_counter[coarse_type]}"
                    self.anon_counter[coarse_type] += 1

                    # Store the mapping
                    entity_map[anon_id] = entity_name

                    # Create replacement pattern
                    name_section = f':name\\s*\\([^)]*:op1\\s*"{entity_name}"[^)]*\\)'
                    replacement = f' {anon_id}'

                    # Replace the name section with the anonymous ID
                    amr_str = re.sub(name_section, replacement, amr_str)

            # Process dates
            amr_str = re.sub(r':year\s+(\d{4})', r':year year_\1', amr_str)
            amr_str = re.sub(r':month\s+(\d{1,2})', r':month month_\1', amr_str)
            amr_str = re.sub(r':day\s+(\d{1,2})', r':day day_\1', amr_str)

            # Clean up variables
            amr_str = re.sub(r'\(z\d+\s*/', '(', amr_str)
            amr_str = re.sub(r'\sz\d+\s', ' ', amr_str)
            amr_str = re.sub(r'\s*\/\s*', ' ', amr_str)

            return amr_str, entity_map

        except Exception as e:
            print(f"Error in preprocess_amr: {e}")
            return amr_str, {}

# # Test with specific examples
# test_cases = [
#     """(z0 / and
#     :op1 (z1 / beat-03
#              :ARG0 (z2 / team
#                        :name (z3 / name
#                                  :op1 "Carlton"))
#              :ARG1 (z4 / team
#                        :name (z5 / name
#                                  :op1 "Melbourne"))
#              :time (z6 / date-entity
#                        :year 2016)))""",

#     """(z0 / urge-01
#     :ARG0 (z1 / person
#               :name (z2 / name
#                         :op1 "Gillum")
#               :medium (z3 / television))
#     :ARG1 (z4 / person
#               :ARG0-of (z5 / reside-01)))"""
# ]

# preprocessor = AMRPreprocessor()

# print("Testing examples:")
# for i, test_amr in enumerate(test_cases):
#     print(f"\nTest case {i + 1}:")
#     print("Original:", test_amr)
#     processed_amr, entity_map = preprocessor.preprocess_amr(test_amr)
#     print("\nProcessed:", processed_amr)
#     print("Entity Map:", entity_map)
#     print("-" * 80)


In [175]:
dataset = load_dataset("Tverous/anli_amr_new2", split="train")

In [176]:
print(f"Dataset loaded. Size: {len(dataset)}")

Dataset loaded. Size: 100459


In [177]:
processed_data = []

In [178]:
for idx, example in enumerate(dataset):
    try:
        entry = {
            'premise': example.get('premise', ''),
            'hypothesis': example.get('hypothesis', ''),
            'label': example.get('label', -1),
            'amr': example.get('amr_penman', '')
        }
        processed_data.append(entry)

        if idx % 1000 == 0:
            print(f"Processed {idx} examples...")
    except Exception as e:
        print(f"Error processing example {idx}: {e}")
        continue

Processed 0 examples...
Processed 1000 examples...
Processed 2000 examples...
Processed 3000 examples...
Processed 4000 examples...
Processed 5000 examples...
Processed 6000 examples...
Processed 7000 examples...
Processed 8000 examples...
Processed 9000 examples...
Processed 10000 examples...
Processed 11000 examples...
Processed 12000 examples...
Processed 13000 examples...
Processed 14000 examples...
Processed 15000 examples...
Processed 16000 examples...
Processed 17000 examples...
Processed 18000 examples...
Processed 19000 examples...
Processed 20000 examples...
Processed 21000 examples...
Processed 22000 examples...
Processed 23000 examples...
Processed 24000 examples...
Processed 25000 examples...
Processed 26000 examples...
Processed 27000 examples...
Processed 28000 examples...
Processed 29000 examples...
Processed 30000 examples...
Processed 31000 examples...
Processed 32000 examples...
Processed 33000 examples...
Processed 34000 examples...
Processed 35000 examples...
Proce

In [179]:
df = pd.DataFrame(processed_data)

In [180]:
df

Unnamed: 0,premise,hypothesis,label,amr
0,"TOKYO, Dec 18 (Reuters) - Japan’s Shionogi & C...",The article was written on December 18th.,0,(z0 / write-01\n :ARG1 (z1 / article)\n ...
1,Tallahassee Mayor and Democratic gubernatorial...,Gillum was on TV urging residents to stay out ...,0,(z0 / urge-01\n :ARG0 (z1 / person\n ...
2,MELBOURNE will look to avoid stumbling against...,Carlton beat Melbourne in 2016 and will attemp...,0,(z0 / and\n :op1 (z1 / beat-03\n ...
3,"by Ted Raymond, Newstalk 580 CFRA A stretch of...",The road was closed for more than two hours af...,0,(z0 / close-01\n :ARG1 (z1 / road)\n :du...
4,Drivers are reporting heavy traffic on the nor...,Its advisible to slow down,0,(z0 / advise-01\n :ARG2 (z1 / slow-down-03))
...,...,...,...,...
100454,A U.S. soldier accused of participating in the...,The Soldier committed these crimes in the unit...,1,(z0 / commit-02\n :ARG0 (z1 / person\n ...
100455,Figures released today show that the number of...,there is a continuous decline in the number of...,1,(z0 / decline-01\n :ARG1 (z1 / number\n ...
100456,"For Bechtolsheim, who designed the prototype f...",Bechtolsheim is experienced with computers.,0,(z0 / experience-01\n :ARG0 (z1 / person\n ...
100457,Figures released today show that the number of...,figures shows New Zealanders are the lowest or...,1,(z0 / show-01\n :ARG0 (z1 / figure)\n :A...


In [181]:
preprocessor = AMRPreprocessor()

In [182]:
print("\nProcessing AMRs...")
processed_amrs = []
entity_mappings = []

for idx, row in df.iterrows():
    try:
        processed_amr, entity_map = preprocessor.preprocess_amr(row['amr'])
        processed_amrs.append(processed_amr)
        entity_mappings.append(entity_map)

        if idx % 1000 == 0:
            print(f"Preprocessed {idx} AMRs...")

    except Exception as e:
        print(f"Error processing row {idx}: {e}")
        processed_amrs.append(row['amr'])
        entity_mappings.append({})


Processing AMRs...
Preprocessed 0 AMRs...
Preprocessed 1000 AMRs...
Preprocessed 2000 AMRs...
Preprocessed 3000 AMRs...
Preprocessed 4000 AMRs...
Preprocessed 5000 AMRs...
Preprocessed 6000 AMRs...
Preprocessed 7000 AMRs...
Preprocessed 8000 AMRs...
Preprocessed 9000 AMRs...
Preprocessed 10000 AMRs...
Preprocessed 11000 AMRs...
Preprocessed 12000 AMRs...
Preprocessed 13000 AMRs...
Preprocessed 14000 AMRs...
Preprocessed 15000 AMRs...
Preprocessed 16000 AMRs...
Preprocessed 17000 AMRs...
Preprocessed 18000 AMRs...
Preprocessed 19000 AMRs...
Preprocessed 20000 AMRs...
Preprocessed 21000 AMRs...
Preprocessed 22000 AMRs...
Preprocessed 23000 AMRs...
Preprocessed 24000 AMRs...
Preprocessed 25000 AMRs...
Preprocessed 26000 AMRs...
Preprocessed 27000 AMRs...
Preprocessed 28000 AMRs...
Preprocessed 29000 AMRs...
Preprocessed 30000 AMRs...
Preprocessed 31000 AMRs...
Preprocessed 32000 AMRs...
Preprocessed 33000 AMRs...
Preprocessed 34000 AMRs...
Preprocessed 35000 AMRs...
Preprocessed 36000 AM

In [183]:
# Add to DataFrame
df['processed_amr'] = processed_amrs
df['entity_mappings'] = entity_mappings

In [184]:
df

Unnamed: 0,premise,hypothesis,label,amr,processed_amr,entity_mappings
0,"TOKYO, Dec 18 (Reuters) - Japan’s Shionogi & C...",The article was written on December 18th.,0,(z0 / write-01\n :ARG1 (z1 / article)\n ...,( write-01 :ARG1 ( article) :time ( da...,{}
1,Tallahassee Mayor and Democratic gubernatorial...,Gillum was on TV urging residents to stay out ...,0,(z0 / urge-01\n :ARG0 (z1 / person\n ...,( urge-01 :ARG0 ( person pe...,{'person_0': 'Gillum'}
2,MELBOURNE will look to avoid stumbling against...,Carlton beat Melbourne in 2016 and will attemp...,0,(z0 / and\n :op1 (z1 / beat-03\n ...,( and :op1 ( beat-03 :ARG0 ( ...,"{'organization_0': 'Carlton', 'organization_1'..."
3,"by Ted Raymond, Newstalk 580 CFRA A stretch of...",The road was closed for more than two hours af...,0,(z0 / close-01\n :ARG1 (z1 / road)\n :du...,( close-01 :ARG1 ( road) :duration ( m...,{}
4,Drivers are reporting heavy traffic on the nor...,Its advisible to slow down,0,(z0 / advise-01\n :ARG2 (z1 / slow-down-03)),( advise-01 :ARG2 ( slow-down-03)),{}
...,...,...,...,...,...,...
100454,A U.S. soldier accused of participating in the...,The Soldier committed these crimes in the unit...,1,(z0 / commit-02\n :ARG0 (z1 / person\n ...,( commit-02 :ARG0 ( person :...,{'location_15956': 'United'}
100455,Figures released today show that the number of...,there is a continuous decline in the number of...,1,(z0 / decline-01\n :ARG1 (z1 / number\n ...,( decline-01 :ARG1 ( number ...,{'person_37066': 'New'}
100456,"For Bechtolsheim, who designed the prototype f...",Bechtolsheim is experienced with computers.,0,(z0 / experience-01\n :ARG0 (z1 / person\n ...,( experience-01 :ARG0 ( person ...,{'person_37067': 'Bechtolsheim'}
100457,Figures released today show that the number of...,figures shows New Zealanders are the lowest or...,1,(z0 / show-01\n :ARG0 (z1 / figure)\n :A...,( show-01 :ARG0 ( figure) :ARG1 ( have...,{'location_15957': 'New'}


In [185]:
# Print statistics
print("\nProcessing complete. Statistics:")
total_amrs = len(df)
amrs_with_entities = sum(1 for m in entity_mappings if m)
print(f"Total AMRs processed: {total_amrs}")
print(f"AMRs with recognized entities: {amrs_with_entities}")
print(f"Percentage with entities: {(amrs_with_entities/total_amrs)*100:.2f}%")

# Show some examples
print("\nExample processed AMRs:")
for i in range(3):
    print(f"\nExample {i+1}:")
    print("Original AMR:")
    print(df['amr'].iloc[i])
    print("\nProcessed AMR:")
    print(df['processed_amr'].iloc[i])
    print("Entity Map:")
    print(df['entity_mappings'].iloc[i])
    print("-" * 80)


Processing complete. Statistics:
Total AMRs processed: 100459
AMRs with recognized entities: 48482
Percentage with entities: 48.26%

Example processed AMRs:

Example 1:
Original AMR:
(z0 / write-01
    :ARG1 (z1 / article)
    :time (z2 / date-entity
              :day 18
              :month 12))

Processed AMR:
( write-01     :ARG1 ( article)     :time ( date-entity               :day day_18               :month month_12))
Entity Map:
{}
--------------------------------------------------------------------------------

Example 2:
Original AMR:
(z0 / urge-01
    :ARG0 (z1 / person
              :name (z2 / name
                        :op1 "Gillum")
              :medium (z3 / television))
    :ARG1 (z4 / person
              :ARG0-of (z5 / reside-01))
    :ARG2 (z6 / stay-01
              :ARG1 z4
              :ARG3 (z7 / out-06
                        :ARG1 z4
                        :ARG2 (z8 / storm))))

Processed AMR:
( urge-01     :ARG0 ( person                person_0         

In [189]:
df.to_csv('processed_amrs.csv', index=False)
print("Data saved to 'processed_amrs.csv'")

Data saved to 'processed_amrs.csv'


In [186]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AdamW, get_linear_schedule_with_warmup
import pandas as pd
from tqdm import tqdm
import numpy as np

In [190]:
class AMRDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        processed_amr = row['processed_amr']
        premise = row['premise']

        # Prepare inputs (processed AMR)
        inputs = self.tokenizer.encode_plus(
            processed_amr,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Prepare targets (premise text)
        targets = self.tokenizer.encode_plus(
            premise,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }


In [191]:
# Training function
def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

        progress_bar.set_postfix({'loss': loss.item()})

    return total_loss / len(dataloader)

# Validation function
def validate(model, dataloader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            total_loss += outputs.loss.item()

    return total_loss / len(dataloader)


In [192]:
def main():
    # Load the processed data
    df = pd.read_csv('processed_amrs.csv')
    print(f"Loaded dataset with {len(df)} examples")

    # Initialize tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained('t5-base')
    model = T5ForConditionalGeneration.from_pretrained('t5-base')

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Split data into train and validation
    train_size = int(0.8 * len(df))
    train_df = df[:train_size]
    val_df = df[train_size:]

    # Create datasets and dataloaders
    train_dataset = AMRDataset(train_df, tokenizer)
    val_dataset = AMRDataset(val_df, tokenizer)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=8,
        shuffle=False
    )

    # Training settings
    num_epochs = 3
    warmup_steps = 0
    total_steps = len(train_dataloader) * num_epochs

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=5e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)
        print(f"Training loss: {train_loss:.4f}")

        # Validate
        val_loss = validate(model, val_dataloader, device)
        print(f"Validation loss: {val_loss:.4f}")

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, f'checkpoint_epoch_{epoch+1}.pt')

if __name__ == "__main__":
    main()


Loaded dataset with 100459 examples


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]



Starting training...

Epoch 1/3


Training:   0%|          | 0/10046 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training: 100%|██████████| 10046/10046 [1:27:31<00:00,  1.91it/s, loss=0.435]


Training loss: 0.4954


Validating: 100%|██████████| 2512/2512 [07:07<00:00,  5.88it/s]


Validation loss: 0.4410

Epoch 2/3


Training: 100%|██████████| 10046/10046 [1:27:34<00:00,  1.91it/s, loss=0.364]


Training loss: 0.3914


Validating: 100%|██████████| 2512/2512 [07:07<00:00,  5.88it/s]


Validation loss: 0.4347

Epoch 3/3


Training: 100%|██████████| 10046/10046 [1:27:31<00:00,  1.91it/s, loss=0.507]


Training loss: 0.3576


Validating: 100%|██████████| 2512/2512 [07:06<00:00,  5.88it/s]


Validation loss: 0.4335


In [205]:

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
def calculate_bleu_score():
    # Load the processed data
    df = pd.read_csv('processed_amrs.csv')
    print(f"Loaded dataset with {len(df)} examples")

    # Initialize tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained('t5-base')
    model = T5ForConditionalGeneration.from_pretrained('t5-base')

    # Load the trained model
    checkpoint = torch.load('checkpoint_epoch_3.pt', weights_only=True)  # Load the last checkpoint
    model.load_state_dict(checkpoint['model_state_dict'])

    # Set device and model to evaluation mode
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model = model.half()  # Convert to FP16 for A100

    # Split data (use the same validation split as during training)
    train_size = int(0.8 * len(df))
    val_df = df[train_size:]

    # Create validation dataset and dataloader
    val_dataset = AMRDataset(val_df, tokenizer)
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=32,  # Larger batch size for A100
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    def calculate_bleu(model, tokenizer, dataloader, device):
        model.eval()
        all_predictions = []
        all_references = []
        smooth = SmoothingFunction()

        with torch.no_grad():
            with torch.amp.autocast('cuda'):  # Proper autocast for A100
                for batch in tqdm(dataloader, desc='Generating predictions'):
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)

                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=256,
                        num_beams=4,
                        length_penalty=2.0,
                        early_stopping=True,
                        max_new_tokens=150,
                        no_repeat_ngram_size=3
                    )

                    predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                    references = tokenizer.batch_decode(labels, skip_special_tokens=True)

                    predictions = [nltk.word_tokenize(pred.lower()) for pred in predictions]
                    references = [[nltk.word_tokenize(ref.lower())] for ref in references]

                    all_predictions.extend(predictions)
                    all_references.extend(references)

                    # Clear cache periodically
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

        bleu_score = corpus_bleu(all_references, all_predictions, smoothing_function=smooth.method1)
        return bleu_score

    # Calculate BLEU score
    print("Starting BLEU score calculation...")
    bleu_score = calculate_bleu(model, tokenizer, val_dataloader, device)
    print(f"BLEU Score: {bleu_score:.4f}")

    # Clear GPU memory
    torch.cuda.empty_cache()

# Run the BLEU score calculation
calculate_bleu_score()


Loaded dataset with 100459 examples
Starting BLEU score calculation...


Generating predictions:   0%|          | 0/628 [00:00<?, ?it/s]Both `max_new_tokens` (=150) and `max_length`(=256) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Generating predictions:   0%|          | 1/628 [00:10<1:54:06, 10.92s/it]Both `max_new_tokens` (=150) and `max_length`(=256) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Generating predictions:   0%|          | 2/628 [00:21<1:50:13, 10.56s/it]Both `max_new_tokens` (=150) and `max_length`(=256) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Generating predictions:   0%|          | 

BLEU Score: 0.0195
