

In this exercise, we are going to fine-tune a model on a supervised **relation extraction** dataset.

The goal of the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

For example given the sentence:


**John, who played last night, is Doe's father.**


The model given the sentence and the spans of the entities John a Doe the model will have to predict what is the relation between Dohn and Doe from a set of pre-defined relations in this case the relation is parents (please note that some of the relations are one-way relations)

The dataset we will use is a subset of the [TACRED](https://nlp.stanford.edu/projects/tacred/) dataset, a supervised relation extraction dataset by Stanford University. 


As you realize by now, the straight forward supervised approach is just to take one of the transformers and use the sentence and the entity spans as input. However, in this exercise we will try a different approach using **Question Answering (QA)**.

Instead of just using the entities span and the sentence, we will train a model to answer the following questions "Who are the parents of John?","Who are the children of Doe?". If the question answering model will be able to answer the question succusfully than we will be able to conclude that the relation between the two entites exists.

In general **for each realtion** we will need to come up with template questions: In the example above the template questions corresponding to the parents relation are: 

*   Who are the parents of E1?
*   Who are the children of E2?

where E1 and E2 are the entities.



# Your part

You are required to fine-tune a model for relation extraction using the question answering framework.

Notes:


* In previous lectures we have seen a demo notebook demonstrating how to fine-tune a transformer on SQUAD, **from a technical prespective we are doing the same.**
*   For each one of the seven relations in the dataset you will need to find appropriate questions (please note that the questions must be SQUAD-like, meaning that if the answer exists it must be contained within the sentence in a contiguous way.)
* There are several issues that you will need to consider. Please provide a brief explanation whenever you tackle such issues


## Data

Let's download the data from the web, hosted on Dropbox.

In [32]:
!pip install -q transformers 

In [2]:
import requests, zipfile, io

def download_data():
    url = "https://www.dropbox.com/s/izi2x4sjohpzoot/relation_extraction_dataset.zip?dl=1"
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall()

download_data()

Each row in the dataframe consists of a news article, and a sentence in which a certain relationship was found (just as "invested_in", or "founded_by"). There were some patterns used to gather the data, so it might contain some noise. 

In [3]:
import pandas as pd

df = pd.read_pickle("relation_extraction_dataset.pkl")
df.reset_index(drop=True, inplace=True)
df.head()

Unnamed: 0,end_idx,entities,entity_spans,match,original_article,sentence,start_idx,string_id
0,1024,"[Lilium, Baillie Gifford]","[[3, 9], [151, 166]]",raising $35,Happy Friday!\n\nWe sincerely hope you and you...,"3) Lilium, a German startup that’s making an a...",1013,invested_in
1,1762,"[Facebook ’s, Giphy]","[[92, 102], [148, 153]]",acquisition,Happy Friday!\n\nWe sincerely hope you and you...,"Meanwhile, the UK’s watchdog on Friday announc...",1751,acquired_by
2,2784,"[Global-e, Vitruvian Partners]","[[27, 35], [94, 112]]",raised $60,Happy Friday!\n\nWe sincerely hope you and you...,Israeli e-commerce startup Global-e has raised...,2774,invested_in
3,680,"[Joris Van Der Gucht, Silverfin]","[[0, 19], [35, 44]]",founder,Hg is a leading investor in tax and accounting...,"Joris Van Der Gucht, co-founder at Silverfin c...",673,founded_by
4,2070,"[Tim Vandecasteele, Silverfin]","[[0, 17], [71, 80]]",founder,Hg is a leading investor in tax and accounting...,"Tim Vandecasteele, co-founder added: ""We want ...",2063,founded_by


Let's create 2 dictionaries, one that maps each label to a unique integer, and one that does it the other way around.

In [4]:
id2label = dict()
for idx, label in enumerate(df.string_id.value_counts().index):
  id2label[idx] = label

As we can see, there are 7 labels (7 unique relationships):

In [5]:
id2label

{0: 'founded_by',
 1: 'acquired_by',
 2: 'invested_in',
 3: 'CEO_of',
 4: 'subsidiary_of',
 5: 'partners_with',
 6: 'owned_by'}

In [6]:
label2id = {v:k for k,v in id2label.items()}
label2id

{'CEO_of': 3,
 'acquired_by': 1,
 'founded_by': 0,
 'invested_in': 2,
 'owned_by': 6,
 'partners_with': 5,
 'subsidiary_of': 4}

## Good Luck




In [7]:
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

In [8]:
df[df['string_id']=='founded_by'].sample(10)

Unnamed: 0,end_idx,entities,entity_spans,match,original_article,sentence,start_idx,string_id
7643,6059,"[Keleya, Demodesk]","[[10, 16], [161, 169]]",Founded,Germany has one of the most vital startup scen...,Apps like Keleya might benefit from that Demod...,6052,founded_by
8161,2463,"[DoorDash, Tony Xu]","[[0, 8], [17, 24]]",founder,"Feedzai: Feedzai is a San Mateo, Ca.-based dat...","DoorDash founder Tony Xu, StockX CEO Scott Cut...",2456,founded_by
7630,7366,"[Henrik Torstensson, Lifesum]","[[0, 18], [139, 146]]",founder,"The land of Abba, Volvo, and Ikea is so much m...",Henrik Torstensson is a Swedish entrepreneur a...,7359,founded_by
1489,5202,"[Elizabeth Varley, TechHub]","[[0, 16], [43, 50]]",founder,"With today being International Women's Day, we...",Elizabeth Varley is the founder and CEO of Tec...,5195,founded_by
8427,8531,"[NFX, James Currier]","[[67, 70], [82, 95]]",founder,2What s in a name More than two years ago Fast...,the name of your company is what gets passed b...,8524,founded_by
11300,3072,"[Craig Rosenberg, TOPO]","[[122, 137], [171, 175]]",founder,Acquisition Strengthens #1 Sales Engagement Pl...,When we look across our dataset of high-growth...,3065,founded_by
6971,1346,"[Konrad Feldman, Quantcast]","[[198, 212], [236, 245]]",founder,Powerful AI and Machine Learning Technology De...,"""As champions of a free and open internet, we ...",1339,founded_by
3085,1086,"[Monerium, Sveinn Valfells]","[[15, 23], [43, 58]]",founder,Licensed e money issuer Monerium will support ...,In a statement Monerium co founder and CEO Sve...,1079,founded_by
7046,4139,"[Emma Best, Denial of Secrets]","[[169, 178], [206, 223]]",founder,The report has sparked calls by lawmakers and ...,"""It contains pretty much everything on Gab, in...",4132,founded_by
7697,1426,"[Saurabh Singh, Flickstree]","[[106, 119], [143, 153]]",founder,BENGALURU: Mumbai-based digital video curation...,", it also allows our publisher partners to ear...",1419,founded_by


Creating the dataset:

We choose the setence and not the entire article as the context and generate 2 questions for each row (one for each side of the relationship), doubling the number of rows.

As was told in the office hour we always define thw relation direction as E1 relation E2, even though part of the data relations are in the opposite direction. 

In [9]:
pd.options.mode.chained_assignment = None  # default='warn'

# Dataset with questions for founded_by relationship

who_founded = df[df['string_id']=='founded_by']
who_founded['answers'] = who_founded.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
who_founded['context'] = who_founded['sentence']
who_founded['id'] = 'wf' + who_founded.index.map(str)
who_founded['question'] = who_founded.apply(lambda row: 'Who founded ' + row.entities[0] + '?', axis = 1)
who_founded['title'] = who_founded.apply(lambda row: row.entities[0] + ' founded by ' + row.entities[1], axis = 1)

what_was_found = df[df['string_id']=='founded_by']
what_was_found['answers'] = what_was_found.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
what_was_found['context'] = what_was_found['sentence']
what_was_found['id'] = 'wwf' + what_was_found.index.map(str)
what_was_found['question'] = what_was_found.apply(lambda row: 'What was founded by ' + row.entities[1] + '?', axis = 1)
what_was_found['title'] = what_was_found.apply(lambda row: row.entities[0] + ' founded by ' + row.entities[1], axis = 1)

frames = [who_founded, what_was_found]
founded_by = pd.concat(frames)
founded_by.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
founded_by.sample(6)

Unnamed: 0,answers,context,id,question,title
429,"{'answer_start': [28], 'text': ['Wasowski Vent...",", Piotr is founder & CEO of Wasowski Ventures,...",wf429,Who founded Piotr?,Piotr founded by Wasowski Ventures
1162,"{'answer_start': [24], 'text': ['Finless Foods']}","In October, 28-year-old Finless Foods co-found...",wwf1162,What was founded by Mike Selden?,Finless Foods founded by Mike Selden
1808,"{'answer_start': [17], 'text': ['Charlie Lee']}",Litecoin founder Charlie Lee and Ethereum co-f...,wf1808,Who founded Litecoin?,Litecoin founded by Charlie Lee
1451,"{'answer_start': [52], 'text': ['cleantech Hol...",Carlota Pi is co-founder and Executive Preside...,wf1451,Who founded Carlota Pi?,Carlota Pi founded by cleantech Holaluz
3362,"{'answer_start': [295], 'text': ['WhizHack Tec...",Private Limited signed an MoU for the establis...,wf3362,Who founded Sanjay Sengupta?,Sanjay Sengupta founded by WhizHack Technologies
9279,"{'answer_start': [52], 'text': ['bastien Rover...",Pazzi was founded by two young French inventor...,wf9279,Who founded Pazzi?,Pazzi founded by bastien Roverso


In [10]:
# Dataset with questions for acquired_by relationship

who_acquired = df[df['string_id']=='acquired_by']
who_acquired['answers'] = who_acquired.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
who_acquired['context'] = who_acquired['sentence']
who_acquired['id'] = 'wa' + who_acquired.index.map(str)
who_acquired['question'] = who_acquired.apply(lambda row: 'Who acquired ' + row.entities[0] + '?', axis = 1)
who_acquired['title'] = who_acquired.apply(lambda row: row.entities[0] + ' acquired by ' + row.entities[1], axis = 1)

what_was_acquired = df[df['string_id']=='acquired_by']
what_was_acquired['answers'] = what_was_acquired.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
what_was_acquired['context'] = what_was_acquired['sentence']
what_was_acquired['id'] = 'wwa' + what_was_acquired.index.map(str)
what_was_acquired['question'] = what_was_acquired.apply(lambda row: 'What was acquired by ' + row.entities[1] + '?', axis = 1)
what_was_acquired['title'] = what_was_acquired.apply(lambda row: row.entities[0] + ' acquired by ' + row.entities[1], axis = 1)

frames = [who_acquired, what_was_acquired]
acquired_by = pd.concat(frames)
acquired_by.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
acquired_by.sample(6)

Unnamed: 0,answers,context,id,question,title
5614,"{'answer_start': [39], 'text': ['TikTok']}",The real reason Microsoft wants to buy TikTok,wa5614,Who acquired Microsoft?,Microsoft acquired by TikTok
11615,"{'answer_start': [4], 'text': ['Hummingbird Ba...","The Hummingbird Bakery, a London-based America...",wwa11615,What was acquired by Acropolis Capital?,Hummingbird Bakery acquired by Acropolis Capital
5549,"{'answer_start': [43], 'text': ['Spell Securit...",Qualys has acquired the software assets of Spe...,wa5549,Who acquired Qualys?,Qualys acquired by Spell Security
2565,"{'answer_start': [0], 'text': ['Kaseya']}",Kaseya acquires RocketCyber to bring SOC solut...,wwa2565,What was acquired by RocketCyber?,Kaseya acquired by RocketCyber
10104,"{'answer_start': [143], 'text': ['Hino']}","While Hyundai's sales were up, Toyota sold few...",wa10104,Who acquired Hyundai?,Hyundai acquired by Hino
5463,"{'answer_start': [204], 'text': ['Shafqat Isla...","“Over the years, NewsCred has established itse...",wa5463,Who acquired NewsCred?,NewsCred acquired by Shafqat Islam


In [11]:
# Dataset with questions for invested_in relationship

who_invested_in = df[df['string_id']=='invested_in']
who_invested_in['answers'] = who_invested_in.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
who_invested_in['context'] = who_invested_in['sentence']
who_invested_in['id'] = 'wii' + who_invested_in.index.map(str)
who_invested_in['question'] = who_invested_in.apply(lambda row: 'Who invested in ' + row.entities[0] + '?', axis = 1)
who_invested_in['title'] = who_invested_in.apply(lambda row: row.entities[0] + ' invested in ' + row.entities[1], axis = 1)

what_was_invested_in = df[df['string_id']=='invested_in']
what_was_invested_in['answers'] = what_was_invested_in.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
what_was_invested_in['context'] = what_was_invested_in['sentence']
what_was_invested_in['id'] = 'wwii' + what_was_invested_in.index.map(str)
what_was_invested_in['question'] = what_was_invested_in.apply(lambda row: 'In what did ' + row.entities[1] + ' invested?', axis = 1)
what_was_invested_in['title'] = what_was_invested_in.apply(lambda row: row.entities[0] + ' invested in ' + row.entities[1], axis = 1)

frames = [who_invested_in, what_was_invested_in]
invested_in = pd.concat(frames)
invested_in.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
invested_in.sample(6)

Unnamed: 0,answers,context,id,question,title
586,"{'answer_start': [235], 'text': ['Highland Eur...",Today Modulr a digital alternative to commerci...,wii586,Who invested in Modulr?,Modulr invested in Highland Europe
50,"{'answer_start': [1], 'text': ['ActionIQ']}","""ActionIQ, the leading customer data platform ...",wwii50,In what did March Capital Partners invested?,ActionIQ invested in March Capital Partners
3982,"{'answer_start': [61], 'text': ['Blossom Capit...","According to data from Beauhurst, however, inv...",wii3982,Who invested in Beauhurst?,Beauhurst invested in Blossom Capital
6840,"{'answer_start': [123], 'text': ['Green Invoic...",- Israeli private equity fund Fortissimo Capit...,wii6840,Who invested in Fortissimo Capital?,Fortissimo Capital invested in Green Invoice
2151,"{'answer_start': [109], 'text': ['CynLr']}",China-based Elite Technology raised $14 millio...,wii2151,Who invested in Elite Technology?,Elite Technology invested in CynLr
2351,"{'answer_start': [41], 'text': ['Founders Fund']}","Notably, he was instrumental in enabling Found...",wwii2351,In what did Postmates invested?,Founders Fund invested in Postmates


In [12]:
# Dataset with questions for CEO_of relationship

ceo_of_what = df[df['string_id']=='CEO_of']
ceo_of_what['answers'] = ceo_of_what.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
ceo_of_what['context'] = ceo_of_what['sentence']
ceo_of_what['id'] = 'cow' + ceo_of_what.index.map(str)
ceo_of_what['question'] = ceo_of_what.apply(lambda row: 'In which company ' + row.entities[0] + ' is the CEO?', axis = 1)
ceo_of_what['title'] = ceo_of_what.apply(lambda row: row.entities[0] + ' CEO of ' + row.entities[1], axis = 1)

who_is_ceo = df[df['string_id']=='invested_in']
who_is_ceo['answers'] = who_is_ceo.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
who_is_ceo['context'] = who_is_ceo['sentence']
who_is_ceo['id'] = 'wic' + who_is_ceo.index.map(str)
who_is_ceo['question'] = who_is_ceo.apply(lambda row: 'Who is the CEO of ' + row.entities[1] + '?', axis = 1)
who_is_ceo['title'] = who_is_ceo.apply(lambda row: row.entities[0] + ' CEO of ' + row.entities[1], axis = 1)

frames = [ceo_of_what, who_is_ceo]
CEO_of = pd.concat(frames)
CEO_of.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
CEO_of.sample(6)

Unnamed: 0,answers,context,id,question,title
3434,"{'answer_start': [135], 'text': ['nTopology']}","""nTopology is excited to announce our first re...",cow3434,In which company Bradley Rothenberg is the CEO?,Bradley Rothenberg CEO of nTopology
6488,"{'answer_start': [0], 'text': ['OneTrust']}","OneTrust, a provider of privacy, security and ...",wic6488,Who is the CEO of TCV?,OneTrust CEO of TCV
4002,"{'answer_start': [18], 'text': ['Jim']}","In 2018 and 2019, Jim invested in and led Rece...",wic4002,Who is the CEO of Receptra Naturals?,Jim CEO of Receptra Naturals
9845,"{'answer_start': [28], 'text': ['Company']}","applicationsBy end of 2020, Company expects to...",wic9845,Who is the CEO of Ingredion Incorporated?,Company CEO of Ingredion Incorporated
8665,"{'answer_start': [0], 'text': ['Memphis Meats']}",Memphis Meats has raised over $20 million in f...,wic8665,Who is the CEO of Crunchbase?,Memphis Meats CEO of Crunchbase
7596,"{'answer_start': [0], 'text': ['Foundation Cap...",Foundation Capital led its $4.3 million seed r...,wic7596,Who is the CEO of Y Combinator?,Foundation Capital CEO of Y Combinator


In [13]:
# Dataset with questions for subsidiary_of relationship

subsidiary_of_what = df[df['string_id']=='subsidiary_of']
subsidiary_of_what['answers'] = subsidiary_of_what.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
subsidiary_of_what['context'] = subsidiary_of_what['sentence']
subsidiary_of_what['id'] = 'sow' + subsidiary_of_what.index.map(str)
subsidiary_of_what['question'] = subsidiary_of_what.apply(lambda row: row.entities[0] + ' is a subsidiary of which company?', axis = 1)
subsidiary_of_what['title'] = subsidiary_of_what.apply(lambda row: row.entities[0] + ' subsidiary of ' + row.entities[1], axis = 1)

parent_company_of = df[df['string_id']=='subsidiary_of']
parent_company_of['answers'] = parent_company_of.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
parent_company_of['context'] = parent_company_of['sentence']
parent_company_of['id'] = 'pco' + parent_company_of.index.map(str)
parent_company_of['question'] = parent_company_of.apply(lambda row: 'Of which comapny is ' + row.entities[1] + ' a parent company?', axis = 1)
parent_company_of['title'] = parent_company_of.apply(lambda row: row.entities[0] + ' subsidiary of ' + row.entities[1], axis = 1)

frames = [subsidiary_of_what, parent_company_of]
subsidiary_of = pd.concat(frames)
subsidiary_of.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
subsidiary_of.sample(6)

Unnamed: 0,answers,context,id,question,title
12002,"{'answer_start': [53], 'text': ['Tel Aviv-base...","It’s worth noting that Metro Skyway, a subsidi...",sow12002,Metro Skyway is a subsidiary of which company?,Metro Skyway subsidiary of Tel Aviv-based Urba...
5339,"{'answer_start': [230], 'text': ['Microsoft']}","I mean, we've sort of joked for years about Mi...",sow5339,AAPL is a subsidiary of which company?,AAPL subsidiary of Microsoft
1034,"{'answer_start': [151], 'text': ['BP']}",When it comes to bringing electricity to UK fa...,sow1034,Zennor is a subsidiary of which company?,Zennor subsidiary of BP
9007,"{'answer_start': [76], 'text': ['Mosa Meat']}",It was announced this morning that Netherlands...,pco9007,Of which comapny is Blue Horizon Ventures a pa...,Mosa Meat subsidiary of Blue Horizon Ventures
10122,"{'answer_start': [0], 'text': ['Capital One']}",Capital One deserves credit for expanding its ...,pco10122,Of which comapny is Spark Miles for Business a...,Capital One subsidiary of Spark Miles for Busi...
1789,"{'answer_start': [0], 'text': ['Kosala Hemacha...",Kosala Hemachandra of MyEtherWallet explained ...,pco1789,Of which comapny is UX a parent company?,Kosala Hemachandra subsidiary of UX


In [14]:
# Dataset with questions for partners_with relationship

partners_with_1 = df[df['string_id']=='partners_with']
partners_with_1['answers'] = partners_with_1.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
partners_with_1['context'] = partners_with_1['sentence']
partners_with_1['id'] = 'pwo' + partners_with_1.index.map(str)
partners_with_1['question'] = partners_with_1.apply(lambda row: 'Which company is a partner of ' + row.entities[0] + '?', axis = 1)
partners_with_1['title'] = partners_with_1.apply(lambda row: row.entities[0] + ' partners with ' + row.entities[1], axis = 1)

partners_with_2 = df[df['string_id']=='partners_with']
partners_with_2['answers'] = partners_with_2.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
partners_with_2['context'] = partners_with_2['sentence']
partners_with_2['id'] = 'pw2' + partners_with_2.index.map(str)
partners_with_2['question'] = partners_with_2.apply(lambda row: 'Which company is a partner of ' + row.entities[1] + '?', axis = 1)
partners_with_2['title'] = partners_with_2.apply(lambda row: row.entities[0] + ' partners with ' + row.entities[1], axis = 1)

frames = [partners_with_1, partners_with_2]
partners_with = pd.concat(frames)
partners_with.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
partners_with.sample(6)

Unnamed: 0,answers,context,id,question,title
8242,"{'answer_start': [45], 'text': ['Stripe']}",The CRM business unveiled a partnership with S...,pwo8242,Which company is a partner of CRM?,CRM partners with Stripe
8177,"{'answer_start': [100], 'text': ['Ooredoo Kuwa...","Meanwhile, mobility solutions firm Comviva has...",pwo8177,Which company is a partner of Comviva?,Comviva partners with Ooredoo Kuwait
6333,"{'answer_start': [98], 'text': ['Uniqorn']}",While work on its SDK has still remained of pr...,pwo6333,Which company is a partner of XRApplied?,XRApplied partners with Uniqorn
9419,"{'answer_start': [56], 'text': ['PayPal']}","The Data Security Council of India, in partner...",pwo9419,Which company is a partner of Data Security Co...,Data Security Council of India partners with P...
2473,"{'answer_start': [67], 'text': ['IBM']}",Samsung Electronics today announced a new plan...,pwo2473,Which company is a partner of Samsung Electron...,Samsung Electronics partners with IBM
5438,"{'answer_start': [64], 'text': ['Iot Evolution']}",Huawei and Trustonic Expand Partnership with P...,pwo5438,Which company is a partner of Huawei?,Huawei partners with Iot Evolution


In [15]:
# Dataset with questions for owned_by relationship

who_owns = df[df['string_id']=='owned_by']
who_owns['answers'] = who_owns.apply(lambda row: {'answer_start': [row.entity_spans[1][0]],
                                                      'text': [row.entities[1]]}, axis = 1)
who_owns['context'] = who_owns['sentence']
who_owns['id'] = 'wo' + who_owns.index.map(str)
who_owns['question'] = who_owns.apply(lambda row: 'Who owns ' + row.entities[0] + '?', axis = 1)
who_owns['title'] = who_owns.apply(lambda row: row.entities[0] + ' owned by ' + row.entities[1], axis = 1)

ownes_what = df[df['string_id']=='owned_by']
ownes_what['answers'] = ownes_what.apply(lambda row: {'answer_start': [row.entity_spans[0][0]],
                                                      'text': [row.entities[0]]}, axis = 1)
ownes_what['context'] = ownes_what['sentence']
ownes_what['id'] = 'ow' + ownes_what.index.map(str)
ownes_what['question'] = ownes_what.apply(lambda row: 'Which company does ' + row.entities[1] + ' own?', axis = 1)
ownes_what['title'] = ownes_what.apply(lambda row: row.entities[0] + ' owned by ' + row.entities[1], axis = 1)

frames = [who_owns, ownes_what]
owned_by = pd.concat(frames)
owned_by.drop(['end_idx','entities','entity_spans','match','original_article','sentence','start_idx','string_id'],axis=1,inplace=True)
owned_by.sample(6)

Unnamed: 0,answers,context,id,question,title
7340,"{'answer_start': [0], 'text': ['Thomson-Reuter...",Thomson-Reuters is owned by the powerful Canad...,ow7340,Which company does Canadian Thomson own?,Thomson-Reuters owned by Canadian Thomson
6729,"{'answer_start': [57], 'text': ['Nielsen']}","SuperData Research, another market analyst fir...",wo6729,Who owns SuperData Research?,SuperData Research owned by Nielsen
5649,"{'answer_start': [69], 'text': ['TikTok']}",Recall that Trump's executive order has prohib...,ow5649,Which company does ByteDance own?,TikTok owned by ByteDance
3008,"{'answer_start': [130], 'text': ['Bakkt']}","Venuto likes Intercontinental Exchange (ICE), ...",wo3008,Who owns New York Stock Exchange?,New York Stock Exchange owned by Bakkt
9689,"{'answer_start': [0], 'text': ['Uber Ele']}",Uber Ele me owned by China s Alibaba and priva...,ow9689,Which company does Alibaba own?,Uber Ele owned by Alibaba
5422,"{'answer_start': [161], 'text': ['Google']}",Google allows competitors to bid on trademarke...,wo5422,Who owns Google?,Google owned by Google


In [37]:
# Combine all and split to train and test

!pip install -q datasets

from sklearn.model_selection import train_test_split
from datasets import Dataset

frames = [founded_by, acquired_by, invested_in, CEO_of, subsidiary_of, partners_with, owned_by]
full_df = pd.concat(frames)
df_train, df_test = train_test_split(full_df, test_size=0.2)
train_dataset = Dataset.from_pandas(df_train)
test_dataset = Dataset.from_pandas(df_test)
train_dataset = train_dataset.remove_columns('__index_level_0__')
test_dataset = test_dataset.remove_columns('__index_level_0__')

In [42]:
# Preprocessing the training data

def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [43]:
# Perform tokenizing

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
max_length = 512 # The maximum allowed length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
tokenized_train = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
tokenized_val = test_dataset.map(prepare_train_features, batched=True, remove_columns=test_dataset.column_names)

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.8.2",
  "vocab_size": 30522
}

