In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_folder = '../FineTuning/FinetunedModels/RUCAIBox_mvp-data-to-text'
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder)
tok = AutoTokenizer.from_pretrained(model_folder)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
def reformatData(arr):
    prompt = 'Write a live blog post describing the following events in a Formula 1 race: '
    for obj in arr:
        i = str(arr.index(obj) + 1)
        sub, act, obj = obj['subject'], obj['action'], obj['object']
        sub_i, act_i, obj_i = 'Agent' + i, 'Action' + i, 'Object' + i
        try:
            sub_cats = [list(s.keys()) for s in sub][0]
            sub_ents = [list(s.values()) for s in sub][0]
        except IndexError: continue
        for p in range(len(sub_cats)):
            cat, ent = sub_cats[p], sub_ents[p]
            if not ent: continue
            if type(ent) == list:
                ent = ', '.join(ent)
                cat += 's'
            prompt += sub_i + ' | ' + cat + ' | ' + str(ent) + ' [SEP] '
        prompt += act_i + ' | ' + act + ' [SEP] '
        for key in obj:
            obj_cat, obj_lst = key, obj[key]
            for obj_ent in obj_lst:
                prompt += obj_i + ' | ' + obj_cat + ' | ' + obj_ent + ' [SEP] '
    return prompt[:-7]

In [3]:
def generatePosts(prompts):
    model_inputs = tok(prompts, return_tensors='pt', padding=True)
    model_output = model.generate(**model_inputs, max_new_tokens=500)
    return tok.batch_decode(model_output, skip_special_tokens=True)

In [4]:
import json
import os


def loadFiles():
    folder = '../EventIdentification/Events/'
    return [folder + f for f in os.listdir(folder) if 'ipynb' not in f]

def loadData(file):
    with open(file, 'r') as f: return json.load(f)

In [34]:
def getPrompts(data, lap):
    prompts = []
    for actions in data:
        for act in actions:
            prompts.append((act, reformatData(act)))
    return prompts

In [None]:
import pandas as pd


files = loadFiles()

for f in files:
    print(f)
    output, acts = [['Lap', 'Data', 'Prompt', 'Output']], []
    data, objs, laps, prompts = loadData(f), [], [], []
    for lap in data:
        ps = getPrompts(data[lap], lap)
        for a, p in ps: laps.append(lap), prompts.append(p), acts.append(a)
    posts = generatePosts(prompts)
    output = list(zip(laps, acts, prompts, posts))
    # print(output[18:25])
        # data.append(data[lap]
        # posts = generatePosts(prompts[lap])
        # for i in range(len(posts)):
        #     output.append([lap, data[lap][i], prompts[lap][i], posts[i]])
    # for lap in prompts:
    #     posts = generatePosts(prompts[lap])
    #     for i in range(len(posts)):
    #         output.append([lap, data[lap][i],
    #                        prompts[lap][i], posts[i]])
    df, fn = pd.DataFrame(output), f.split('/')[-1].replace('json', 'csv')
    fn = './GeneratedPosts/' + fn
    df.to_csv(fn, index=False, header=False)
    print('Saved', fn)

../EventIdentification/Events/2018_monaco.json
Saved ./GeneratedPosts/2018_monaco.csv
../EventIdentification/Events/2018_shanghai.json
Saved ./GeneratedPosts/2018_shanghai.csv
../EventIdentification/Events/2019_marina_bay.json
Saved ./GeneratedPosts/2019_marina_bay.csv
../EventIdentification/Events/2019_catalunya.json
Saved ./GeneratedPosts/2019_catalunya.csv
../EventIdentification/Events/2021_baku.json
Saved ./GeneratedPosts/2021_baku.csv
../EventIdentification/Events/2021_americas.json
Saved ./GeneratedPosts/2021_americas.csv
../EventIdentification/Events/2022_spa.json
Saved ./GeneratedPosts/2022_spa.csv
../EventIdentification/Events/2022_jeddah.json
Saved ./GeneratedPosts/2022_jeddah.csv
../EventIdentification/Events/2023_baku.json
Saved ./GeneratedPosts/2023_baku.csv
../EventIdentification/Events/2023_silverstone.json
