In [None]:
!pip install datasets
!pip install crosslingual-coreference

In [None]:
from crosslingual_coreference import Predictor
from collections import defaultdict
import ast
import re
import torch
import nltk
import json

In [None]:
if torch.cuda.is_available():
    # Set the device to GPU
    device = torch.device("cuda")
    print("Using GPU.")
else:
    # Set the device to CPU
    device = torch.device("cpu")
    print("No GPU available, using CPU.")

In [11]:
def extract_crossreferences(text):
  """
    Extracts cross-references from a given text.

    @:param text: The text from which to extract cross-references.  

    @:return clusters: The extracted cross-references.

  """
  # In this colab we are using 1 GPU (torch.cuda.device_count()), so to use the GPU we need to specify device = 0
  # device = -1 corresponds on running on CPU
  predictor = Predictor(
      language="it_core_news_lg", device=0, model_name="minilm"
  )

  clusters = predictor.predict(text)
  del clusters["resolved_text"] 

  coreferences = adjust_coreferences(clusters)

  return coreferences


def adjust_coreferences(coreferences):
  """
    Adjusts the given coreferences (avoid long nominal phrases and just extract names)

    @:param coreferences: The coreferences to be adjusted.    

    @:return new_coreferences: The adjusted coreferences.

  """

  new_coreferences = {"cluster_heads": defaultdict(), "clusters":[], "span2head":defaultdict()}
  for coreference, span in coreferences["cluster_heads"].items():
    
    is_up = any(char.isupper() for char in coreference)
    # Exclude coreferences that do not contain any uppercase letters (probably no names)
    if not is_up:
      continue
    # Try to reduce nominal phrases and only extract the names
    elif len(coreference) > 35:
      pattern = r'[A-Z](?:[a-z]+|[A-Z]+)(?:\s[A-Z][a-z]+)?\b'

      for matchh in re.finditer(pattern, coreference):
        start_match = matchh.start()
        end_match = matchh.end()
        new_coref = matchh.group()
        new_start_span = span[0]+start_match
        new_end_span = new_start_span + (end_match - start_match)
        new_span = [new_start_span, new_end_span]
        # Adjust coreferences spans
        new_coreferences["cluster_heads"][new_coref] = new_span
        new_coreferences["span2head"][(new_span[0], new_span[1])] = new_coref

        for cluster in coreferences["clusters"]:
          if cluster[0][0] == span[0] and cluster[0][1] == span[1]:
            new_cluster = [new_span] + cluster[1:]
            new_coreferences["clusters"].append(new_cluster)
    else:
      new_coreferences["cluster_heads"][coreference] = span
      new_coreferences["span2head"][(span[0], span[1])] = coreference
      for cluster in coreferences["clusters"]:
        if cluster[0][0] == span[0] and cluster[0][1] == span[1]:
          new_coreferences["clusters"].append(cluster)

  return new_coreferences

def update_data(text, labels, coreferences):

  """
  Updates the given text by:
   1) sentence tokenizing the text
   2) adding the coreferences at the beginning of the sentence , if needed
   3) adjust the labels spans to refer to the sentence in which they occurr and not the whole text and taking into consideration the added coreferences

  @:param
    text: The text where coreferences need to be added
    labels: The entities in the text
    coreferences: The coreferences extracted from the text.

  @return: new_sentences, new_labels, new_entities

  """

  # Sentence tokenize the text
  tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
  # Get the spans for each sentence
  spans = list(tokenizer.span_tokenize(text))
  # Extract the actual sentences
  sentences = [text[start:end] for start, end in spans]

  # Initialize the new labels per sentence that we will use for the binary classification
  new_labels =[0]*len(sentences)
  new_sentences = []
  new_entities = []

  for i in range(len(sentences)):
    # Check if there are coreferences that refer to this specific text
    sentence = sentences[i]
    has_coref = False
    new_sent = sentence

    if coreferences:

      for cluster in coreferences["clusters"]:
        for coreference in cluster[1:]:

      # If there are coreferences, add the head of the reference at the beginning of the sentence in square brackets 
          if spans[i][0] <= coreference[0] <= coreference[1] <= spans[i][1]:
            # Add the reference at the beginning of the sentence
            if coreferences["span2head"][(cluster[0][0], cluster[0][1])] not in new_sent:
              new_sent = "[" + coreferences["span2head"][(cluster[0][0], cluster[0][1])] + "] " + new_sent
              has_coref = True

    # Calculate length of added coreferences
    coref_length = len(new_sent) - len(sentences[i])

    if not has_coref:
      new_sentences.append(sentence)
    else:
      new_sentences.append(new_sent)

    new_entity_group = []

    # Check if there are entities inside that sentence --> if there are, assign a label 1, 0 otherwise
    for label in labels:

      if spans[i][0] <= label["start_offset"] <= label["end_offset"] <= spans[i][1] :
        # Change spans of entities to refer to the sentence length and add the length of the coreference
        new_start_offset = label["start_offset"] - spans[i][0] + coref_length
        new_end_offset = label["end_offset"] - spans[i][0] + coref_length
        new_entity = {"label":label["label"], "start_offset":new_start_offset, "end_offset": new_end_offset}
        new_labels[i] = 1

        new_entity_group.append(new_entity)
 

    new_entities.append(new_entity_group)

  return new_sentences, new_labels, new_entities

def prepare_data(data):

  """
  Prepare data for further processing.

  @:param data: The data of the current fold that need to be modified.

  @return: new_data: preprocessed dataset, sentence tokenized and with coreferences added

  """

  coreferences = []
  coreferences = extract_crossreferences(data["words"])
  sents, labels, entities = update_data(data["words"], data["labels"], coreferences)
  new_data = []
  for sent, label, entity in zip(sents, labels, entities):
    new_data.append({"text": sent, "label": label, "entities":entity})

  return new_data

def main():
  # Load and apply coreference to documents in each fold defined for the baseline
  for i in range(5):
    fold = []

    with open("dataset/baseline/fold0"+str(i)+".json", "r", encoding="utf-8") as fold_file:
      fold_file = json.load(fold_file)

    fold.extend(fold_file["Data"])

    new_data = []
    # Divide the text into sentences and apply coreference
    for data in fold:
      new_data.extend(prepare_data(data))
    dataset = {"Data": new_data}

    # Store the new fold in json files
    with open("dataset/pipeline/data_sent_coref_fold"+str(i)+".json", "w", encoding="utf-8") as out_file:
      json_data = json.dumps(dataset, ensure_ascii=False)
      out_file.write(json_data)

main()