loading file https://huggingface.co/distilbert-base-uncased/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/0e1bbfda7f63a99bb52e3915dcf10c3c92122b827d92eb2d34ce94ee79ba486c.d789d64ebfe

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [46]:
# We will finetune a pretrained model

from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer, default_data_collator
import torch

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.d423bdf2f58dc8b77d5f5d18028d7ae4a72dcfd8f468e81fe979ada957a8c361
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.8.2",
  "vocab_size": 30522
}

loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/9c169103d7e5a73936dd2b627e42851bec0831212b677c637033ee4bce9a

In [47]:
# Additional training with our data (3 epochs)

args = TrainingArguments(
    f"test-questions",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

data_collator = default_data_collator

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running training *****
  Num examples = 19876
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3729


Epoch,Training Loss,Validation Loss
1,0.5464,0.375702
2,0.2951,0.321313
3,0.1793,0.356019


Saving model checkpoint to test-questions/checkpoint-500
Configuration saved in test-questions/checkpoint-500/config.json
Model weights saved in test-questions/checkpoint-500/pytorch_model.bin
tokenizer config file saved in test-questions/checkpoint-500/tokenizer_config.json
Special tokens file saved in test-questions/checkpoint-500/special_tokens_map.json
Saving model checkpoint to test-questions/checkpoint-1000
Configuration saved in test-questions/checkpoint-1000/config.json
Model weights saved in test-questions/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in test-questions/checkpoint-1000/tokenizer_config.json
Special tokens file saved in test-questions/checkpoint-1000/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 4970
  Batch size = 16
Saving model checkpoint to test-questions/checkpoint-1500
Configuration saved in test-questions/checkpoint-1500/config.json
Model weights saved in test-questions/checkpoint-1500/pytorch_model.bin
tokenizer 

TrainOutput(global_step=3729, training_loss=0.4066481606894442, metrics={'train_runtime': 3547.0158, 'train_samples_per_second': 16.811, 'train_steps_per_second': 1.051, 'total_flos': 1.2156449330700288e+16, 'train_loss': 0.4066481606894442, 'epoch': 3.0})

In [82]:
raw_predictions = trainer.predict(tokenized_val)

***** Running Prediction *****
  Num examples = 4970
  Batch size = 16


In [97]:
# Get the best predictions for the val dataset

import numpy as np

n_best_size = 1 # want the top prediction
max_answer_length = 40
answers = []

for i in range (len(df_test)):
  start_logits = raw_predictions.predictions[0][i]
  end_logits = raw_predictions.predictions[1][i]
  offset_mapping = validation_features[i]["offset_mapping"]
  # The first feature comes from the first example. For the more general case, we will need to be match the example_id to
  # an example index
  context = test_dataset[i]["context"]

  # Gather the indices the best start/end logits:
  start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
  end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
  valid_answers = []
  for start_index in start_indexes:
      for end_index in end_indexes:
          # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
          # to part of the input_ids that are not in the context.
          if (
              start_index >= len(offset_mapping)
              or end_index >= len(offset_mapping)
              or offset_mapping[start_index] is None
              or offset_mapping[end_index] is None
          ):
              continue
          # Don't consider answers with a length that is either < 0 or > max_answer_length.
          if end_index < start_index or end_index - start_index + 1 > max_answer_length:
              continue
          if start_index <= end_index: # We need to refine that test to check the answer is inside the context
              start_char = offset_mapping[start_index][0]
              end_char = offset_mapping[end_index][1]
              valid_answers.append(
                  {
                      "score": start_logits[start_index] + end_logits[end_index],
                      "text": context[start_char: end_char]
                  }
              )

  valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
  answers.append(valid_answers)

In [134]:
# Looking at some random questions and predictions

import random

random_rows = random.sample(range(len(df_test)), 10)

for i in random_rows:
  print("\nQuestion:")
  print(df_test.iloc[i]['question'])
  print("Answer:")
  print(df_test.iloc[i]['answers']['text'][0])
  print("Predicted:")
  print(answers[i])



Question:
Which company is a partner of DisruptAD?
Answer:
MENA
Predicted:
[{'score': 11.063078, 'text': 'MENA'}]

Question:
Which company does Jaak own?
Answer:
Dot Blockchain Media
Predicted:
[{'score': 16.93885, 'text': 'um that can b'}]

Question:
In which company Evan Gappelberg is the CEO?
Answer:
Nextech AR
Predicted:
[{'score': 15.248435, 'text': 'hange u'}]

Question:
In what did Third Swedish National Pension Fund invested?
Answer:
Readly
Predicted:
[{'score': 13.873442, 'text': 'In June '}]

Question:
Who acquired Uber?
Answer:
Postmates
Predicted:
[{'score': 14.369099, 'text': 'Meanwhile, according to Axios'}]

Question:
Who invested in CaptivateIQ?
Answer:
Accel
Predicted:
[{'score': 17.347366, 'text': 'Accel'}]

Question:
Who is the CEO of Husqvarna Group?
Answer:
Soil Scout
Predicted:
[{'score': 17.897842, 'text': 'T se'}]

Question:
In what did TikTok invested?
Answer:
Tencent
Predicted:
[{'score': 13.327038, 'text': 'ByteDance, briefly'}]

Question:
Who is the CEO of 

Not perfect, but we do get some correct answers