Copyright 2023 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

```
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

```

# Decoder of the paper
# [Coreference Resolution through a seq2seq Transition-Based System](https://arxiv.org/abs/2211.12142)



```
@misc{https://doi.org/10.48550/arxiv.2211.12142,
  url = {https://arxiv.org/abs/2211.12142},
  author = {Bohnet, Bernd and Alberti, Chris and Collins, Michael},
  title = {Coreference Resolution through a seq2seq Transition-Based System},
  publisher = {TACL},
  year = {2023},
}

```



# Requirements


1.   Please start a mt5 sever with the checkpoint [Coref-mT5-XXL model](https://console.cloud.google.com/storage/browser/gresearch/correference_seq2seq/checkpoint_1140000)
2.   For instance, via rpc-calls send and get results. Not included in this colab.

**Important !!**

**The python predictor_fn needs to be implemented. This code is missing.**
**This method calls a prediction serverl.**

# Imports and libs

In [None]:
import functools
import multiprocessing
import time

import tensorflow as tf

import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [None]:
!git clone https://github.com/huggingface/transformers.git
!pip install ./transformers
!pip install sentencepiece

Cloning into 'transformers'...
remote: Enumerating objects: 132356, done.[K
remote: Counting objects: 100% (698/698), done.[K
remote: Compressing objects: 100% (238/238), done.[K
remote: Total 132356 (delta 434), reused 623 (delta 391), pack-reused 131658[K
Receiving objects: 100% (132356/132356), 127.53 MiB | 17.40 MiB/s, done.
Resolving deltas: 100% (99808/99808), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./transformers
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m49.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingfa

In [None]:
from transformers import T5Tokenizer
tokenizer_mt5 =  T5Tokenizer.from_pretrained('google/mt5-xxl')
tokenizer_mt5.encode('Hello world')

[30273, 4836, 1]

In [None]:
# NEEDS IMPLEMENTATION depending on the infrastructure
def predictor_fn(batches):
  """The method takes a list of batches and processes each batch.

    This could be also implemented as a loop that sends each batch to a server.
    The demo uses list of batches but always only a single batch
    with one sentence is included for this the demo.

    For the efficient processing of a larger document set the ability to process
    documents in parallel is kept.

  Args:
    batches: list[list[str]] list of batches wiht each bach a list of documents.

  Returns:
    mt5 predictions (typically consisting of output_text and scores).
  """
  raise NotImplementedError('method "predictor_fn" needs to be implement depending on the used infrastructure')

In [None]:

def predict_coreferences(batches, threads_to_use=1):
  """Predict coreferences of focus part (e.g. one sentence)"""

  with multiprocessing.pool.ThreadPool(threads_to_use) as pool:
    results = pool.map(predictor_fn, batches)

  return results

def extract_result_string(predictions):
  """Extract the results from prediction."""
  results = []
  for resp in  predictions:
    output_text = tf.make_ndarray(resp.outputs['output_text'])
    scores = tf.make_ndarray(resp.outputs['scores'])

    for text, score in zip(output_text, scores):
      text = text[0].decode('utf-8')
      results.append(text)
  return results

# Extract coreferences

In [None]:
# @title helper
def normalize_speaker(speaker_in):
  """Add '_' before and after speaker name if it does not contain it already"""
  if speaker_in == '-' or speaker_in == '__':
    return '_'

  speaker = speaker_in.replace(' ', '_')
  speaker = speaker.strip()

  if not speaker.startswith('_'):
    speaker = '_'+speaker
  if not speaker.endswith('_'):
    speaker = speaker+'_'
  return speaker


def match_mention_state(m, inputs, maps, position=None, debug=False, start_index=0):

  if '##' in m:
    index_num = m.index('##')
  else:
    if not m[0].startswith('['):
      print('get_chains::error ## not in split', m)
    index_num = len(m)

  if ']]' in inputs:
    end_index = inputs.index(']]')
  elif '**' in inputs:
    end_index = inputs.index('**')
  else:
    end_index = len(inputs)

  # m_clean = [x for x in m if x != '##']
  m_clean = []
  for x in m:
    if x != '##':
      m_clean.append(x)
    if x == '**':
      break

  # get context
  context = []
  found_num = False
  for s in m:
    if found_num:
      context.append(s)
    if '##' == s:
      found_num = True

  maps_index = 0
  indices = []
  for i in range(start_index, end_index):
    maps_index = i
    if inputs[i] == m_clean[0]:
      if inputs[i:i+len(m_clean)] == m_clean:
        indices.append((maps[maps_index], maps[maps_index + index_num  - 1]))

        if maps[maps_index + index_num  - 1] == -1:
          print('index negative', maps[maps_index:], ' index_num',  index_num)
          print('index negative', inputs[i:], ' index_num',  index_num)
          print(f'i {i} maps_index {maps_index}')


  if len(indices) == 0:
    print('none found match_mention', m)
    print('inputs', inputs)
    return []
  elif len(indices) > 1 and debug:
    print('match_mention: too many ', m,  indices, 'm_clean - use both')

  if (-1,-1) in indices:
    print('error for ',m, indices)
    return []

  return indices

def match_link_state(link, inputs, maps, cluster_name_to_cluster,
                     debug=True, node_wise=True):
  link_mentions = [m.split(' ') for m in link]
  links = []
  if len(link_mentions) == 1 and node_wise:
    m0 = link_mentions[0]
    try:
      index_m0 = match_mention_state(m0, inputs, maps, position=None)
      links = [index_m0]
    except Exception as e:
      print(str(e))
    return links


  m0 = link_mentions[0]
  m1 = link_mentions[1]

  if debug:
    print('match_link', m0, m1)

  # invert indices
  if m1[0].startswith('[') and len(m1[0]) > 0:
    cluster = cluster_name_to_cluster.get(m1[0], None)
    if cluster is not None:
      index_m1 = [cluster[-1]]
    else:
      print('cluster does not exists')
      return []
  else:
    index_m1 = match_mention_state(m1, inputs, maps, position=None)


  if debug:
    print(index_m1 ,'match' ,m1)

  if len(index_m1) > 1:
    print('index_m1', index_m1)

  try:
    index_m0 = match_mention_state(m0, inputs, maps, position=None)
  except Exception as e:
    print('error', str(e))
    index_m0 = []

  if debug:
    print(index_m0 ,'match' , m0)

  if len(index_m0) > 1:
    print('index_m0', index_m0)

  if len(index_m1) > 0 and len(index_m0) > 0:
      i1 = index_m1[-1]
      i2 = index_m0[-1]
      links.append([i1, i2])

  # use only last link
  if len(links) > 1:
    print('too many links, ', links, 'for link', link)
    print('context', inputs)

    return links[-1:]

  return links


def get_mentions_for_link_state(link, node_wise):
  link_split = link.split('->')

  if node_wise and len(link_split) == 1:
    m0 = link_split[0].strip()
    # print('link has only one mention?', link, m0)
    return [m0]

  elif len(link_split) < 2:
    print('link has only one mention - skipping mention', link)
    return []

  if len(link_split) > 2:
    print('link has too many mentions - using first two.', link)
  m0 = link_split[0].strip()
  m1 = link_split[1].strip()
  return [m0, m1]

In [None]:
# use mt5 and large context

class State(object):
  """Document state."""

  def __init__(self, input_document, node_wise=True, max_len_doc=3000):
    """ Create State object to process documents.

    Args:
      input_document: dictonary with the input document.
      node_wise: Predict mentions too.
      max_len_doc: max sentence pieace tokens, eg. 2000 or 3000 (bit better).

    """
    self.sentence_num = -1
    self.clusters_num = 0

    self.token_map_context, self.annotation_context = [], []
    self.annotation_coreference_start, self.annotation_coreference_end = [], []
    self.token_map, self.annotation = [], []

    # a mention index to cluster mapping, e.g. (23, 24) -> [(23, 24), (41, 42)]
    self.mention_index_to_cluster = {}

    # the first link names the cluster, e.g. (23, 24) -> '1'
    self.mention_index_to_cluster_name = {}
    self.cluster_name_to_cluster = {}

    self.input_document = input_document
    # print('sentence_num', self.sentence_num)
    self.genre = input_document['genres'][0][0]
    self.speakers = {t: spk for (t, spk) in self.input_document['speakers']}

    self.done = False
    self.predictions_str = {}  # keep the predictions
    self.node_wise = node_wise

    self.max_len_doc = max_len_doc

    # move to initial position.
    self.extend()


  def extend_done(self):
    return self.done

  def extend(self, prediction_str=None, use_gold_cluster=False, move=True):

    # move annotation to context
    self.token_map_context +=  self.token_map
    self.annotation_context += self.annotation

    for k in range(len(self.annotation)):
      self.annotation_coreference_start.append([])
      self.annotation_coreference_end.append([])

    assert len(self.annotation_context)  == len(self.annotation_coreference_start)

    self.annotation, self.token_map = [], []

    link_found = False
    if prediction_str is not None and not 'None [' in prediction_str:
      links = [l for l in prediction_str.split(';;') if l != '' ]

      annotation_update = []
      for link in links:
        link_found = True
        link_mentions = get_mentions_for_link_state(link, self.node_wise)

        if len(link_mentions) < 2 and not (self.node_wise and len(link_mentions)):
          print('less mentions as needed skip', link_mentions)
          continue
        indices = match_link_state(link_mentions, self.annotation_full,
                                   self.annotation_full_map,
                                   self.cluster_name_to_cluster,
                                   debug=False)

        if not indices:
          print('not found !!')
          print('indices not found', link, indices)
          print('self.annotation_full', self.annotation_full )
          print('annotation + context', self.get_input_annotation())
          continue

        if True:
          index = indices[0]
          cluster = []
          for mention_index in index:
            if str(mention_index) in self.mention_index_to_cluster:
              cluster = self.mention_index_to_cluster[str(mention_index)]
              break


          if not cluster:
            self.clusters_num += 1
            cluster_name = str(self.clusters_num)

            if use_gold_cluster:  # just to evaluate on gold

              for ni, cx in enumerate(self.input_document['clusters']):
                for mx in cx:
                  if mx in index:
                    cluster_name = str(ni+1)
                    break

          else:
            cluster_name = self.mention_index_to_cluster_name[str(cluster[0])]

          for mention_index in index:
            if mention_index not in cluster:
              cluster.append(mention_index)
              self.mention_index_to_cluster[str(mention_index)] = cluster
              self.cluster_name_to_cluster['['+cluster_name] = cluster
              self.mention_index_to_cluster_name[str(mention_index)] = cluster_name
              annotation_update.append([mention_index, cluster_name])

      # update the annotation
      if True:
        for update in annotation_update:
          update_index = update[0]
          update_name = update[1]

          for t, coref_starts, coref_end, tid in zip(self.annotation_context,
                                    self.annotation_coreference_start,
                                    self.annotation_coreference_end,
                                    self.token_map_context):



            if update_index[0] == tid:
              coref_starts.append(update)
              coref_starts.sort( key=lambda x: x[0][1], reverse=True)


            if update_index[1] == tid:
              coref_end.append(']')

    if move or 'None [' in prediction_str or not link_found:
      self.sentence_num += 1

      if self.sentence_num not in self.input_document['sentences']:
        self.done = True
        return True

      tokens = self.input_document['sentences'][self.sentence_num]
      token_map = self.input_document['token_maps'][self.sentence_num]
      first = True


      for tid, t in zip(token_map, tokens):
        if first:
          self.token_map.append(-1)
          speaker = normalize_speaker(self.speakers[tid])
          self.annotation.append(speaker)
          first = False
        self.token_map.append(tid)
        self.annotation.append(t)

    if self.sentence_num not in self.predictions_str:
      self.predictions_str[self.sentence_num] = ''

    if prediction_str is not None:
      self.predictions_str[self.sentence_num] += prediction_str

    return False

  def input_annotation(self):

    self.annotation_full = ['coref:', self.genre]
    self.annotation_full_map = [-1, -1]
    for t, coref_starts, coref_end, tid in zip(self.annotation_context,
                                  self.annotation_coreference_start,
                                  self.annotation_coreference_end,
                                  self.token_map_context):

      for coref_start in coref_starts:
        coref_name = coref_start[-1]

        self.annotation_full.append('[' + coref_name)
        self.annotation_full_map.append(-1)

      self.annotation_full.append(t)
      self.annotation_full_map.append(tid)

      for end in coref_end:
        coref_name = end[-1]
        self.annotation_full.append(coref_name)
        self.annotation_full_map.append(-1)

    self.annotation_full += ['|'] + self.annotation
    self.annotation_full_map += [-1] + self.token_map
    self.annotation_full += ['**']
    self.annotation_full_map += [-1]


  def encode(self, annotation_str):
    return tokenizer_mt5.encode(annotation_str)

  def get_input_annotation(self, context_right=True):

    self.input_annotation()
    annotation_str = ' '.join(self.annotation_full)

    enc = self.encode(annotation_str)
    shorten = len(enc) > self.max_len_doc

    while len(enc) > self.max_len_doc:   # inefficient ...
      self.annotation_context = self.annotation_context[1:]
      self.token_map_context = self.token_map_context[1:]
      self.annotation_coreference_start = self.annotation_coreference_start[1:]
      self.annotation_coreference_end = self.annotation_coreference_end[1:]

      self.input_annotation()
      annotation_str = ' '.join(self.annotation_full)
      enc = self.encode(annotation_str)

    last_token_id = self.annotation_full_map[-2]  # the last one is **
    self.annotation_context_right = []

    if not shorten and context_right:
      sentence_num = self.sentence_num
      total_len = len(enc)

      while True:
        sentence_num += 1
        if sentence_num not in self.input_document['sentences']:
          break

        first = True
        annotation_context_next = []

        for t, tid in zip(self.input_document['sentences'][sentence_num], self.input_document['token_maps'][sentence_num]):
          if first:
            speaker = normalize_speaker(self.speakers[tid])
            annotation_context_next.append(speaker)
            first = False
          annotation_context_next.append(t)

        annotation_context_right = self.annotation_context_right + annotation_context_next
        enc = self.encode(' '.join(annotation_context_right))

        if (len(enc) + total_len) > self.max_len_doc:
          break
        self.annotation_context_right = annotation_context_right
      if self.annotation_context_right:
        annotation_str = annotation_str + ' ' + ' '.join(self.annotation_context_right)

    enc = self.encode(annotation_str)
    if len(enc) > self.max_len_doc:
      print('warning: document too long', len(enc))

    return annotation_str


In [None]:
# @title function def to reate a input state
tokenizer_nltk = nltk.WordPunctTokenizer()


def create_document(document: str, title: str = 'not_named'):
  """Creates a datastructure with a title and uses nltk for tokenization.

  Args:
    document: sentences separated with newline ('\n').
    title: the name of the document.

  Returns:
    dict with sentences, maps to token-ids, speakers, and genres.
  """
  input_document = {
      'doc_key': title,
      'sentences': {},
      'token_maps': {},
      'speakers': [],
      'genres': []
  }

  tid = 0
  for k, sentence in enumerate(document.split('\n')):
    input_document['sentences'][k] = tokenizer_nltk.tokenize(text=sentence)
    input_document['token_maps'][k] = []

    for _ in input_document['sentences'][k]:
      input_document['token_maps'][k].append(tid)
      input_document['speakers'].append((tid, '_'))
      input_document['genres'].append('wi')
      tid += 1
  return input_document

In [None]:
# @title function def to create batches
def create_next_batch(states_dict, batche_size=1, num_batches=1):
  batches = [[]]
  states = []
  for key, state in states_dict.items():
    if state.extend_done():
      continue

    states.append(state)
    if len(states) >= (batche_size * num_batches):
      break
  for state in states:
    batches[-1].append(state.get_input_annotation())
    if len(batches[-1]) >= batche_size:
      if len(batches) >= num_batches:
        break
      batches.append([])
  return states, batches


In [None]:
# @title sample document - from wikipedia
d1_title = "Eiffel Tower Wiki"
d1 = """The Eiffel Tower (French: tour Eiffel) is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.
Locally nicknamed "La dame de fer" (French for "Iron Lady"), it was constructed from 1887 to 1889 as the centerpiece of the 1889 World's Fair.
Although initially criticised by some of France's leading artists and intellectuals for its design, it has since become a global cultural icon of France and one of the most recognisable structures in the world.
The Eiffel Tower is the most visited monument with an entrance fee in the world: 6.91 million people ascended it in 2015.
It was designated a monument historique in 1964, and was named part of a UNESCO World Heritage Site ("Paris, Banks of the Seine") in 1991."""

input_document = create_document(d1, d1_title)
states_dict = {d1_title: State(input_document)}

In [None]:
emulate_predictions = True # @param

if emulate_predictions:
  predictioned_results = [
      [
          'It ## is named after -> The Eiffel Tower ( French : tour Eiffel ) ##'
          ' is a wrought ;; the tower ## . ** _ -> It ## is named after ;;'
      ],
      ['it ## was constructed from -> [1 ;;'],
      [
          'its ## design , it -> [1 ;; it ## has since become -> its ## design'
          " , it ;; France ' s ## leading artists and -> France ## . [1 It ;;"
          " France ## and one of -> France ' s ## leading artists and ;;"
      ],
      [
          'The Eiffel Tower ## is the most -> [1 ;; it ## in 2015 . -> The'
          ' Eiffel Tower ## is the most ;; the world ## : 6 . -> the world ## .'
          ' | _ ;;'
      ],
      [
          'It ## was designated a -> [1 ;; Paris ## , Banks of -> Paris , [2'
          ' France ## ] . [1 ;;'
      ],
  ]
else:
  predictioned_results = []

In [None]:
expand_only = False
total_time = time.time()
total_results = 0

debug = True

for step in range(100):  # while states
  t = time.time()
  states, batches = create_next_batch(states_dict)

  if not states:
    break

  documents_processing = set([x.input_document['doc_key'] for x in states])

  print(f'Processing documents: {documents_processing}')

  if predictioned_results:
    results = predictioned_results[step]
  else:
    predictions = predict_coreferences(batches, len(batches))
    results = extract_result_string(predictions)

  for state, result, batch in zip(states, results, batches):
    state.extend(result)

    if debug:
      print('input batch[0]: ', batch)
      print('mt5 output:     ', results)

  total_results += len(results)
  print(
      f'time { time.time()-t}, round time/seq : {(time.time()-t)/len(results)}'
      f' total time/seq: {(time.time()-total_time)/total_results}'
  )
  print()

Processing documents: {'Eiffel Tower Wiki'}
input batch[0]:  ['coref: w | _ The Eiffel Tower ( French : tour Eiffel ) is a wrought - iron lattice tower on the Champ de Mars in Paris , France . It is named after the engineer Gustave Eiffel , whose company designed and built the tower . ** _ Locally nicknamed " La dame de fer " ( French for " Iron Lady "), it was constructed from 1887 to 1889 as the centerpiece of the 1889 World \' s Fair . _ Although initially criticised by some of France \' s leading artists and intellectuals for its design , it has since become a global cultural icon of France and one of the most recognisable structures in the world . _ The Eiffel Tower is the most visited monument with an entrance fee in the world : 6 . 91 million people ascended it in 2015 . _ It was designated a monument historique in 1964 , and was named part of a UNESCO World Heritage Site (" Paris , Banks of the Seine ") in 1991 .']
mt5 output:      ['It ## is named after -> The Eiffel Tower ( F

In [None]:
# @title get and print the output as annotated document

for doc_name, s in states_dict.items():
  pred_clusters = [cluster for name, cluster in s.cluster_name_to_cluster.items()]
  print('predicted clusters with word indexes', pred_clusters)

  text, text_map = [], []
  for k, snt in states_dict[doc_name].input_document['sentences'].items():
    m = states_dict[doc_name].input_document['token_maps'][k]
    text += snt
    text_map += m

  cluster_annotations_start = []
  cluster_annotations_end = []

  # Cluster annotation per token
  for tid in text_map:
    cluster_annotations_start.append([])
    cluster_annotations_end.append([])
    for ci in pred_clusters:
      for m in ci:

        if tid == m[0]:
          m_len = m[1] - m[0]
          name = s.mention_index_to_cluster_name[str(m)]
          cluster_annotations_start[-1].append((name, m_len))

        if tid == m[1]:
          cluster_annotations_end[-1].append(']')

  # get the text with the coreference annotations
  all_text = []
  for tok, start, end in zip(text, cluster_annotations_start, cluster_annotations_end):

    if start:
      for x in sorted(start, key=lambda x : x[1], reverse=True):
        all_text.append('['+str(x[0]))

    all_text.append(tok)

    if end:
      all_text.append(''.join(end))

  print(' '.join(all_text))

predicted clusters with word indexes [[(0, 8), (26, 26), (40, 41), (58, 58), (90, 90), (93, 93), (114, 116), (136, 136), (140, 140)], [(24, 24), (82, 84), (102, 102)], [(111, 112), (127, 128)], [(22, 24), (160, 160)]]
[1 The Eiffel Tower ( French : tour Eiffel ) ] is a wrought - iron lattice tower on the Champ de Mars in [4 Paris , [2 France ]] . [1 It ] is named after the engineer Gustave Eiffel , whose company designed and built [1 the tower ] . Locally nicknamed " La dame de fer " ( French for " Iron Lady "), [1 it ] was constructed from 1887 to 1889 as the centerpiece of the 1889 World ' s Fair . Although initially criticised by some of [2 France ' s ] leading artists and intellectuals for [1 its ] design , [1 it ] has since become a global cultural icon of [2 France ] and one of the most recognisable structures in [3 the world ] . [1 The Eiffel Tower ] is the most visited monument with an entrance fee in [3 the world ] : 6 . 91 million people ascended [1 it ] in 2015 . [1 It ] was