In [None]:
import tensorflow.compat.v2 as tf
import multilingual_t5.preprocessors
import multilingual_t5.tasks
import multilingual_t5.utils
from __future__ import print_function
import collections
import re
import string
import sys
import unicodedata
from t5.evaluation import qa_utils
from absl.testing import absltest
from multilingual_t5.evaluation import metrics
from t5.evaluation import test_utils
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union      
from flax import linen as nn
from flax import optim
from flax.core import scope as flax_scope
from flax.training import common_utils
import jax
from jax import lax
import jax.numpy as jnp
import ml_collections
import numpy as np
import seqio
from t5x import losses as t5x_losses
from t5x import metrics as metrics_lib
from t5x import models as t5x_models
from t5x import utils as t5x_utils
from t5x_retrieval import utils
import tensorflow as tf
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union
from flask import Flask
from flask import request
from flask import jsonify
from flask import request
import requests
from googletrans import Translator
import string


from transformers import pipeline



def _string_join(lst): 
  out = tf.strings.join(lst, separator=' ')
  return tf.strings.regex_replace(out, r'\s+', ' ')
def _pad_punctuation(text):
  text = tf.strings.regex_replace(text, r'([^_\s\p{N}\p{L}\p{M}])', r' \1 ')
  text = tf.strings.regex_replace(text, r'\s+', ' ')
  return text

def xnli_map_hypothesis_premise(dataset, target_language): #XNLI dataset preparation
  def _process(x):
    languages = x['hypothesis']['language']
    translations = x['hypothesis']['translation']
    label = tf.fill(tf.shape(languages), x['label'])
    premise = tf.fill(tf.shape(languages), x['premise'][target_language])

    return {
        'language': languages,
        'translation': translations,
        'label': label,
        'premise': premise
    }

  dataset = dataset.map(
      _process, num_parallel_calls=tf.data.experimental.AUTOTUNE).unbatch()
  dataset = dataset.filter(
      lambda x: tf.math.equal(x['language'], target_language))
  return dataset

