In [1]:
from SPARQLWrapper import SPARQLWrapper, JSON
import pandas as pd
from collections import Counter
#import pywikibot
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
from typing import List
import re

In [2]:
# initialize tokenizer, model

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

NER_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

# aggregation strategy: if left undefined, will default to breaking up some entities into subwords
nlp = pipeline('ner',tokenizer = tokenizer,  model = NER_model, aggregation_strategy = "max" )

In [57]:
text = "This is a sentence about spacewoman Adri and NASA and SpaceX and the German Aerospace Center."
entities = nlp(text)

In [58]:
entities

[{'entity_group': 'PER',
  'score': 0.9989241,
  'word': 'Adri',
  'start': 36,
  'end': 40},
 {'entity_group': 'ORG',
  'score': 0.9992776,
  'word': 'NASA',
  'start': 45,
  'end': 49},
 {'entity_group': 'ORG',
  'score': 0.99895793,
  'word': 'SpaceX',
  'start': 54,
  'end': 60},
 {'entity_group': 'ORG',
  'score': 0.99821335,
  'word': 'German Aerospace Center',
  'start': 69,
  'end': 92}]

In [36]:
type(entities[0]['entity_group'])

str

In [59]:
def extract_id_entity(entities: List[dict]) -> List[(tuple)]:
    """
    :param entities: List of dictionaries that contain one dict for each entitiy
    :return: List of ( 'Name', 'wiki id') tuples, for the tuples that exist
    """

    q_ids = []
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql", agent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36')
    for ent in entities:
        entity_group = ent['entity_group']
        if entity_group == 'ORG':
            new_query = "SELECT ?item WHERE {?item rdfs:label " + "'" + str(entity_group) + "'" +  "@en}"

            sparql.setQuery( new_query)
            sparql.setReturnFormat(JSON)
            results = sparql.query().convert()
            intermediate = results['results']['bindings']
            if len(intermediate)>0:
                #print(intermediate)
                #print("length of intermediate is " + str(len(intermediate)))

                first_result = intermediate[0]
                url = first_result['item']['value']
                q_id = url[31:]

                tuple = (ent['word'], q_id)
            else:
                raise ValueError( " no Wikidata entry for the entity: " + str(ent))

            q_ids.append(tuple)

    return q_ids



In [60]:
ids = extract_id_entity(entities)

In [61]:
ids

[('NASA', 'Q448480'),
 ('SpaceX', 'Q448480'),
 ('German Aerospace Center', 'Q448480')]

In [62]:
def company_checker(ids: List[tuple]) -> List[dict] :
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql", agent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36')


    sparql.setQuery("""
    SELECT
    ?company ?companyLabel ?countryLabel

    WHERE
    {
    ?article schema:inLanguage "en" .
    ?article schema:isPartOf <https://en.wikipedia.org/>.
    ?article schema:about ?company .

    ?company p:P31/ps:P31/wdt:P279* wd:Q11753232.

    OPTIONAL {?company wdt:P17 ?country.}

    SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
    }

    """)
    sparql.setReturnFormat(JSON)
    companies = sparql.query().convert()
    results_comp = companies['results']['bindings']

    country_labels = []

    for id in ids:
        for result in results_comp:
            #print(type(id))


            name = result['companyLabel']['value']

            #print(type(id_name_str))
            #print(type(id_name))
            print("ids: " + id[0] + " name: " +  name)


            if id[0] == name:
                print('success')
                org_label = result['companyLabel']['value']
                #print("company label is " + company_label)

                if 'countryLabel' in result.keys():
                    country_label = result['countryLabel']['value']

                else:
                    country_label = 'Not in Wikidata'
                #company_label = 'not a company'
                #country_label = 'not a company'

                tuple = (org_label,country_label)
                country_labels.append(tuple)
            else:
                continue
    return country_labels


In [63]:
test = company_checker(ids)

ids: NASA name: Rocket Lab
ids: NASA name: SpaceX
ids: NASA name: Arianespace
ids: NASA name: ISC Kosmotras
ids: NASA name: Starsem
ids: NASA name: Mitsubishi Heavy Industries
ids: NASA name: Sea Launch
ids: NASA name: Convair
ids: NASA name: Blue Origin
ids: NASA name: Orbital Sciences Corporation
ids: NASA name: United Launch Alliance
ids: NASA name: International Launch Services
ids: NASA name: Eurockot Launch Services
ids: NASA name: Antrix Corporation
ids: NASA name: COSMOS International
ids: NASA name: China Aerospace Science and Industry Corporation
ids: NASA name: Firefly Aerospace
ids: NASA name: Generation Orbit Launch Services
ids: NASA name: Astra, Inc.
ids: NASA name: Northrop Grumman Innovation Systems
ids: NASA name: PLD Space
ids: NASA name: Vector Space Systems
ids: NASA name: Virgin Orbit
ids: NASA name: LandSpace
ids: NASA name: LinkSpace
ids: NASA name: SpinLaunch
ids: NASA name: Skyrora
ids: NASA name: OneSpace
ids: NASA name: Alba Orbital
ids: NASA name: Skyroot A

In [64]:
test

[('SpaceX', 'United States of America')]