In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch
from tqdm import tqdm
from OmniEvent.infer import infer, AttrDict, get_model, get_tokenizer

MODEL_PATH = "../models/s2s-mt5-ed/"
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0

In [3]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = 'cuda'

def get_model_and_tokenizer(model_name_or_path):
    
    model_args = AttrDict({
        "paradigm": "seq2seq",
        "model_type": "mt5"
    })
    model = get_model(model_args, model_name_or_path)
    model = model.to(device)
    # tokenizer 
    tokenizer = get_tokenizer(model_name_or_path)

    return model, tokenizer

In [4]:
df = pd.read_csv("../data/data.csv")
#df = df.sample(n=50, random_state=42)

In [5]:
model, tokenizer = get_model_and_tokenizer(MODEL_PATH)

load from local file: ../models/s2s-mt5-ed/ model
load from local file: ../models/s2s-mt5-ed/ tokenizer


In [None]:
# from title
ed_results = []
for row in df.itertuples(index=False):
    res = infer(text=row.Title, task="ED", model=model, tokenizer=tokenizer)
    ed_results.append(res)

[{'text': 'Fed officials warn of inflation risks from tariff surge', 'events': []}]
[{'text': 'While You Were Sleeping: 5 stories you might have missed, Feb 4, 2025', 'events': []}]
[{'text': "Slower growth, souring business sentiment: How Trump's tariffs could hurt Singapore's economy", 'events': []}]
[{'text': "Malaysia's ECRL: A closer look at the US$11.2b railway’s promises of boosting jobs for locals and making money", 'events': []}]
[{'text': 'S&P 500, Nasdaq, pare losses as Trump’s Mexico tariffs paused', 'events': []}]
[{'text': 'Democrats blast Musk as USAID agency HQ shutters', 'events': [{'type': 'attack', 'trigger': 'blast', 'offset': [10, 15]}]}]
[{'text': "Trump to speak with China's Xi after raising tariffs, White House says", 'events': [{'type': 'meet', 'trigger': 'speak', 'offset': [9, 14]}]}]
[{'text': '‘Good opportunity’ for Rubio, Wang to meet at UN on Feb 18, says China’s UN envoy', 'events': [{'type': 'meet', 'trigger': 'meet', 'offset': [38, 42]}]}]
[{'text': 'Ex

In [None]:
# get all from articles
ed_results = []
for row in tqdm(df.itertuples(index=False), total=len(df)):
    sentences = [s.strip() for s in row.Content.split(".")]
    ed_article = []
    for sentence in sentences:
        event = infer(text=sentence, task="ED", model=model, tokenizer=tokenizer)[0]
        if len(event['events']) > 0:
            ed_article.append(event)
    ed_results.append(ed_article)

In [None]:
df['ed_results'] = ed_results
df.to_parquet("../data/ed_output.parquet", engine="pyarrow")