def wikiann(dataset):
  def _process(x, delimiter=' $$ '):
    inputs = 'tag: ' + tf.strings.reduce_join(x['tokens'], separator=' ')
    targets = tf.strings.reduce_join(x['spans'], separator=delimiter)

    return {
        'inputs': inputs,
        'targets': targets,
        'tokens': x['tokens'],
        'tags': x['tags'],
        'langs': x['langs'],
        'spans': x['spans']
    }
  return dataset.map(_process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def process_mnli(dataset):

  def _process(x):
    return {
        'inputs': tf.strings.join(['xnli: premise: ', x['premise'],
                                   ' hypothesis: ', x['hypothesis']]),
        'targets': tf.strings.as_string(x['label'])
    }

  return dataset.map(_process, num_parallel_calls=tf.data.experimental.AUTOTUNE)


def process_xnli(dataset, target_languages):
  def _process(x):
    return {
        'inputs': tf.strings.join(['xnli: premise: ', x['premise'],
                                   ' hypothesis: ', x['translation']]),
        'targets': tf.strings.as_string(x['label'])
    }

  output = []
  for language in target_languages:
    examples = xnli_map_hypothesis_premise(dataset, target_language=language)
    d = examples.map(_process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    output.append(d)

  output_dataset = output[0]
  for lang_dataset in output[1:]:
    output_dataset = output_dataset.concatenate(lang_dataset)
  return output_dataset

In [None]:

def normalize_mlqa(s, lang, punct):


  whitespace_langs = ['en', 'es', 'hi', 'vi', 'de', 'ar']
  mixed_segmentation_langs = ['zh']

  def whitespace_tokenize(text):
    return text.split()

  def mixed_segmentation(text):
    segs_out = []
    temp_str = ''
    for char in text:
      if re.search(r'[\u4e00-\u9fa5]', char) or char in punct:
        if temp_str != '':
          ss = whitespace_tokenize(temp_str)
          segs_out.extend(ss)
          temp_str = ''
        segs_out.append(char)
      else:
        temp_str += char
    if temp_str != '':
      ss = whitespace_tokenize(temp_str)
      segs_out.extend(ss)
    return segs_out

  def drop_articles(text, lang):
    if lang == 'en':
      return re.sub(r'\b(a|an|the)\b', ' ', text)
    elif lang == 'es':
      return re.sub(r'\b(un|una|unos|unas|el|la|los|las)\b', ' ', text)
    elif lang == 'hi':
      return text
    elif lang == 'vi':
      return re.sub(r'\b(của|là|cái|chiếc|những)\b', ' ', text)
    elif lang == 'de':
      return re.sub(
          r'\b(ein|eine|einen|einem|eines|einer|der|die|das|den|dem|des)\b',
          ' ', text)
    elif lang == 'ar':
      return re.sub('\sال^|ال', ' ', text)
    elif lang == 'zh':
      return text

  def white_space_fix(text, lang):
    if lang in whitespace_langs:
      tokens = whitespace_tokenize(text)
    elif lang in mixed_segmentation_langs:
      tokens = mixed_segmentation(text)
    return ' '.join([t for t in tokens if t.strip()])

  def drop_punc(text):
    return ''.join(c for c in text if c not in punct)

  s = s.lower()
  s = drop_punc(s)
  s = drop_articles(s, lang)
  s = white_space_fix(s, lang)
  return s


def mlqa(targets, predictions, lang=None):

  assert lang is not None
  punct = {
      chr(i)
      for i in range(sys.maxunicode)
      if unicodedata.category(chr(i)).startswith('P')
  }.union(string.punctuation)
  targets = [[normalize_mlqa(t, lang, punct) for t in u] for u in targets]
  predictions = [normalize_mlqa(p, lang, punct) for p in predictions]
  return qa_utils.qa_metrics(targets, predictions)


def span_f1(targets, predictions):

  true_positives = collections.defaultdict(int)
  false_positives = collections.defaultdict(int)
  false_negatives = collections.defaultdict(int)

  def tags_to_spans(tag_sequence, delimiter=' $$ '):
    tag_sequence_split = [x.strip() for x in tag_sequence.split(delimiter)]
    tags_entities = []
    for tag_entity in tag_sequence_split:
      tag_entity_split = tag_entity.split(':')
      if len(tag_entity_split) != 2:
        continue
      tag = tag_entity_split[0].strip()
      entity = tag_entity_split[1].strip()
      tags_entities.append((tag, entity))
    return tags_entities

  def compute_f1_metrics(true_positives, false_positives, false_negatives):
    precision = float(true_positives) / float(true_positives + false_positives +
                                              1e-13)
    recall = float(true_positives) / float(true_positives + false_negatives +
                                           1e-13)
    f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13))
    return precision, recall, f1_measure

  for target, pred in zip(targets, predictions):
    gold_spans = tags_to_spans(target)
    predicted_spans = tags_to_spans(pred)

    for span in predicted_spans:
      if span in gold_spans:
        true_positives[span[0]] += 1
        gold_spans.remove(span)
      else:
        false_positives[span[0]] += 1
    # These spans weren't predicted.
    for span in gold_spans:
      false_negatives[span[0]] += 1

  _, _, f1_measure = compute_f1_metrics(
      sum(true_positives.values()), sum(false_positives.values()),
      sum(false_negatives.values()))

  return {'span_f1': f1_measure}

In [None]:

class MetricsTest(test_utils.BaseMetricsTest):

  def test_same_mlqa(self):
    ref = "this is a string"
    self.assertDictClose(
        metrics.mlqa([["", ref], [ref, ref]], [ref, ref], lang="en"), {
            "em": 100,
            "f1": 100,
        })

  def test_different_mlqa(self):
    ref = "this is a string"
    self.assertDictClose(
        metrics.mlqa([[ref, ref], [ref, ref]], ["", ""], lang="en"), {
            "em": 0,
            "f1": 0
        })

  def test_article_drop_mlqa(self):
    ref = "this unas a string"
    pred = "this a string"
    self.assertDictClose(
        metrics.mlqa([[ref]], [pred], lang="es"), {
            "em": 100,
            "f1": 100,
        })

  def test_mlqa_small(self):
    self.assertDictClose(
        metrics.mlqa([["abc abd", "$$$$"]], ["abd"], lang="en"),
        {"f1": 100 * 2.0 / 3.0, "em": 0.},
    )
if __name__ == "__main__":
  absltest.main()

  """Add Tasks to registry."""
import functools

import seqio
import t5.data

tsv_english_path = {
        "train": "/home/mac/datasets/train_english.tsv",
        "validation": "/home/mac/datasets/dev_english.tsv",
}

