### RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction

GitHub: https://github.com/declare-lab/RelationPrompt

In [None]:
!git clone https://github.com/declare-lab/RelationPrompt.git
!cd RelationPrompt && git checkout 8ce3656
!cp -a RelationPrompt/* .
!wget -q -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/zero_rte_data.zip
!wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_fewrel_unseen_10_seed_0.tar
!tar -xf model_fewrel_unseen_10_seed_0.tar
# !wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_wiki_unseen_10_seed_0.tar
!unzip -nq zero_rte_data.zip
!pip install -q -r requirements.txt

fatal: destination path 'RelationPrompt' already exists and is not an empty directory.
HEAD is now at 8ce3656 Upgrade torch version 1.9.0 -> 1.10.0
File ‘model_fewrel_unseen_10_seed_0.tar’ already there; not retrieving.



In [None]:
#@title Data Parameters
data_name = "fewrel" #@param ["fewrel", "wiki"]
num_unseen_labels = 10 #@param [5,10,15]
random_seed = 0 #@param [0,1,2,3,4]
data_limit = 5000 #@param {type:"number"}
data_dir = f"outputs/data/splits/zero_rte/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}"
print(dict(data_dir=data_dir))

{'data_dir': 'outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0'}


In [None]:
# Data Setup
import json
import random
from pathlib import Path
from wrapper import Generator, Extractor, Dataset

def truncate_data(path:str, limit:int, path_out:str):
    # Use a subset of data for quick demo on Colab
    data = Dataset.load(path)
    random.seed(0)
    random.shuffle(data.sents)
    data.sents = data.sents[:limit]
    data.save(path_out)

path_train = "train.jsonl"
path_dev = "dev.jsonl"
path_test = "test.jsonl"
truncate_data(f"{data_dir}/train.jsonl", limit=data_limit, path_out=path_train)
truncate_data(f"{data_dir}/dev.jsonl", limit=data_limit // 10, path_out=path_dev)
truncate_data(f"{data_dir}/test.jsonl", limit=data_limit // 10, path_out=path_test)

In [None]:
# Data Exploration

def explore_data(path: str):
    data = Dataset.load(path)
    print("labels:", data.get_labels())
    print()
    for s in random.sample(data.sents, k=3):
        print("tokens:", s.tokens)
        for t in s.triplets:
            print("head:", t.head)
            print("tail:", t.tail)
            print("relation:", t.label)
        print()

explore_data(path_train)

labels: ['after a work by', 'applies to jurisdiction', 'architect', 'characters', 'child', 'constellation', 'contains administrative territorial entity', 'country', 'country of citizenship', 'country of origin', 'crosses', 'developer', 'director', 'distributed by', 'father', 'field of work', 'followed by', 'follows', 'genre', 'has part', 'head of government', 'headquarters location', 'heritage designation', 'instance of', 'instrument', 'language of work or name', 'league', 'licensed to broadcast to', 'located in or next to body of water', 'located in the administrative territorial entity', 'located on terrain feature', 'location of formation', 'manufacturer', 'member of', 'military branch', 'military rank', 'mother', 'mountain range', 'mouth of the watercourse', 'movement', 'notable work', 'occupant', 'occupation', 'operator', 'original language of film or TV show', 'part of', 'participant', 'participating team', 'performer', 'place served by transport hub', 'publisher', 'record label'

In [None]:
# Use Pretrained Model for Generation
model = Generator(load_dir="gpt2", save_dir="outputs/wrapper/fewrel/unseen_10_seed_0/generator")
model.generate(labels=["location", "religion"], path_out="synthetic.jsonl")
explore_data(path="synthetic.jsonl")

labels: ['location', 'religion']

tokens: ['In', '2007', ',', 'he', 'joined', 'a', 'group', 'of', 'artists', 'known', 'as', 'the', 'Moth', 'Boys', ',', 'an', 'annual', 'neo', '-', 'pop', 'quartet', 'that', 'plays', 'in', 'several', 'venues', 'around', 'the', 'country', 'in', 'Las', 'Vegas', '.']
head: [12, 13]
tail: [30, 31]
relation: location

tokens: ['There', 'is', 'a', 'section', 'of', 'the', 'town', 'under', '"', 'the', 'Graziano', '"', 'River', ',', 'a', 'channel', 'flowing', 'the', 'river', 'in', 'southwestern', 'Italy', 'from', 'the', 'island', 'of', 'Sardinia', 'to', 'Italy', '.']
head: [10]
tail: [26]
relation: location

tokens: ['In', 'August', '2012', ',', 'the', 'station', 'opened', 'on', 'its', 'regular', 'schedule', 'between', 'Minto', 'Plaza', 'in', 'Osaka', 'and', 'Keito', 'Station', 'in', 'the', 'city', 'of', 'Nara', 'in', 'Japan', '.']
head: [17, 18]
tail: [15]
relation: location



In [None]:
# Use Pretrained Model for Extraction
model = Extractor(load_dir="facebook/bart-base", save_dir="outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final")
model.predict(path_in=path_test, path_out="pred.jsonl")
explore_data(path="pred.jsonl")

{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=64, grad_accumulation=2, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}


100%|██████████| 8/8 [00:07<00:00,  1.11it/s]

labels: ['', 'competition class', 'location', 'member of political party', 'nominated for', 'operating system', 'original broadcaster', 'owned by', 'position played on team / speciality', 'religion']

tokens: ['The', 'South', 'Bank', 'Show', 'is', 'a', 'television', 'arts', 'magazine', 'show', 'that', 'was', 'produced', 'by', 'ITV', 'between', '1978', 'and', '2010', ',', 'and', 'by', 'Sky', 'Arts', 'from', '2012', '.']
head: [1, 2, 3]
tail: [14]
relation: original broadcaster

tokens: ['Then', 'Senator', 'Neptali', 'Gonzales', ',', 'whom', 'Maceda', 'helped', ',', 'was', 'installed', 'as', 'Senate', 'President', 'from', '1992', '-', '1993', 'and', '1995', '-', '1996', 'succeeded', 'him', '.']
head: [2, 3]
tail: [12, 13]
relation: position played on team / speciality

tokens: ['In', '1908', 'she', 'won', 'the', 'singles', 'title', 'at', 'the', 'Welsh', 'Championships', 'in', 'Newport', 'and', 'successfully', 'defended', 'it', 'in', '1909', '.', 'she', 'also', 'won', 'the', 'Scottish', '




In [None]:
# Full Training
save_dir = f"outputs/wrapper/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}"
print(dict(save_dir=save_dir))
model_kwargs = dict(batch_size=32, grad_accumulation=4)  # For lower memory on Colab

generator = Generator(
    load_dir="gpt2",
    save_dir=str(Path(save_dir) / "generator"),
    model_kwargs=model_kwargs,
)
extractor = Extractor(
    load_dir="facebook/bart-base",
    save_dir=str(Path(save_dir) / "extractor"),
    model_kwargs=model_kwargs,
)

generator.fit(path_train, path_dev)
extractor.fit(path_train, path_dev)
path_synthetic = str(Path(save_dir) / "synthetic.jsonl")
labels_dev = Dataset.load(path_dev).get_labels()
labels_test = Dataset.load(path_test).get_labels()
generator.generate(labels_dev + labels_test, path_out=path_synthetic)

extractor_final = Extractor(
    load_dir=str(Path(save_dir) / "extractor" / "model"),
    save_dir=str(Path(save_dir) / "extractor_final"),
    model_kwargs=model_kwargs,
)
extractor_final.fit(path_synthetic, path_dev)

path_pred = str(Path(save_dir) / "pred.jsonl")
extractor_final.predict(path_in=path_test, path_out=path_pred)
results = extractor_final.score(path_pred, path_test)
print(json.dumps(results, indent=2))

{'save_dir': 'outputs/wrapper/fewrel/unseen_10_seed_0'}
{'select_model': RelationGenerator(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/data', model_name='gpt2', do_pretrain=False, encoder_name='generate', pipe_name='text-generation', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, block_size=128)}
{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}
{'select_mod

100%|██████████| 16/16 [00:09<00:00,  1.72it/s]

{
  "path_pred": "outputs/wrapper/fewrel/unseen_10_seed_0/pred.jsonl",
  "path_gold": "test.jsonl",
  "precision": 0.328,
  "recall": 0.3215686274509804,
  "score": 0.32475247524752476
}