tsv_german_path = {
        "train": "/home/mac/datasets/train_german.tsv",
        "validation": "/home/mac/datasets/dev_german.tsv",
}

tsv_arabic_path = {
        "train": "/home/mac/datasets/train_arabic.tsv",
        "validation": "/home/mac/datasets/dev_arabic.tsv",
}

tsv_chinese_path = {
        "train": "/home/mac/datasets/train_chinese.tsv",
        "validation": "/home/mac/datasets/dev_chinese.tsv",
}

tsv_dutch_path = {
        "train": "/home/mac/datasets/train_dutch.tsv",
        "validation": "/home/mac/datasets/dev_dutch.tsv",
}

tsv_french_path = {
        "train": "/home/mac/datasets/train_french.tsv",
        "validation": "/home/mac/datasets/dev_french.tsv",
}

tsv_hindi_path = {
        "train": "/home/mac/datasets/train_hindi.tsv",
        "validation": "/home/mac/datasets/dev_hindi.tsv",
}

tsv_indonesian_path = {
        "train": "/home/mac/datasets/train_indonesian.tsv",
        "validation": "/home/mac/datasets/dev_indonesian.tsv",
}

tsv_japanese_path = {
        "train": "/home/mac/datasets/train_japanese.tsv",
        "validation": "/home/mac/datasets/dev_japanese.tsv",
}

tsv_portuguese_path = {
        "train": "/home/mac/datasets/train_portuguese.tsv",
        "validation": "/home/mac/datasets/dev_portuguese.tsv",
}

tsv_russian_path = {
        "train": "/home/mac/datasets/train_russian.tsv",
        "validation": "/home/mac/datasets/dev_russian.tsv",
}

language_to_path = {
    "arabic": tsv_arabic_path,
    "english": tsv_english_path,
    "russian": tsv_russian_path,
    "portuguese": tsv_portuguese_path,
    "japanese": tsv_japanese_path,
    "hindi": tsv_hindi_path,
    "indonesian": tsv_indonesian_path,
    "dutch": tsv_dutch_path,
    "german": tsv_german_path,
    "chinese": tsv_chinese_path,
    "french": tsv_french_path,

}

tsv_english_negatives_path = {
        "train": "/home/mac/datasets/train_english_negatives.tsv",
        "validation": "/home/mac/datasets/dev_english.tsv",
}

tsv_german_negatives_path = {
        "train": "/home/mac/datasets/train_german_negatives.tsv",
        "validation": "/home/mac/datasets/dev_german.tsv",
}

tsv_arabic_negatives_path = {
        "train": "/home/mac/datasets/train_arabic_negatives.tsv",
        "validation": "/home/mac/datasets/dev_arabic.tsv",
}

tsv_chinese_negatives_path = {
        "train": "/home/mac/datasets/train_chinese_negatives.tsv",
        "validation": "/home/mac/datasets/dev_chinese.tsv",
}

tsv_dutch_negatives_path = {
        "train": "/home/mac/datasets/train_dutch_negatives.tsv",
        "validation": "/home/mac/datasets/dev_dutch.tsv",
}

tsv_french_negatives_path = {
        "train": "/home/mac/datasets/train_french_negatives.tsv",
        "validation": "/home/mac/datasets/dev_french.tsv",
}

tsv_hindi_negatives_path = {
        "train": "/home/mac/datasets/train_hindi_negatives.tsv",
        "validation": "/home/mac/datasets/dev_hindi.tsv",
}

tsv_indonesian_negatives_path = {
        "train": "/home/mac/datasets/train_indonesian_negatives.tsv",
        "validation": "/home/mac/datasets/dev_indonesian.tsv",
}

tsv_japanese_negatives_path = {
        "train": "/home/mac/datasets/train_japanese_negatives.tsv",
        "validation": "/home/mac/datasets/dev_japanese.tsv",
}

tsv_portuguese_negatives_path = {
        "train": "/home/mac/datasets/train_portuguese_negatives.tsv",
        "validation": "/home/mac/datasets/dev_portuguese.tsv",
}

tsv_russian_negatives_path = {
        "train": "/home/mac/datasets/train_russian_negatives.tsv",
        "validation": "/home/mac/datasets/dev_russian.tsv",
}

tsv_spanish_negatives_path = {
        "train": "/home/mac/datasets/train_spanish_negatives.tsv",
        "validation": "/home/mac/datasets/dev_spanish.tsv",
}

language_negatives_to_path = {
    "arabic": tsv_arabic_negatives_path,
    "english": tsv_english_negatives_path,
    "russian": tsv_russian_negatives_path,
    "portuguese": tsv_portuguese_negatives_path,
    "japanese": tsv_japanese_negatives_path,
    "hindi": tsv_hindi_negatives_path,
    "indonesian": tsv_indonesian_negatives_path,
    "dutch": tsv_dutch_negatives_path,
    "german": tsv_german_negatives_path,
    "chinese": tsv_chinese_negatives_path,
    "french": tsv_french_negatives_path,
    "spanish": tsv_spanish_negatives_path,
}

tsv_clirmatrix_multi_path = {
        "train": "/home/mac/datasets/train_clirmatrix.tsv",
        "validation": "/home/mac/datasets/dev_clirmatrix.tsv",
}



MULTILIGUAL_SPM_PATH = "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model"  # GCS
MULTILIGUAL_EXTRA_IDS = 100


def get_multilingual_vocabulary():
  return seqio.SentencePieceVocabulary(MULTILIGUAL_SPM_PATH)


DEFAULT_VOCAB = t5.data.get_default_vocabulary()
DEFAULT_OUTPUT_FEATURES = {
    "inputs":
        seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True, required=False),
    "targets":
        seqio.Feature(vocabulary=DEFAULT_VOCAB, add_eos=True)
}


MULTILINGUAL_VOCAB = get_multilingual_vocabulary()
MULTILINGUAL_OUTPUT_FEATURES = {
    "inputs":
        seqio.Feature(vocabulary=MULTILINGUAL_VOCAB, add_eos=True, required=False),
    "targets":
        seqio.Feature(vocabulary=MULTILINGUAL_VOCAB, add_eos=True)
}



MULTILINGUAL_OUTPUT_FEATURES_NEGATIVES = {
    "inputs":
        seqio.Feature(vocabulary=MULTILINGUAL_VOCAB, add_eos=True, required=False),
    "targets":
        seqio.Feature(vocabulary=MULTILINGUAL_VOCAB, add_eos=True),
    "negative_targets":
        seqio.Feature(vocabulary=MULTILINGUAL_VOCAB, add_eos=True),
}

seqio.TaskRegistry.add(
    "beir_msmarco_retrieval",
    source=seqio.TfdsDataSource(
        tfds_name="beir/msmarco:1.0.0",
        splits={
            "train": "train",
            "validation": "validation",
        },
    ),
    preprocessors=[
        functools.partial(
            t5.data.preprocessors.rekey,
            key_map={
                "inputs": "query",
                "targets": "passage",
            }),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=[],
    output_features=DEFAULT_OUTPUT_FEATURES)

for language in list(language_to_path.keys()):
    seqio.TaskRegistry.add(
        f"mmarco_retrieval_{language}",
        source=seqio.TextLineDataSource(
            split_to_filepattern=language_to_path[language],
            ),
        preprocessors=[
        functools.partial(
            t5.data.preprocessors.parse_tsv,
            field_names=["inputs","targets"]),
                seqio.preprocessors.tokenize,
                seqio.CacheDatasetPlaceholder(),
                seqio.preprocessors.append_eos_after_trim,
                ],
        metric_fns=[],
        output_features=MULTILINGUAL_OUTPUT_FEATURES,
    )



seqio.MixtureRegistry.add(
  "multilingual_marco_mixture",
  [(f"mmarco_retrieval_{language}", 1) for language in list(language_to_path.keys())]
)

for language in list(language_to_path.keys()):
    seqio.TaskRegistry.add(
        f"mmarco_retrieval_{language}_negatives",
        source=seqio.TextLineDataSource(
            split_to_filepattern=language_negatives_to_path[language],
            ),
        preprocessors=[
        functools.partial(
            t5.data.preprocessors.parse_tsv,
            field_names=["inputs","targets", "negative_targets"]),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
        ],
        metric_fns=[],
        output_features=MULTILINGUAL_OUTPUT_FEATURES_NEGATIVES,
    ) 

seqio.MixtureRegistry.add(
  "multilingual_marco_mixture_negatives",
  [(f"mmarco_retrieval_{language}_negatives", 1) for language in list(language_to_path.keys())]
)
seqio.TaskRegistry.add(
        f"clirmatrix_pretraining",
        source=seqio.TextLineDataSource(
            split_to_filepattern=tsv_clirmatrix_multi_path,
            ),
        preprocessors=[
        functools.partial(
            t5.data.preprocessors.parse_tsv,
            field_names=["inputs","targets"]),
                seqio.preprocessors.tokenize,
                seqio.CacheDatasetPlaceholder(),
                seqio.preprocessors.append_eos_after_trim,
                ],
        metric_fns=[],
        output_features=MULTILINGUAL_OUTPUT_FEATURES,
    )

for split in ["query", "passage"]:
  seqio.TaskRegistry.add(
      f"beir_msmarco_retrieval_{split}",
      source=seqio.TfdsDataSource(
          tfds_name="beir/msmarco:1.0.0",
          splits={split: split},
      ),
      preprocessors=[
          functools.partial(
              t5.data.preprocessors.rekey,
              key_map={
                  "inputs": split,
                  "targets": f"{split}_id",
              }),
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      metric_fns=[],
      output_features=DEFAULT_OUTPUT_FEATURES)
app = Flask(__name__)


@app.route('/')
def string_return():
    resp = "ML 서버 열심히 만들자"
   


    return resp




@app.route('/sick')
def breadpage():
   return '식빵맨'


@app.route("/curry")
def curry():
    return '''<!DOCTYPE HTML><html>
  <head>
    <title>카레빵맨 이미지 받아랑!!!</title>
  </head>
  <body>
    <h1>Curry bread man</h1>
    <img src = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRrfTgmkoNaKpFBUZ_DmdzbkGc66dsch5CWYg&usqp=CAU"
  </body>
</html>'''






@app.route('/bread', methods=['GET','POST'])
def test_get():
    bread_receive = request.args.get('breadname')
    return f'이것의 이름은 {bread_receive}'


'''
import openai


openai.api_key = "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"


@app.route('/turk', methods=['GET','POST'])


def test_conv():
    talk = request.args.get("humantalk")
    #prompt = f'{message}'
    prompt = "bana türkçe cevap ver" + talk
    response = openai.Completion.create(
        engine="gpt-3.5-turbo",
        prompt=prompt,
        temperature=0.5,
        max_tokens=35,
        n=1,
        stop=None,
        timeout=5,
        frequency_penalty=0,
        presence_penalty=0
    )
    return f"ai's response is : {response.choices[0].text.strip()}"
'''


import openai
openai.api_key = "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"
@app.route('/turk', methods=['GET','POST'])
def test_conv():
    if request.method == 'POST':
        param = request.get_json()
        param = param['content']


        pipe = pipeline('text-generation',model = "redrussianarmy/gpt2-turkish-cased",
                        tokenizer="redrussianarmy/gpt2-turkish-cased")
        text = pipe(str(param)+"gramer doğruluğunu kontrol et:")[0]['generated_text']


        text2 = pipe("Aşağıdaki cümle dilbilgisi açısından yanlışsa, lütfen x ile işaretleyin:"+str(param))[0]['generated_text']
        Aanswer = text
        translator = Translator()
        Ttransans = (translator.translate(Aanswer ,src = "tr",dest='ko').text)
        ttt = (translator.translate(text2 ,src = "tr",dest='ko').text)
        AAanswer = jsonify({ 'Aanswer' : Aanswer, "Ttrans_answer": Ttransans,"Iisrightvalue":text2, "transIS":ttt})
    return AAanswer






@app.route('/echo_call/<param>') #get echo api
def get_echo_call(param):
    return jsonify({"param": param})


@app.route('/answer',methods=['GET', 'POST']) #post echo api
def post_echo_call():
    if request.method == 'POST':
        param = request.get_json()
        param = param['content']
        completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "Sadece türkçe bilen bir arkadaşsın, elinden geldiği kadar doğal sohbet etmeye çalış."},
            {"role": "user", "content": param},
            ]
        )
        translator = Translator()
        answer = completion.choices[0].message.content
        transans = (translator.translate(answer ,src = "tr",dest='ko').text)
        #answer_json1 = jsonify({ 'answer' : answer})
        answer = jsonify({ 'answer' : answer, "trans_answer": transans})
        return answer


    elif request.method =='GET':
        return "get"
    else:
        return "hi"


@app.route('/grammar',methods=['GET', 'POST']) #post echo api
def post_grm():
    if request.method == 'POST':
        param = request.get_json()
        param = param['content']


        completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        temperature = 0.2,
        messages=[
        {"role": "system", "content": "Lütfen yalnızca dilbilgisi düzeltilmiş cümleleri yazdırın"},
        {"role": "user", "content": param},
        #{"role":"system","content": ""},
        ]
        )


        answer0 = completion.choices[0].message.content
        answer = "doğru cevap : " + answer0


        completion2 = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        temperature = 0,
        max_tokens = 2000,
        top_p = 1.0,
        messages=[
        {"role": "system", "content": "Düzeltmek için bir araç olmasını istiyorum. Size metni vereceğim ve yazım, dilbilgisi veya noktalama hataları için incelemenizi isteyeceğim. Metni inceledikten sonra cümlelerden herhangi biri yanlışsa neden yanlış olduğunu açıklayınız."},
        {"role": "user", "content": param},
        #{"role":"system","content": ""},hu


        ## Lütfen gramerin neden yanlış olduğunu söyle
        ]
        )


        reason = completion2.choices[0].message.content

for split in ["query", "passage"]:
  seqio.TaskRegistry.add(
      f"mmarco_retrieval_de_{split}",
      source=seqio.TfdsDataSource(
          tfds_name="mrtydi/mmarco-en:1.0.0",
          splits={split: split},
      ),
      preprocessors=[
          functools.partial(
              t5.data.preprocessors.rekey,
              key_map={
                  "inputs": split,
                  "targets": f"{split}_id",
              }),
          seqio.preprocessors.tokenize,
          seqio.CacheDatasetPlaceholder(),
          seqio.preprocessors.append_eos_after_trim,
      ],
      metric_fns=[],
      output_features=DEFAULT_OUTPUT_FEATURES)

In [None]:
Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray, tf.Tensor]
DType = jnp.dtype
ConfigDict = ml_collections.ConfigDict
PyTreeDef = type(jax.tree_structure(None))
Optimizer = optim.Optimizer


class DualEncoderBase(t5x_models.BaseTransformerModel):

  FEATURE_CONVERTER_CLS: Callable[..., seqio.FeatureConverter]

  ALLOWED_INFERENCE_MODE = frozenset({'encode', 'similarity'})

  def __init__(
      self,
      module: nn.Module,
      feature_converter_cls: Callable[[bool], seqio.FeatureConverter],
      input_vocabulary: seqio.Vocabulary,
      output_vocabulary: seqio.Vocabulary,
      optimizer_def: optim.OptimizerDef,
      inference_mode: str = 'encode',
  ):
    self.FEATURE_CONVERTER_CLS = feature_converter_cls  # pylint: disable=invalid-name
    self._inference_mode = inference_mode
    super(DualEncoderBase, self).__init__(
        module=module,
        input_vocabulary=input_vocabulary,
        output_vocabulary=output_vocabulary,
        optimizer_def=optimizer_def)

  def get_initial_variables(
      self,
      rng: jnp.ndarray,
      input_shapes: Mapping[str, Array],
      input_types: Optional[Mapping[str, DType]] = None
  ) -> flax_scope.FrozenVariableDict:
    """Get the initial variables for an dual-encoder model."""
    input_types = {} if input_types is None else input_types
    encoder_type = input_types.get('left_encoder_input_tokens', jnp.float32)
    left_encoder_shape = input_shapes['left_encoder_input_tokens']
    right_encoder_shape = input_shapes['right_encoder_input_tokens']
    initial_variables = self.module.init(
        rng,
        jnp.ones(left_encoder_shape, encoder_type),
        jnp.ones(right_encoder_shape, encoder_type),
        enable_dropout=False)
    return initial_variables

  def loss_weights(self, batch: Mapping[str,
                                        jnp.ndarray]) -> Optional[jnp.ndarray]:
    raise NotImplementedError('Not implemented for dual encoder.')

  def predict_batch_with_aux(
      self,
      params: Mapping[str, Array],
      batch: Mapping[str, jnp.ndarray],
      rng: Optional[jax.random.KeyArray] = None,
  ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
    raise NotImplementedError(
        'Autoregressive prediction is not implemented for dual encoder.')

  def _encode_batch(self, params: Mapping[str, Array],
                    batch: Mapping[str, jnp.ndarray]) -> Array:

    return self.module.apply(
        {'params': params},
        batch['left_encoder_input_tokens'],

        enable_dropout=False,
        method=self.module.encode)

  def _similarity_batch(self,
                        params: Mapping[str, Array],
                        batch: Mapping[str, jnp.ndarray],
                        return_intermediates: bool = False) -> Array:

    _, _, logits = self.module.apply({'params': params},
                                     batch['left_encoder_input_tokens'],
                                     batch['right_encoder_input_tokens'],
                                     enable_dropout=False)
    return logits

  def score_batch(self,
                  params: Mapping[str, Array],
                  batch: Mapping[str, jnp.ndarray],
                  return_intermediates: bool = False) -> jnp.ndarray:

    if self._inference_mode not in self.ALLOWED_INFERENCE_MODE:
      raise ValueError(
          'Invalid `inference_mode`: %s. Supported inference mode: %s.' %
          (self._inference_mode, self.ALLOWED_INFERENCE_MODE))
    if self._inference_mode == 'encode':
      return self._encode_batch(params, batch)
    elif self._inference_mode == 'similarity':
      return self._similarity_batch(params, batch, return_intermediates)


class DualEncoderModel(DualEncoderBase):


  ALLOWED_INFERENCE_MODE = frozenset(
      {'encode', 'similarity', 'pointwise_similarity'})

  def __init__(
      self,
      module: nn.Module,
      feature_converter_cls: Callable[[bool], seqio.FeatureConverter],
      input_vocabulary: seqio.Vocabulary,
      output_vocabulary: seqio.Vocabulary,
      optimizer_def: optim.OptimizerDef,
      inference_mode: str = 'encode',
      use_negatives: bool = False,
      use_align_uniform: bool = False,
      logit_scale: float = 100,
      logit_margin: float = 0.0,
  ):



    self._use_negatives = use_negatives
    self._use_align_uniform = use_align_uniform
    self._logit_scale = logit_scale
    self._logit_margin = logit_margin
    super(DualEncoderModel, self).__init__(
        module=module,
        feature_converter_cls=feature_converter_cls,
        input_vocabulary=input_vocabulary,
        output_vocabulary=output_vocabulary,
        optimizer_def=optimizer_def,
        inference_mode=inference_mode)

  def _compute_logits(
      self,
      params: Mapping[str, Any],
      batch: Mapping[str, jnp.ndarray],
      dropout_rng: Optional[jnp.ndarray] = None
  ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Computes logits via a forward pass of `self.module_cls`."""

    rngs = {'dropout': dropout_rng} if dropout_rng is not None else None

    if not self._use_negatives and 'right_negative_encoder_input_tokens' in batch:
      ValueError(
          'Invalid module. Please select `DualEncoderWithNegativesModel` for negative inputs.'
      )

    if self._use_negatives and 'right_negative_encoder_input_tokens' not in batch:
      ValueError(
          'Invalid inputs. Please prepare negative inputs for DualEncoderWithNegativesModel.'
      )

    if self._use_negatives:
      left_tokens = batch['left_encoder_input_tokens']
      right_positive_tokens = batch['right_encoder_input_tokens']
      right_negative_tokens = batch['right_negative_encoder_input_tokens']

 
      assert left_tokens.ndim == 2
      assert right_positive_tokens.ndim == 2

      assert right_negative_tokens.ndim == 2 or right_negative_tokens.ndim == 3

      batch_size = right_positive_tokens.shape[0]
      assert left_tokens.shape[0] == batch_size
      assert right_negative_tokens.shape[0] == batch_size


        right_seq_length = right_positive_tokens.shape[1]
        assert right_seq_length == right_negative_tokens.shape[2]

        num_negatives = right_negative_tokens.shape[1]
        right_negative_tokens = jnp.reshape(
            right_negative_tokens,
            (batch_size * num_negatives, right_seq_length))

      (left_encodings, right_encodings,
       logits), _ = self.module.apply({'params': params},
                                      left_tokens,
                                      right_positive_tokens,
                                      right_negative_tokens,
                                      enable_dropout=rngs is not None,
                                      rngs=rngs,
                                      mutable='dropout')

      left_logits, right_logits = logits, jnp.dot(right_encodings,
                                                  left_encodings.transpose())
    else:
      (left_encodings, right_encodings, logits), _ = self.module.apply(
          {'params': params},
          batch['left_encoder_input_tokens'],
          batch['right_encoder_input_tokens'],
          enable_dropout=rngs is not None,
          rngs=rngs,
          mutable='dropout')

      left_logits, right_logits = logits, logits.transpose()

    left_logits *= self._logit_scale
    right_logits *= self._logit_scale


    if dropout_rng is not None:
      left_logits = (
          left_logits - self._logit_margin *
          jnp.eye(N=left_logits.shape[0], M=left_logits.shape[1]))
      right_logits = (
          right_logits - self._logit_margin * jnp.eye(right_logits.shape[0]))

    return left_encodings, right_encodings, left_logits, right_logits

  def _compute_loss(
      self,
      batch: Mapping[str, jnp.ndarray],
      left_logits: jnp.ndarray,
      right_logits: jnp.ndarray,
  ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:

    # targets: [batch, 1] -> [batch, batch]
    left_loss = utils.in_batch_cross_entropy(left_logits)
    right_loss = utils.in_batch_cross_entropy(right_logits)
    loss = jnp.mean(left_loss + right_loss)
    return loss, 0.0, left_logits.shape[0]

  def _compute_metrics(
      self,
      params: Mapping[str, Any],
      batch: Mapping[str, jnp.ndarray],
      left_logits: jnp.ndarray,
      loss: jnp.ndarray,
      total_z_loss: jnp.ndarray,
      weight_sum: jnp.ndarray,
      align_loss: Optional[jnp.ndarray] = None,
      uniform_loss: Optional[jnp.ndarray] = None,
  ) -> metrics_lib.MetricsMap:
    """Compute metrics given the logits, targets and loss."""
    metrics = t5x_models.compute_base_metrics(
        logits=left_logits,
        targets=utils.sparse_labels_for_in_batch_cross_entropy(left_logits),
        mask=None,
        loss=loss,
        z_loss=total_z_loss)
    metrics.update({
        'mrr':
            metrics_lib.AveragePerStep.from_model_output(
                utils.compute_rr(
                    left_logits,
                    utils.sparse_labels_for_in_batch_cross_entropy(
                        (left_logits))))
    })
    if self._use_align_uniform:
      metrics.update({
          'align_loss':
              metrics_lib.AveragePerStep.from_model_output(align_loss),
          'uniform_loss':
              metrics_lib.AveragePerStep.from_model_output(uniform_loss),
      })
    return metrics

  def loss_fn(
      self,
      params: Mapping[str, Any],
      batch: Mapping[str, jnp.ndarray],
      dropout_rng: Optional[jnp.ndarray],
  ) -> Tuple[jnp.ndarray, jnp.ndarray]:

    left_encodings, right_encodings, left_logits, right_logits = self._compute_logits(
        params, batch, dropout_rng)

    loss, z_loss, weight_sum = self._compute_loss(batch, left_logits,
                                                  right_logits)
    if self._use_align_uniform:
      align_loss = utils.compute_align_loss(left_encodings, right_encodings)
      uniform_loss = utils.compute_uniform_loss(
          left_encodings) + utils.compute_uniform_loss(right_encodings)
      metrics = self._compute_metrics(params, batch, left_logits, loss, z_loss,
                                      weight_sum, align_loss, uniform_loss)
    else:
      metrics = self._compute_metrics(
          params,
          batch,
          left_logits,
          loss,
          z_loss,
          weight_sum,
      )

    return loss, metrics

  def score_batch(self,
                  params: Mapping[str, Array],
                  batch: Mapping[str, jnp.ndarray],
                  return_intermediates: bool = False) -> jnp.ndarray:
 
    if self._inference_mode not in self.ALLOWED_INFERENCE_MODE:
      raise ValueError(
          'Invalid `inference_mode`: %s. Supported inference mode: %s.' %
          (self._inference_mode, self.ALLOWED_INFERENCE_MODE))
    if self._inference_mode == 'encode':
      return self._encode_batch(params, batch)
    elif self._inference_mode == 'similarity':
      return self._similarity_batch(params, batch, return_intermediates)
    elif self._inference_mode == 'pointwise_similarity':
      logits = self._similarity_batch(params, batch, return_intermediates)
      return jnp.diagonal(logits)
