In [None]:
class Test:
    def __init__(self):
        self.val = 1
test = Test()

try:
    if test.nonval == 1:
        print("Exists")
except AttributeError:
    0

# T5を使ったJP to Protocolのテストコード

In [None]:
!pip install pathlib pandas torch transformers sentencepiece scikit-learn

In [None]:
# 必要なパッケージのインポート
from pathlib import Path
import re
import math
import time
import copy
from tqdm import tqdm
import pandas as pd
import tarfile
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
import settings

In [None]:
MODEL_NAME = "sonoisa/t5-base-japanese"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

max_length_src = 400
max_length_target = 200

batch_size_train = 8
batch_size_valid = 8

epochs = 1000
patience = 20

In [None]:
import os
path = '/content/drive/MyDrive/google_colaboratory/document_summarization/

In [None]:
data_dir_path = Path('data')

In [None]:
body_data = pd.read_csv(data_dir_path.joinpath('body_data.csv'))
summary_data = pd.read_csv(data_dir_path.joinpath('summary_data.csv'))

pd.merge(
    body_data.query('text.notnull()', engine='python').rename(columns={'text': 'body'}),
    summary_data.rename(columns={'text': 'summary'}),
    on='article_id', how='inner'
).sort_values('article_id').head(10)

In [None]:
def join_text(x, add_char='。'):
    return add_char.join(x)

def preprocess_text(text):
    text = re.sub(r'[\r\t\n\u3000]', '', text)
    text = neologdn.normalize(text)
    text = text.lower()
    text = text.strip()
    return text

summary_data = summary_data.query('text.notnull()', engine='python').groupby(
    'article_id'
).agg({'text': join_text})

body_data = body_data.query('text.notnull()', engine='python')

data = pd.merge(
    body_data.rename(columns={'text': 'body_text'}),
    summary_data.rename(columns={'text': 'summary_text'}),
    on='article_id', how='inner'
).assign(
    body_text=lambda x: x.body_text.map(lambda y: preprocess_text(y)),
    summary_text=lambda x: x.summary_text.map(lambda y: preprocess_text(y))
)

In [None]:
def convert_batch_data(train_data, valid_data, tokenizer):

    def generate_batch(data):

        batch_src, batch_tgt = [], []
        for src, tgt in data:
            batch_src.append(src)
            batch_tgt.append(tgt)

        batch_src = tokenizer(
            batch_src, max_length=settings.max_length_src, truncation=True, padding="max_length", return_tensors="pt"
        )
        batch_tgt = tokenizer(
            batch_tgt, max_length=settings.max_length_target, truncation=True, padding="max_length", return_tensors="pt"
        )

        return batch_src, batch_tgt

    train_iter = DataLoader(train_data, batch_size=settings.batch_size_train, shuffle=True, collate_fn=generate_batch)
    valid_iter = DataLoader(valid_data, batch_size=settings.batch_size_valid, shuffle=True, collate_fn=generate_batch)

    return train_iter, valid_iter

In [None]:
tokenizer = T5Tokenizer.from_pretrained(settings.MODEL_NAME, is_fast=True)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    data['body_text'], data['summary_text'], test_size=0.2, random_state=42, shuffle=True
)

train_data = [(src, tgt) for src, tgt in zip(X_train, y_train)]
valid_data = [(src, tgt) for src, tgt in zip(X_test, y_test)]

train_iter, valid_iter = convert_batch_data(train_data, valid_data, tokenizer)

In [None]:

class T5FineTuner(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.model = T5ForConditionalGeneration.from_pretrained(settings.MODEL_NAME)

    def forward(
        self, input_ids, attention_mask=None, decoder_input_ids=None,
        decoder_attention_mask=None, labels=None
    ):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels
        )

In [None]:
def train(model, data, optimizer, PAD_IDX):
    
    model.train()
    
    loop = 1
    losses = 0
    pbar = tqdm(data)
    for src, tgt in pbar:
                
        optimizer.zero_grad()
        
        labels = tgt['input_ids'].to(settings.device)
        labels[labels[:, :] == PAD_IDX] = -100

        outputs = model(
            input_ids=src['input_ids'].to(settings.device),
            attention_mask=src['attention_mask'].to(settings.device),
            decoder_attention_mask=tgt['attention_mask'].to(settings.device),
            labels=labels
        )
        loss = outputs['loss']

        loss.backward()
        optimizer.step()
        losses += loss.item()
        
        pbar.set_postfix(loss=losses / loop)
        loop += 1
        
    return losses / len(data)

In [None]:
def evaluate(model, data, PAD_IDX):
    
    model.eval()
    losses = 0
    with torch.no_grad():
        for src, tgt in data:

            labels = tgt['input_ids'].to(settings.device)
            labels[labels[:, :] == PAD_IDX] = -100

            outputs = model(
                input_ids=src['input_ids'].to(settings.device),
                attention_mask=src['attention_mask'].to(settings.device),
                decoder_attention_mask=tgt['attention_mask'].to(settings.device),
                labels=labels
            )
            loss = outputs['loss']
            losses += loss.item()
        
    return losses / len(data)

In [None]:
model = T5FineTuner()
model = model.to(settings.device)

optimizer = optim.Adam(model.parameters())

PAD_IDX = tokenizer.pad_token_id
best_loss = float('Inf')
best_model = None
counter = 1

for loop in range(1, settings.epochs + 1):

    start_time = time.time()

    loss_train = train(model=model, data=train_iter, optimizer=optimizer, PAD_IDX=PAD_IDX)

    elapsed_time = time.time() - start_time

    loss_valid = evaluate(model=model, data=valid_iter, PAD_IDX=PAD_IDX)

    print('[{}/{}] train loss: {:.4f}, valid loss: {:.4f} [{}{:.0f}s] counter: {} {}'.format(
        loop, settings.epochs, loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))

    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = copy.deepcopy(model)
        counter = 1
    else:
        if counter > settings.patience:
            break

        counter += 1


In [None]:
model_dir_path = Path('model')
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

In [None]:
tokenizer.save_pretrained(model_dir_path)
best_model.model.save_pretrained(model_dir_path)

In [None]:
def generate_text_from_model(text, trained_model, tokenizer, num_return_sequences=1):

    trained_model.eval()
    
    text = preprocess_text(text)
    batch = tokenizer(
        [text], max_length=settings.max_length_src, truncation=True, padding="longest", return_tensors="pt"
    )

    # 生成処理を行う
    outputs = trained_model.generate(
        input_ids=batch['input_ids'].to(settings.device),
        attention_mask=batch['attention_mask'].to(settings.device),
        max_length=settings.max_length_target,
        repetition_penalty=8.0,   # 同じ文の繰り返し（モード崩壊）へのペナルティ
        # temperature=1.0,  # 生成にランダム性を入れる温度パラメータ
        # num_beams=10,  # ビームサーチの探索幅
        # diversity_penalty=1.0,  # 生成結果の多様性を生み出すためのペナルティパラメータ
        # num_beam_groups=10,  # ビームサーチのグループ
        num_return_sequences=num_return_sequences,  # 生成する文の数
    )

    generated_texts = [
        tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) for ids in outputs
    ]

    return generated_texts

In [None]:
tokenizer = T5Tokenizer.from_pretrained(model_dir_path)
trained_model = T5ForConditionalGeneration.from_pretrained(model_dir_path)

In [None]:
trained_model = trained_model.to(settings.device)

In [None]:
index = 1
body = valid_data[index][0]
summaries = valid_data[index][1]
generated_texts = generate_text_from_model(
    text=body, trained_model=trained_model, tokenizer=tokenizer, num_return_sequences=1
)
print('□ 生成本文')
print('\n'.join(generated_texts[0].split('。')))
print()
print('□ 教師データ要約')   
print('\n'.join(summaries.split('。')))
print()
print('□ 本文')
print(body)

## enumをlistやdictから作る
参考：https://stackoverflow.com/questions/62120732/generate-an-enum-class-from-a-list-in-python

In [None]:
import enum

def generate_enum(enumClass, enumDict):
    """
    Generates python code for an Enum
    """

    enum_template = """
@unique
class {enumClass}(Enum)
{enumBody}
"""

    enumBody = '\n'.join([f"    {name} = '{value}'" for (name,value) in enumDict.items()])

    return enum_template.format(enumClass=enumClass,enumBody=enumBody)
species_list = ['HUMAN',"WEREWOLF","ANY"]
print(generate_enum("ProtocolToken",species_list))


In [None]:
#
from enum import Enum
species_list = ['HUMAN',"WEREWOLF","ANY"]
ProtocolToken = Enum('ProtocolToken',species_list)
list(ProtocolToken)

In [None]:
from enum import Enum
subject_list = ["Agent01","Agent02", "Agent03", "Agent04", "Agent05",
                   "Agent06", "Agent07", "Agent08", "Agent09", "Agent10", 
                   "Agent11", "Agent12", "Agent13", "Agent14", "Agent15","UNSPEC","ANY"] #TODO:ここ周りテキトーにやってる。ホントは分類ではなく値自身を使えばいいはず
verb_list = ['ESTIMATE', 'COMINGOUT', 'DIVINATION', 'GUARD', 'VOTE',
            'ATTACK', 'DIVINED', 'IDENTIFIED', 'GUARDED', 'VOTED',
            'ATTACKED', 'AGREE', 'DISAGREE', 'Skip', 'Over' ] # REVIEW: Skip, Overをuppercaseにする必要があるかどうか
target_list = subject_list
species_list = ['HUMAN',"WEREWOLF","ANY"]
role_list = ['VILLAGER','SEER', 'MEDIUM','BODYGUARD','WEREWOLF','POSSESSED','ANY']

protocol_token_list = subject_list + verb_list + target_list + species_list + role_list
protocol_token_dict = {token: i for i, token in enumerate(protocol_token_list)}

ProtocolToken = Enum('ProtocolToken',protocol_token_dict)
list(ProtocolToken)

In [None]:
def next_terminal_tokens(partial_sentence: str):
    terminal_tokens = {
        'sentence_start': ['Skip', 'Over', 'Agent', 'ANY', 'UNSPEC'],
        'VTR_VT_VTS_AGG_OTS_OS1_OS2_OSS_DAY': ['ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'Agree', 'Disagree', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY'],
        'TR': ['Agent', 'ANY'],
        'T': ['Agent', 'ANY'],
        'TSp': ['Agent', 'ANY'],
        'TSe': ['Agent', 'ANY'],
        'S2': ['Skip', 'Over', 'Agent', 'ANY', 'UNSPEC'],
        'SS': ['Skip', 'Over', 'Agent', 'ANY', 'UNSPEC'],
        'recsentence': ['Skip', 'Over', 'Agent', 'ANY', 'UNSPEC'],
        'rec2sentence': ['Skip', 'Over', 'Agent', 'ANY', 'UNSPEC', 'eps'],
        'species': ['HUMAN', 'WEREWOLF', 'ANY'],
        'role': ['VILLAGER', 'SEER', 'MEDIUM', 'BODYGUARD', 'WEREWOLF', 'POSSESSED'],
        'talk_number': ['day'],
        'agent_number': [str(i) for i in range(1, 16)],
        'day_number': [str(i) for i in range(1, 1000)], # ここでは上限を1000に設定していますが、適宜変更してください
        'ID_number': [str(i) for i in range(1, 1000)] # ここでは上限を1000に設定していますが、適宜変更してください
    }

    # 入力された部分文から次の非終端記号を特定する
    for non_terminal, tokens in terminal_tokens.items():
        for token in tokens:
            if partial_sentence.endswith(token):
                next_non_terminal = non_terminal
                break

    # 次の非終端記号に対応する終端記号の一覧を返す
    return terminal_tokens[next_non_terminal]

# 使用例
partial_sentence = "Agent1 ESTIMATE Agent2"
print(next_terminal_tokens(partial_sentence)) # ['VILLAGER', 'SEER', 'MEDIUM', 'BODYGUARD', 'WEREWOLF', 'POSSESSED']


In [None]:
# 分類するラベルのリスト
subject_list = ["Agent","UNSPEC","ANY"] #TODO:ここ周りテキトーにやってる。ホントは分類ではなく値自身を使えばいいはず
verb_list = ['ESTIMATE', 'COMINGOUT', 'DIVINATION', 'GUARD', 'VOTE',
            'ATTACK', 'DIVINED', 'IDENTIFIED', 'GUARDED', 'VOTED',
            'ATTACKED', 'AGREE', 'DISAGREE', 'Skip', 'Over' ] # REVIEW: Skip, Overをuppercaseにする必要があるかどうか
target_list = subject_list
species_list = ['HUMAN',"WEREWOLF","ANY"]
role_list = ['VILLAGER','SEER', 'MEDIUM','BODYGUARD','WEREWOLF','POSSESSED','ANY']

protocol_token_list = subject_list + verb_list + target_list + species_list + role_list

In [None]:
# モデルのインポート
# 事前学習済みモデル
PRETRAINED_MODEL_NAME = "sonoisa/t5-base-english-japanese" #"sonoisa/t5-base-japanese"

# 転移学習済みモデル
MODEL_DIR = "/content/model"

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, LogitsProcessorList, LogitsProcessor
import numpy as np

model :T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(PRETRAINED_MODEL_NAME)
tokenizer :T5Tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME)


In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, LogitsProcessorList, LogitsProcessor
import numpy as np

model :T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(PRETRAINED_MODEL_NAME)
tokenizer :T5Tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME)


In [None]:
#書くprotocolがtokenizerのvocabに含まれているかどうかを確認する
for p in protocol_token_list:
    if p not in tokenizer.get_vocab().keys():
        print(f"protocol:{p}", f"{tokenizer.encode(p)},{tokenizer.decode(tokenizer.encode(p))}")

In [None]:
#protocolをlower caseに <-したほうが良さそう
lower_protocol_token_list = [p.lower() for p in protocol_token_list]
#書くprotocolがtokenizerのvocabに含まれているかどうかを確認する
for p in lower_protocol_token_list:
    if p not in tokenizer.get_vocab().keys():
        print(f"some tokens protocol:{p}", f"{tokenizer.tokenize(p)},{tokenizer.encode(p)},{tokenizer.decode(tokenizer.encode(p))}")
    else:
        print(f"one token protocol:{p}", f"{tokenizer.tokenize(p)},{tokenizer.encode(p)},{tokenizer.decode(tokenizer.encode(p))}")
        

In [None]:
tokenizer.get_vocab()['agree'] # なぜかagreeはvocabにない

In [None]:
tokenizer.encode(' ')

In [None]:
def set_scores_to_inf_for_banned_tokens(scores, banned_tokens):
    """
    Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
    list of list of banned tokens to ban in the format [[batch index, vocabulary position],...

    Args:
        scores: logits distribution of shape (batch size, vocabulary size)
        banned_tokens: list of list of tokens to ban of length (batch_size)
    """
    banned_mask_list = []
    for idx, batch_banned_tokens in enumerate(banned_tokens):
        for token in batch_banned_tokens:
            banned_mask_list.append([idx, token])
    if not banned_mask_list:
        return scores

    banned_mask = torch.LongTensor(banned_mask_list)
    indices = torch.ones(len(banned_mask))

    banned_mask = (
        torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
    )
    scores = scores.masked_fill(banned_mask, -float("inf"))
    return scores

In [None]:
class EvenLogits(LogitsProcessor):
  def __call__(self, input_ids, scores):

    banned_tokens = []
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      elementwise_length = np.vectorize(len)
      keys = np.array(list(tokenizer.vocab.keys()))
      values = np.array(list(tokenizer.vocab.values()))

      # indexes of tokens that are too long
      indexes = np.where(elementwise_length(keys) % 2 == 0)[0]

      banned_tokens.append(values[indexes])

    scores = set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
    return scores

class ABCLogits(LogitsProcessor):
  def __init__(self, vocab):
    """
    vocab is a dictionary where the keys are tokens
    and the values are the corresponding ids.
    """
    # create an array of tokens
    # remove the 'Ġ' token (used to represent a blank space in the tokenizer)
    self.keys = list(tokenizer.get_vocab().keys())
    # index_to_pop = self.keys.index(' ') 
    # self.keys.pop(index_to_pop)
    self.keys = np.array(self.keys)

    # create an array of ids
    # also remove the 'Ġ' token
    self.values = list(tokenizer.get_vocab().values())
    # self.values.pop(index_to_pop)
    self.values = np.array(self.values)

    # vectorized function used to get the first character of a token
    # ignores leading whitespaces and 'Ġ' tokens
    first_char = lambda x: x.strip('Ġ ')[0].lower()
    self.first_char = np.vectorize(first_char)

    # get the indexes of all IDs that do not start with the given letter
    not_a_indexes = np.where(self.first_char(self.keys) != 'a')
    not_b_indexes = np.where(self.first_char(self.keys) != 'b')
    not_c_indexes = np.where(self.first_char(self.keys) != 'c')

    # create sets of tokens that do not start with 'a', 'b' or 'c'
    self.not_a_values = self.values[not_a_indexes]
    self.not_b_values = self.values[not_b_indexes]
    self.not_c_values = self.values[not_c_indexes]

  def __call__(self, input_ids, scores):
    banned_tokens = []
    # for every beam (partially generated sentence)
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      # get the last token of this beam
      last_word = tokenizer.decode(beam_input_ids[-1])
      # get the first character of this last token
      starting_char = self.first_char(last_word)
      # if the last token starts with 'a',
      # ban all words that do not start with 'b', etc.
      if starting_char == 'a':
        banned_tokens.append(self.not_b_values)
      elif starting_char == 'b':
        banned_tokens.append(self.not_c_values)
      elif starting_char == 'c':
        banned_tokens.append(self.not_a_values)
      else:
        banned_tokens.append(self.not_a_values)
    # set the scores of all banned tokens over the beams to -inf
    scores = set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
    return scores

In [None]:
class LL1Grammar2:
    def __init__(self):
        self.state = 'S'

    def next_valid_tokens(self, token):
        if self.state == 'S':
            if token == 'hello':
                self.state = 'A'
                return ['sec', 'apple']
            else:
                return []
        elif self.state == 'A':
            if token == 'sec':
                self.state = 'B'
                return ['cost', 'done']
            elif token == 'apple':
                self.state = 'A'
                return ['sec', 'apple']
            else:
                return []
        elif self.state == 'B':
            if token == 'cost':
                self.state = 'B'
                return ['cost', 'done']
            elif token == 'done':
                self.state = 'A'
                return ['sec', 'apple', 'bad']
            else:
                return []
        else:
            return []


In [None]:
class ConstrainedLogitsProcessor(LogitsProcessor):
    def __init__(self, grammar, tokenizer):
        super().__init__()
        self.grammar = grammar
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores):
        print(input_ids)
        last_token_id = input_ids[-1]
        last_token = self.tokenizer.convert_ids_to_tokens(last_token_id)
        #print(last_token)
        valid_tokens = self.grammar.next_valid_tokens(last_token)
        valid_token_ids = set(self.tokenizer.convert_tokens_to_ids(valid_tokens))

        for token_id in range(scores.shape[-1]):
            if token_id not in valid_token_ids:
                scores[..., token_id] = float('-inf')

        return scores


In [None]:
encoder_input_str = "Agent1はAgent2を人狼だと思っている"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
grammar = LL1Grammar2()
force_words_ids = tokenizer(lower_protocol_token_list, add_special_tokens=False).input_ids
# Initialize the custom logits processor
logits_processor = LogitsProcessorList([
    ConstrainedLogitsProcessor(grammar, tokenizer)
])

outputs = model.generate(
    input_ids,
    force_words_ids=force_words_ids,
    num_beams=5,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
    logits_processor=logits_processor
)


print("Output:\n" + 100 * '-')
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
# how many beams to track during the Viterbi algorithm
num_beams = 10
# how many beams to return after the algorithm
num_return_beams = 10

# the prompt to continue
prompt = 'My cute dog is a'

# tokenizing the prompt
prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']

# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

# instantiating a list of LogitsProcessor instances
# using our custom ABCLogits class
logits_processor = LogitsProcessorList([ABCLogits(tokenizer.vocab_size)])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    input_ids=torch.cat([prompt_tokenized] * num_beams).to(dtype = torch.long),
    beam_scorer =beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=12)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  print(f'beam {index}: {output}')

In [None]:
tokenizer.encode(' ')

In [None]:
class LL1Parser:
    def __init__(self, grammar):
        self.grammar = grammar
        self.first_sets = self.compute_first_sets()
        self.parse_table = self.construct_parse_table()

    def compute_first_sets(self):
        first_sets = {}
        for non_terminal in self.grammar:
            first_sets[non_terminal] = set()

        changed = True
        while changed:
            changed = False
            for non_terminal, rules in self.grammar.items():
                for rule in rules:
                    rule_symbols = rule.split()
                    first_symbol = rule_symbols[0]

                    if first_symbol not in self.grammar:  # Terminal symbol
                        if first_symbol not in first_sets[non_terminal]:
                            first_sets[non_terminal].add(first_symbol)
                            changed = True
                    else:  # Non-terminal symbol
                        for symbol in first_sets[first_symbol]:
                            if symbol not in first_sets[non_terminal]:
                                first_sets[non_terminal].add(symbol)
                                changed = True
        return first_sets

    def construct_parse_table(self):
        parse_table = {}
        for non_terminal in self.grammar:
            parse_table[non_terminal] = {}
            for terminal in self.first_sets[non_terminal]:
                for rule in self.grammar[non_terminal]:
                    if rule.startswith(terminal):
                        parse_table[non_terminal][terminal] = rule
        return parse_table

    def parse(self, input_string):
        tokens = re.findall(r"\w+", input_string) + ["$"]
        stack = [list(self.grammar.keys())[0], "$"]
        cursor = 0

        while stack:
            top = stack.pop()
            if top in self.grammar:  # Non-terminal
                if cursor < len(tokens) and tokens[cursor] in self.parse_table[top]:
                    production = self.parse_table[top][tokens[cursor]]
                    for symbol in reversed(production.split()):
                        stack.append(symbol)
                else:
                    return False, self.first_sets[top]
            elif top == tokens[cursor]:
                cursor += 1
                if cursor == len(tokens) and stack[-1] == "$":
                    return True, set()
            else:
                return False, self.first_sets.get(top, set())

        return False, set()

if __name__ == "__main__":
    grammar = {
        "S": ["A a", "b B"],
        "A": ["a"],
        "B": ["b"],
    }

    parser = LL1Parser(grammar)

    test_cases = [
        ("aa", True),
        ("ba", False),
        ("bb", True),
        ("ab", False),
        ("b", False),
    ]

    for string, expected in test_cases:
        result, next_symbols = parser.parse(string)
        print(f"Input: {string}, Expected: {expected}, Result: {result}, Next symbols: {next_symbols}")


In [None]:
import re

class LL1Parser:
    def __init__(self, grammar):
        self.grammar = grammar
        self.first_sets = self.compute_first_sets()
        self.parse_table = self.construct_parse_table()

    def compute_first_sets(self):
        first_sets = {}
        for non_terminal in self.grammar:
            first_sets[non_terminal] = set()

        changed = True
        while changed:
            changed = False
            for non_terminal, rules in self.grammar.items():
                for rule in rules:
                    rule_symbols = rule.split()
                    first_symbol = rule_symbols[0]

                    if first_symbol not in self.grammar:  # Terminal symbol
                        if first_symbol not in first_sets[non_terminal]:
                            first_sets[non_terminal].add(first_symbol)
                            changed = True
                    else:  # Non-terminal symbol
                        for symbol in first_sets[first_symbol]:
                            if symbol not in first_sets[non_terminal]:
                                first_sets[non_terminal].add(symbol)
                                changed = True
        return first_sets

    def construct_parse_table(self):
        parse_table = {}
        for non_terminal in self.grammar:
            parse_table[non_terminal] = {}
            for rule in self.grammar[non_terminal]:
                first_symbol = rule.split()[0]
                if first_symbol not in self.grammar:  # Terminal symbol
                    parse_table[non_terminal][first_symbol] = rule
                else:
                    for symbol in self.first_sets[first_symbol]:
                        parse_table[non_terminal][symbol] = rule
        return parse_table

    def parse(self, input_string):
        tokens = re.findall(r"\w+", input_string) + ["$"]
        stack = [list(self.grammar.keys())[0], "$"]
        cursor = 0

        while stack:
            top = stack.pop()
            if top in self.grammar:  # Non-terminal
                if cursor < len(tokens) and tokens[cursor] in self.parse_table[top]:
                    production = self.parse_table[top][tokens[cursor]]
                    for symbol in reversed(production.split()):
                        stack.append(symbol)
                else:
                    return False, self.first_sets[top]
            elif top == tokens[cursor]:
                cursor += 1
                if cursor == len(tokens) and not stack:
                    return True, set()
            else:
                return False, set()

        return False, set()

if __name__ == "__main__":
    grammar = {
        "S": ["A a", "b B"],
        "A": ["a"],
        "B": ["b"],
    }

    parser = LL1Parser(grammar)

    test_cases = [
        ("aa", True),
        ("ba", False),
        ("bb", True),
        ("ab", False),
        ("b", False),
    ]

    for string, expected in test_cases:
        result, next_symbols = parser.parse(string)
        print(f"Input: {string}, Expected: {expected}, Result: {result}, Next symbols: {next_symbols}")


In [None]:
from collections import defaultdict

class LL1Parser:
    def __init__(self, grammar):
        self.grammar = grammar
        self.first_sets = self.compute_first_sets()
        self.follow_sets = self.compute_follow_sets()

    def compute_first_sets(self):
        first_sets = defaultdict(set)

        for non_terminal in self.grammar:
            first_sets[non_terminal] = self.compute_first_set(non_terminal)

        return first_sets

    def compute_first_set(self, symbol):
        if symbol.isupper():
            return {symbol}
        
        if not symbol in self.grammar:
            return {symbol}

        first_set = set()
        for production in self.grammar[symbol]:
            first_symbol = production[0]
            if first_symbol.isupper():
                first_set.add(first_symbol)
            else:
                first_set |= self.compute_first_set(first_symbol)

        return first_set

    # 他の関数は変更なし


    def compute_follow_sets(self):
        follow_sets = defaultdict(set)

        for non_terminal in self.grammar:
            for production in self.grammar[non_terminal]:
                for i, symbol in enumerate(production):
                    if symbol.isupper():
                        if i + 1 < len(production):
                            next_symbol = production[i + 1]
                            if next_symbol.isupper():
                                follow_sets[symbol].add(next_symbol)
                            else:
                                follow_sets[symbol] |= self.first_sets[next_symbol]
                        else:
                            follow_sets[symbol] |= self.compute_follow_set(non_terminal)

        return follow_sets

    def compute_follow_set(self, non_terminal):
        follow_set = set()

        for nt in self.grammar:
            for production in self.grammar[nt]:
                if non_terminal in production:
                    idx = production.index(non_terminal)
                    if idx + 1 < len(production):
                        next_symbol = production[idx + 1]
                        if next_symbol.isupper():
                            follow_set.add(next_symbol)
                        else:
                            follow_set |= self.first_sets[next_symbol]
                    else:
                        if nt != non_terminal:
                            follow_set |= self.compute_follow_set(nt)

        return follow_set

    def parse(self, input_str):
        stack = ["$"]
        stack.extend(list(reversed(list(self.grammar["S"][0]))))
        input_str = list(input_str) + ["$"]

        while input_str:
            if stack[-1].isupper():
                top_symbol = stack.pop()
                input_symbol = input_str[0]

                for production in self.grammar[top_symbol]:
                    first_set = self.compute_first_set(production[0])

                    if input_symbol in first_set:
                        stack.extend(list(reversed(production)))
                        break
                else:
                    return False
            else:
                if stack[-1] == input_str[0]:
                    stack.pop()
                    input_str.pop(0)
                else:
                    return stack[-1:]

        return True

# 使用例
if __name__ == "__main__":
    # BNF表記された文法
    grammar = {
        "S": [["E"]],
        "E": [["T", "E'"]],
        "E'": [["+", "T", "E'"], [""]],
        "T": [["F", "T'"]],
        "T'": [["*", "F", "T'"], [""]],
        "F": [["(", "E", ")"], ["a"]]
    }

    parser = LL1Parser(grammar)

    # テストケース
    test_cases = [
        ("a+a", True),
        ("a+a*a", True),
        ("(a+a)*a", True),
        ("a+a)", False),
        ("a+)", ["+", "a"]),
        ("a+*", False),
        ("(a+", [")", "a"]),
    ]

    for input_str, expected in test_cases:
        result = parser.parse(input_str)
        assert result == expected, f"Expected {expected}, but got {result}. Input: {input_str}"
        print(f"Input: {input_str}, Result: {result}")



In [None]:
from collections import deque
from typing import List,Dict,Union

from enum import Enum
# 分類するラベルのリスト
subject_list = ["Agent01","Agent02", "Agent03", "Agent04", "Agent05",
                   "Agent06", "Agent07", "Agent08", "Agent09", "Agent10", 
                   "Agent11", "Agent12", "Agent13", "Agent14", "Agent15","UNSPEC","ANY"] #TODO:ここ周りテキトーにやってる。ホントは分類ではなく値自身を使えばいいはず
verb_list = ['ESTIMATE', 'COMINGOUT', 'DIVINATION', 'GUARD', 'VOTE',
            'ATTACK', 'DIVINED', 'IDENTIFIED', 'GUARDED', 'VOTED',
            'ATTACKED', 'AGREE', 'DISAGREE', 'Skip', 'Over' ] # REVIEW: Skip, Overをuppercaseにする必要があるかどうか
target_list = subject_list
species_list = ['HUMAN',"WEREWOLF","ANY"]
role_list = ['VILLAGER','SEER', 'MEDIUM','BODYGUARD','WEREWOLF','POSSESSED','ANY']

protocol_token_list = ['EOS']+ subject_list + verb_list + target_list + species_list + role_list
protocol_token_dict = {token: i for i, token in enumerate(protocol_token_list)}

#Protocolのラベル用enum
Terminal = Enum('Terminal',protocol_token_dict) # TODO:要修正
NonTerminal = Enum('NonTerminal',protocol_token_dict) #TODO:要修正

terminal_next_gen_rule_list: Dict[NonTerminal,Dict[Terminal,List[Union[NonTerminal,Terminal]]]] = {"test":{}} #TODO:要修正
non_terminal_first: Dict[NonTerminal,List[Terminal]] = {"test":[]} #TODO:要修正
next_terminals_in_gen_rule: Dict[NonTerminal,Dict[Terminal,List[List[Terminal]]]] = {"test":{}} #TODO:要修正

def next_terminals(terminals: List[Terminal]) -> List[Terminal]:
    if len(terminals) == 0:
        return non_terminal_first["sentence"]

    transitions = deque()
    cur_non_terminal = "sentence" # Start NON TERMINAL
    gen_rule = terminal_next_gen_rule_list[cur_non_terminal][terminals[0]]
    idx = 0
    #構文解析
    for t in terminals:
        # 非終端記号に到達したので子に移動
        while t != gen_rule[idx]:
            transitions.push_right((cur_non_terminal,gen_rule,idx+1))
            cur_non_terminal = gen_rule[idx]
            gen_rule[idx] = terminal_next_gen_rule_list[cur_non_terminal][t]
            idx = 0
         
        idx += 1
        # cur_non_terminalの構文解析成功．親に移動
        while idx == len(gen_rule) and len(transitions) > 0:
            cur_non_terminal,gen_rule,idx = transitions.pop_right()
     
    #次にとりうる終端記号を列挙
    if len(transitions) == 0:# 構文解析が成功し，次にとり得る終端記号がない
        return [] #
    else:
        return next_terminals_in_gen_rule[cur_non_terminal][gen_rule[0]][idx-1]
    

In [None]:
First = {}
First['sentence'] = ['Skip', 'Over', 'Agent', 'ANY', 'ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY']
First['VTR_VT_VTS_AGG_OTS_OS1_OS2_OSS_DAY'] = ['ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY']
First['TR'] = ['Agent', 'ANY']
First['T'] = ['Agent', 'ANY']
First['TSp'] = ['Agent', 'ANY']
First['TSe'] = ['Agent', 'ANY']
First['S2'] = ['Skip', 'Over', 'Agent', 'ANY', 'ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY']
First['SS'] = ['Skip', 'Over', 'Agent', 'ANY', 'ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY']
First['recsentence'] = ['Skip', 'Over', 'Agent', 'ANY', 'ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY']
First['rec2sentence'] = ['Skip', 'Over', 'Agent', 'ANY', 'ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY', 'r']
First['species'] = ['HUMAN', 'WEREWOLF', 'ANY']
First['role'] = ['VILLAGER', 'SEER', 'MEDIUM', 'BODYGUARD', 'WEREWOLF', 'POSSESSED']
First['talk_number'] = ['day']
First['agent_number'] = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15']
First['day_number'] = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
First['ID_number'] = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


In [None]:
Follow = {}
Follow['sentence'] = ['r']
Follow['VTR_VT_VTS_AGG_OTS_OS1_OS2_OSS_DAY'] = ['r']
Follow['TR'] = ['r']
Follow['T'] = ['r']
Follow['TSp'] = ['r']
Follow['TSe'] = ['r']
Follow['S2'] = ['r']
Follow['SS'] = ['r']
Follow['recsentence'] = ['r']
Follow['rec2sentence'] = ['r']
Follow['species'] = ['r']
Follow['role'] = ['r']
Follow['talk_number'] = ['r']
Follow['agent_number'] = ['ESTIMATE', 'COMMINGOUT', 'DIVINATION', 'GUARD', 'VOTE', 'ATTACK', 'GUARDED', 'VOTED', 'ATTACKED', 'DIVINED', 'IDENTIFIED', 'AGREE', 'DISAGREE', 'REQUEST', 'INQUIRE', 'NOT', 'BECAUSE', 'XOR', 'AND', 'OR', 'DAY', 'VILLAGER', 'SEER', 'MEDIUM', 'BODYGUARD', 'WEREWOLF', 'POSSESSED', 'r', 'HUMAN', 'ANY', 'l']
Follow['day_number'] = ['l', 'IDcol']
Follow['ID_number'] = ['r']



In [None]:
class LL1Grammar:
    def __init__(self, non_terminals, terminals, start_symbol, production_rules):
        self.non_terminals = non_terminals
        self.terminals = terminals
        self.start_symbol = start_symbol
        self.production_rules = production_rules
        self.first_sets = self.compute_first_sets()
        self.follow_sets = self.compute_follow_sets()
        self.parse_table = self.construct_parse_table()

    def __repr__(self):
        return f"LL1Grammar(non_terminals={self.non_terminals}, terminals={self.terminals}, start_symbol={self.start_symbol}, production_rules={self.production_rules})"
    
    def add_production_rule(self, non_terminal, production):
        if non_terminal in self.non_terminals:
            self.production_rules[non_terminal].append(production)
        else:
            raise ValueError(f"Non-terminal '{non_terminal}' not found in grammar.")

    def remove_production_rule(self, non_terminal, production):
        if non_terminal in self.non_terminals:
            if production in self.production_rules[non_terminal]:
                self.production_rules[non_terminal].remove(production)
            else:
                raise ValueError(f"Production '{production}' not found for non-terminal '{non_terminal}'.")
        else:
            raise ValueError(f"Non-terminal '{non_terminal}' not found in grammar.")
        
    def compute_first_sets(self):
        first_sets = {nt: set() for nt in self.non_terminals}
        change = True

        while change:
            change = False
            for nt in self.non_terminals:
                for prod in self.production_rules[nt]:
                    if prod[0] in self.terminals:
                        if prod[0] not in first_sets[nt]:
                            first_sets[nt].add(prod[0])
                            change = True
                    else:
                        for symbol in prod:
                            if symbol in self.terminals:
                                break
                            for first in first_sets[symbol]:
                                if first not in first_sets[nt]:
                                    first_sets[nt].add(first)
                                    change = True
                            if 'ε' not in first_sets[symbol]:
                                break
        return first_sets

    def compute_follow_sets(self):
        follow_sets = {nt: set() for nt in self.non_terminals}
        follow_sets[self.start_symbol].add('$')

        change = True
        while change:
            change = False
            for nt in self.non_terminals:
                for prod in self.production_rules[nt]:
                    for i, symbol in enumerate(prod[:-1]):
                        if symbol in self.non_terminals:
                            for next_sym in prod[i + 1:]:
                                if next_sym in self.terminals:
                                    if next_sym not in follow_sets[symbol]:
                                        follow_sets[symbol].add(next_sym)
                                        change = True
                                    break
                                else:
                                    for first in self.first_sets[next_sym]:
                                        if first not in follow_sets[symbol] and first != 'ε':
                                            follow_sets[symbol].add(first)
                                            change = True
                                    if 'ε' not in self.first_sets[next_sym]:
                                        break
                            else:
                                for follow in follow_sets[nt]:
                                    if follow not in follow_sets[symbol]:
                                        follow_sets[symbol].add(follow)
                                        change = True
        return follow_sets

    def construct_parse_table(self):
        parse_table = {nt: {t: '' for t in self.terminals} for nt in self.non_terminals}
        parse_table.update({nt: {t: '' for t in self.terminals} for nt in ['$']})

        for nt, prods in self.production_rules.items():
            for prod in prods:
                first_symbols = set()
                for symbol in prod:
                    if symbol in self.terminals:
                        first_symbols.add(symbol)
                        break
                    else:
                        first_symbols |= self.first_sets[symbol]
                        if 'ε' not in first_symbols:
                            break
                        first_symbols -= {'ε'}

                for terminal in first_symbols:
                    parse_table[nt][terminal] = prod

                if 'ε' in first_symbols:
                    for terminal in self.follow_sets[nt]:
                        if terminal != '$':
                            parse_table[nt][terminal] = prod

        return parse_table

In [None]:

# 以下は、前回の例で使用した文法に基づいてLL(1)文法のインスタンスを作成するコードです。
non_terminals = {'S', 'A', 'B'}
terminals = {'a', 'b', 'c', 'd', 'ε'}
start_symbol = 'S'
production_rules = {
    'S': ['AB'],
    'A': ['aA', 'b', 'ε'],
    'B': ['cB', 'd', 'ε']
}

grammar = LL1Grammar(non_terminals, terminals, start_symbol, production_rules)
for key, value in grammar.parse_table.items():
    print(key, value)
#print(grammar.parse_table)

In [None]:
from typing import Dict, List, Set

class LL1Grammar:
    def __init__(self, non_terminals: Set[str], terminals: Set[str], start_symbol: str, production_rules: Dict[str, List[str]]):
        self.non_terminals = non_terminals
        self.terminals = terminals
        self.start_symbol = start_symbol
        self.production_rules = production_rules
        self.first_sets = self.compute_first_sets()
        self.follow_sets = self.compute_follow_sets()
        self.parse_table = self.construct_parse_table()

    # ... (Other methods and __repr__)
    def __repr__(self):
        return f"LL1Grammar(non_terminals={self.non_terminals}, terminals={self.terminals}, start_symbol={self.start_symbol}, production_rules={self.production_rules})"
    
    def add_production_rule(self, non_terminal, production):
        if non_terminal in self.non_terminals:
            self.production_rules[non_terminal].append(production)
        else:
            raise ValueError(f"Non-terminal '{non_terminal}' not found in grammar.")

    def remove_production_rule(self, non_terminal, production):
        if non_terminal in self.non_terminals:
            if production in self.production_rules[non_terminal]:
                self.production_rules[non_terminal].remove(production)
            else:
                raise ValueError(f"Production '{production}' not found for non-terminal '{non_terminal}'.")
        else:
            raise ValueError(f"Non-terminal '{non_terminal}' not found in grammar.")

    def compute_first_sets(self) -> Dict[str, Set[str]]:
        first_sets = {nt: set() for nt in self.non_terminals}
        change = True

        while change:
            change = False
            for nt in self.non_terminals:
                for prod in self.production_rules[nt]:
                    prod_symbols = prod.split()  # Update here
                    if prod_symbols[0] in self.terminals:
                        if prod_symbols[0] not in first_sets[nt]:
                            first_sets[nt].add(prod_symbols[0])
                            change = True
                    else:
                        for symbol in prod_symbols:
                            if symbol in self.terminals:
                                break
                            for first in first_sets[symbol]:
                                if first not in first_sets[nt]:
                                    first_sets[nt].add(first)
                                    change = True
                            if 'ε' not in first_sets[symbol]:
                                break
        return first_sets

    def compute_follow_sets(self) -> Dict[str, Set[str]]:
        follow_sets = {nt: set() for nt in self.non_terminals}
        follow_sets[self.start_symbol].add('$')

        change = True
        while change:
            change = False
            for nt in self.non_terminals:
                for prod in self.production_rules[nt]:
                    prod_symbols = prod.split()  # Update here
                    for i, symbol in enumerate(prod_symbols[:-1]):
                        if symbol in self.non_terminals:
                            for next_sym in prod_symbols[i + 1:]:
                                if next_sym in self.terminals:
                                    if next_sym not in follow_sets[symbol]:
                                        follow_sets[symbol].add(next_sym)
                                        change = True
                                    break
                                else:
                                    for first in self.first_sets[next_sym]:
                                        if first not in follow_sets[symbol] and first != 'ε':
                                            follow_sets[symbol].add(first)
                                            change = True
                                    if 'ε' not in self.first_sets[next_sym]:
                                        break
                            else:
                                for follow in follow_sets[nt]:
                                    if follow not in follow_sets[symbol]:
                                        follow_sets[symbol].add(follow)
                                        change = True
        return follow_sets

    def construct_parse_table(self):
        parse_table = {nt: {t: '' for t in self.terminals} for nt in self.non_terminals}
        parse_table.update({nt: {t: '' for t in self.terminals} for nt in ['$']})

        for nt, prods in self.production_rules.items():
            for prod in prods:
                prod_symbols = prod.split()  # Update here
                first_symbols = set()
                for symbol in prod_symbols:
                    if symbol in self.terminals:
                        first_symbols.add(symbol)
                        break
                    else:
                        first_symbols |= self.first_sets[symbol]
                        if 'ε' not in first_symbols:
                            break
                first_symbols -= {'ε'}

                for terminal in first_symbols:
                    parse_table[nt][terminal] = prod

                if 'ε' in first_symbols:
                    for terminal in self.follow_sets[nt]:
                        if terminal != '$':
                            parse_table[nt][terminal] = prod

        return parse_table



In [None]:
non_terminals = {'S', 'A', 'B'}
terminals = {'a', 'b', 'c', 'd', 'ε'}
start_symbol = 'S'
production_rules = {
    'S': ['A B'],
    'A': ['a A', 'b', 'ε'],
    'B': ['c B', 'd', 'ε']
}

grammar = LL1Grammar(non_terminals, terminals, start_symbol, production_rules)
for key, value in grammar.parse_table.items():
    print(key, value)

In [None]:
import re

def convert_day(s):
    # "day[数字]" を "day [各数字]" に変換する正規表現
    pattern = r"(day)(\d+)"

    # 文字列内のパターンにマッチする部分を変換する関数
    def replace_func(match):
        day_str = match.group(1)
        digits = match.group(2)

        # 数字を1桁ずつスペースで区切る
        spaced_digits = ' '.join(list(digits))

        return f"{day_str} {spaced_digits}"

    # 文字列内のパターンにマッチする部分を変換
    s = re.sub(pattern, replace_func, s)

    return s

# 使用例
input_str = "Today is day10 and tomorrow is day32."
output_str = convert_day(input_str)
print(output_str)


In [None]:
import re

def revert_day(s):
    # "day [各数字]" を "day[数字]" に変換する正規表現
    pattern = r"(day)(?:\s+(\d))+"

    # 文字列内のパターンにマッチする部分を変換する関数
    def replace_func(match):
        day_str = match.group(1)
        digits = match.group(2)

        # スペースで区切られた数字を連結
        concatenated_digits = ''.join(match.groups()[1:])

        return f"{day_str}{concatenated_digits}"

    # 文字列内のパターンにマッチする部分を変換
    s = re.sub(pattern, replace_func, s)

    return s

# 使用例
input_str = "Today is day 1 0 and tomorrow is day 3 2."
output_str = revert_day(input_str)
print(output_str)


In [None]:
import re

def add_brackets(s):
    # "Agent[数字]" を "Agent [数字]" に変換する正規表現
    pattern = r"(Agent) (\d+)"

    # 文字列内のパターンにマッチする部分を変換する関数
    agent_str =""
    digits = ""
    def replace_func(match):
        agent_str = match.group(1)
        digits = match.group(2)

        return f"{agent_str}[{digits}]"
    print(agent_str, digits)

    # 文字列内のパターンにマッチする部分を変換
    s = re.sub(pattern, replace_func, s)

    return s

# 使用例
input_str = "This is Agent 07 and that is Agent 11."
output_str = add_brackets(input_str)
print(output_str)


# Add Tokens test

In [4]:
# Importing stock libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
import torch.nn.functional as F

# Importing the T5 modules from huggingface/transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer, LogitsProcessorList, LogitsProcessor

# 日本語プロトコル変換用
from aiwolfk2b.agentLPS.jp_to_protocol_converter import JPToProtocolConverter
from aiwolfk2b.utils.ll1_grammar import LL1Grammar, aiwolf_protocol_grammar, convert_ll1_to_protocol,convert_protocol_to_ll1
from typing import Any, Callable, Dict, List,Optional, Tuple, Union



# Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [5]:

# 文法に従ったトークンのみを生成するlogits_processor
class ConstrainedLogitsProcessor(LogitsProcessor):
    def __init__(self, grammar: LL1Grammar, tokenizer: T5Tokenizer):
        super().__init__()
        self.grammar:LL1Grammar = grammar
        self.tokenizer:T5Tokenizer = tokenizer
        # self.space_token_id = self.tokenizer.convert_tokens_to_ids(" ")
        # print("space_token_id:",self.space_token_id)

    def __call__(self, input_ids, scores):
        #print(input_ids)
        #一度文章に変換
        protocol_batch = self.tokenizer.batch_decode(input_ids,skip_special_tokens=True)
        valid_tokens_ids = []
        agent_numbers_ids_1_15 = {2072, 1905, 2078, 2458, 2357, 2195, 2227, 2246, 2224,333,359,350,491,506,423}
        numbers_ids_0_9 = {942, 291, 293, 294, 306, 320, 331, 334, 337, 341}
        #各文章をスペースで分割
        for idx,protocol in enumerate(protocol_batch):
            possible_tokens_ids = set()
            #protocol_split = protocol.split()
            #空の場合
            print("protocol:",protocol)
            print("protocol_ids:",input_ids[idx])
            if protocol == "":
                #最初に来るトークンの候補を取得
                possible_terminals = self.grammar.first_sets[self.grammar.start_symbol]
                # print("possible_terminals:",possible_terminals)
                for terminal in possible_terminals:
                    terminal_token_id = self.tokenizer.convert_tokens_to_ids(terminal)
                    possible_tokens_ids.add(terminal_token_id)
            #最後のトークンが"day"の場合、次にありえるトークンを設定
            elif input_ids[idx][-1] == 2726: # "day" <-> 2726
                possible_tokens_ids.add(262)#" " <-> 262
            #最後のトークンが01~15の場合、次にありえるトークンを設定
            elif input_ids[idx][-1] in agent_numbers_ids_1_15:
                possible_tokens_ids.add(262)#" " <-> 262
            #最後のトークンが01~09の場合、次にありえるトークンを設定
            elif input_ids[idx][-1] in numbers_ids_0_9:
                possible_tokens_ids.add(262)#" " <-> 262
            
            # #最後のトークンが262の場合、次にありえるトークンを設定
            # elif input_ids[idx][-1] == 262:
            #     #01~09, )が該当
            #     after262_tokens = [1905, 2078, 2224, 2357, 2458, 2246, 2072, 2227, 2195, 268]
            #     possible_terminals = self.grammar.get_next_terminals(protocol)
            #     possible_terminals_tokens = set()
            #     for terminal in possible_terminals:
            #         possible_terminals_tokens.add(self.tokenizer.convert_tokens_to_ids(terminal))
            #     possible_tokens_ids.add(after262_tokens & possible_terminals)
            #最後のトークンが262以外の場合は、次に来る終端記号がそのまま続く
            
            #次に来る終端記号がそのまま続く
            else:    
                possible_terminals = self.grammar.get_next_terminals(protocol)
                # print("possible_terminals:",possible_terminals)
                for terminal in possible_terminals:
                    terminal_token_id = self.tokenizer.convert_tokens_to_ids(terminal)
                    if terminal != "(":
                        possible_tokens_ids.add(terminal_token_id)
                    elif terminal == ")":
                        if input_ids[idx][-1] == 262: #すでに　空白(262)がある場合
                            possible_tokens_ids.add(268)
                        else:
                            possible_tokens_ids.add(262) #"number )"を実現するために、262を追加
                    elif terminal == "ε":
                        possible_tokens_ids.add(self.tokenizer.eos_token_id)
                    else:
                        possible_tokens_ids.add(290)
                    
                    
                            
            #もし、得られる終端記号の候補がない場合は、終了記号を追加
            if len(possible_tokens_ids) == 0:
                possible_tokens_ids.add(self.tokenizer.eos_token_id)
                
            valid_tokens_ids.append(possible_tokens_ids)

        #print("eos_token_id:{}".format(self.tokenizer.eos_token_id))
        print("valid_token_ids:{}".format(valid_tokens_ids))
        # print("scores.shape:{}".format(scores.shape))

        for batch_idx in range(scores.shape[0]):
            for token_id in range(scores.shape[1]):
                if token_id not in valid_tokens_ids[batch_idx]:
                    scores[batch_idx, token_id] = float('-inf')

        return scores

class T5JPToProtocolConverter(JPToProtocolConverter):
    def __init__(self, model_name: str = "default", model_path: str = "default"):
        #モデルの読み込み
        if model_name == "default":
            MODEL_NAME = "sonoisa/t5-base-english-japanese"
        else:
            MODEL_NAME = model_name
        if model_path == "default":
            MODEL_PATH = "/home/takuya/HDD1/work/AI_Wolf/2023S_AIWolfK2B/aiwolfk2b/agentLPS/jp2protocol_model/t5_upper_20230514_4.pth"
        else:
            MODEL_PATH = model_path
            
        self.model_name = MODEL_NAME
        self.model_path = MODEL_PATH
        
        self.model:T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
        self.tokenizer:T5Tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
        
        # プロトコル変換用に文法の準備
        self.protocol_grammar: LL1Grammar = aiwolf_protocol_grammar()
        # プロトコル変換用のlogits_processorの準備
        self.logits_processor = LogitsProcessorList([
            ConstrainedLogitsProcessor(self.protocol_grammar, self.tokenizer)
        ])
        #トークナイザーにLL1文法の終端記号を追加し、モデルのトークナイザーをリサイズ
        # terminals_lower = [terminal.lower() for terminal in self.protocol_grammar.terminals]
        # terminals_lower = set(terminals_lower)
        # new_tokens = terminals_lower - set(self.tokenizer.get_vocab().keys())
        # self.tokenizer.add_tokens(list(new_tokens))
        # self.model.resize_token_embeddings(len(self.tokenizer))
        #学習したモデルがあれば読み込む
        if MODEL_PATH != "":
            self.model.load_state_dict(torch.load(MODEL_PATH,map_location=torch.device(device)))
        
    def convert(self, text_list: List[str]) -> List[str]:
        input = self.tokenizer.batch_encode_plus(text_list, max_length=128, padding='max_length', return_tensors='pt', truncation=True)
        
        outputs = self.model.generate(
        inputs=input["input_ids"],
        attention_mask=input["attention_mask"],
        #force_words_ids=self.force_words_ids,
        num_beams=5,
        do_sample=True,
        num_return_sequences=1,
        no_repeat_ngram_size=0,
        remove_invalid_values=True,
        logits_processor=self.logits_processor,
        max_length = 16,
        early_stopping=True,
        )
        
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)




In [6]:
def unit_test_T5JPToProtocolConverter(converter):
    text = "Agent[08]が襲われたAgent[05]を霊媒すると人間だった"
    # 入力する文章
    text_list = [
        "Agent[03]はAgent[08]が狼だと推測する",
        # "Agent[06]はAgent[06]が占い師だとカミングアウトする",
        # "Agent[12]が占った結果Agent[10]は人狼だった",
        # "Agent[12]が占った結果Agent[10]は人間だった",
        # "Agent[08]が襲われたAgent[05]を霊媒すると人間だった",
        # "Agent[05]はAgent[10]を護衛した",
        #  "Agent[10]はAgent[12]に投票する",
        # "Agent[06]はAgent[08]が狼だと思う",
        # "私が占い師です",
        # "Agent[12]が占った結果、Agent[10]は人狼でした",
        # "Agent[12]が占った結果、Agent[10]は人間でした",
        # "Agent[12]がAgent[05]を霊媒すると人間でした",
        # "Agent[12]はAgent[10]を守った",
        # "Agent[10]はAgent[12]に投票します",
        # "Agent[08]が狼だと思う",
        # "私が占い師です",
        # "占った結果、Agent[10]は人狼でした",
        # "占った結果、Agent[10]は人間でした",
        # "Agent[05]を霊媒すると人間でした",
        # "私はAgent[10]を守った",
        # "私はAgent[12]に投票します",
    ]
    protocol = converter.convert([text])
    print("one text:",protocol)
    
    protocols = converter.convert(text_list)
    print("text_list:", protocols)


In [4]:
converter = T5JPToProtocolConverter()

RuntimeError: Error(s) in loading state_dict for T5ForConditionalGeneration:
	size mismatch for shared.weight: copying a param with shape torch.Size([32133, 768]) from checkpoint, the shape in current model is torch.Size([32123, 768]).
	size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([32133, 768]) from checkpoint, the shape in current model is torch.Size([32123, 768]).
	size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([32133, 768]) from checkpoint, the shape in current model is torch.Size([32123, 768]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([32133, 768]) from checkpoint, the shape in current model is torch.Size([32123, 768]).

In [None]:
unit_test_T5JPToProtocolConverter(converter)

protocol: 
protocol_ids: tensor([0])
protocol: 
protocol_ids: tensor([0])
protocol: 
protocol_ids: tensor([0])
protocol: 
protocol_ids: tensor([0])
protocol: 
protocol_ids: tensor([0])
valid_token_ids:[{32128, 32129, 32130, 32131, 32132, 32100, 32101, 32103, 32105, 32107, 32108, 32109, 32110, 32111, 32112, 32113, 32114, 32116, 32117, 32118, 32119, 32120, 32122, 32123, 32125}, {32128, 32129, 32130, 32131, 32132, 32100, 32101, 32103, 32105, 32107, 32108, 32109, 32110, 32111, 32112, 32113, 32114, 32116, 32117, 32118, 32119, 32120, 32122, 32123, 32125}, {32128, 32129, 32130, 32131, 32132, 32100, 32101, 32103, 32105, 32107, 32108, 32109, 32110, 32111, 32112, 32113, 32114, 32116, 32117, 32118, 32119, 32120, 32122, 32123, 32125}, {32128, 32129, 32130, 32131, 32132, 32100, 32101, 32103, 32105, 32107, 32108, 32109, 32110, 32111, 32112, 32113, 32114, 32116, 32117, 32118, 32119, 32120, 32122, 32123, 32125}, {32128, 32129, 32130, 32131, 32132, 32100, 32101, 32103, 32105, 32107, 32108, 32109, 32110

In [7]:
converter = T5JPToProtocolConverter(model_path="")

In [None]:
test_protocols = [
    "INQUIRE Agent 02 ( Agent 03 DISAGREE day 2 ID 2 12 )",
    "Agent 01 ESTIMATE ANY POSSESSED",
    "REQUEST",
    "INQUIRE Agent"]
for protocol in test_protocols:
    token_ids = converter.tokenizer.encode(protocol)
    print("protocol:{} \n token_ids:{}\n".format(protocol,token_ids))

In [11]:
#語彙のうち、1 token idで表現されていないものを調べる
terminals_lower = [terminal.lower() for terminal in converter.protocol_grammar.terminals]
terminals_lower = set(terminals_lower)
for terminal in terminals_lower:
    tokens = converter.tokenizer.encode(terminal)
    if len(tokens) > 2:
        print("terminal:{} \t tokens:{}".format(terminal,tokens))

terminal:divination 	 tokens:[1048, 3971, 799, 1]
terminal:comingout 	 tokens:[8602, 2845, 1]
terminal:werewolf 	 tokens:[400, 29949, 1]
terminal:possessed 	 tokens:[23501, 312, 1]
terminal:estimate 	 tokens:[262, 29475, 1]
terminal:08 	 tokens:[262, 2078, 1]
terminal:seer 	 tokens:[1458, 358, 1]
terminal:01 	 tokens:[262, 1905, 1]
terminal:villager 	 tokens:[1725, 370, 1]
terminal:skip 	 tokens:[7528, 408, 1]
terminal:02 	 tokens:[262, 2246, 1]
terminal:03 	 tokens:[262, 2227, 1]
terminal:09 	 tokens:[262, 2195, 1]
terminal:xor 	 tokens:[1047, 614, 1]
terminal:bodyguard 	 tokens:[3724, 18339, 1]
terminal:05 	 tokens:[262, 2072, 1]
terminal:) 	 tokens:[262, 268, 1]
terminal:guarded 	 tokens:[7218, 312, 1]
terminal:07 	 tokens:[262, 2357, 1]
terminal:divined 	 tokens:[28316, 321, 1]
terminal:06 	 tokens:[262, 2224, 1]
terminal:inquire 	 tokens:[279, 3144, 530, 1]
terminal:04 	 tokens:[262, 2458, 1]


In [7]:
#新たに追加されるトークン
new_tokens = terminals_lower - set(converter.tokenizer.get_vocab().keys())
print("new_tokens:",new_tokens)

new_tokens: set()


In [None]:
token_ids = []
for i in range(0,16):
    token_id = converter.tokenizer.convert_tokens_to_ids(f"{i}")
    token_ids.append(token_id)
token_ids

[942,
 291,
 293,
 294,
 306,
 320,
 331,
 334,
 337,
 341,
 333,
 359,
 350,
 491,
 506,
 423]

In [None]:
#2つ以上のtoken idで表現されている語彙の一覧をリスト形式で出力
second_ids = []
for terminal in converter.protocol_grammar.terminals:
    tokens = converter.tokenizer.encode(terminal)
    if len(tokens) > 2:
        second_ids.append(tokens[1])
        
print(second_ids)

In [None]:
converter.tokenizer.convert_tokens_to_ids("10")

333

In [None]:
converter.tokenizer.convert_tokens_to_ids(")")

In [None]:
converter.tokenizer.decode([2726,262,320,262],skip_special_tokens=True)

In [None]:
del converter

In [None]:
converter.tokenizer.encode("護衛")

[262, 9521, 1]

In [13]:
converter.tokenizer.batch_encode_plus(["agent","possess","possessed","(",")","03"], add_special_tokens=False)

{'input_ids': [[12631], [23501], [23501, 312], [290], [262, 268], [262, 2227]], 'attention_mask': [[1], [1], [1, 1], [1], [1, 1], [1, 1]]}

# 既存のTokenをできるだけ流用したConverterのテスト
終端記号を小文字にして、既存のTokenをできるだけ流用してみる


In [1]:
# Importing stock libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 
import torch.nn.functional as F

# Importing the T5 modules from huggingface/transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer, LogitsProcessorList, LogitsProcessor

# 日本語プロトコル変換用
from aiwolfk2b.agentLPS.jp_to_protocol_converter import JPToProtocolConverter
from aiwolfk2b.utils.ll1_grammar import LL1Grammar, aiwolf_protocol_grammar, convert_ll1_to_protocol,convert_protocol_to_ll1
from typing import Any, Callable, Dict, List, Optional, Tuple, Union,Set
from collections import defaultdict


# Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
#device = 'cpu'


#現在のプログラムが置かれているディレクトリを取得
import os
# current_dir = os.path.dirname(os.path.abspath(__file__))

In [2]:
class Tree:
    def __init__(self):
        self.root:TreeNode = TreeNode(token_id=-1,parent=None)
    
    def add_branch(self, token_ids:List[int],terminal:str = None):
        cur_node:TreeNode = self.root
        for i in token_ids:
            #子ノードがなければ追加
            if cur_node.children[i] is None:
                cur_node.add_child(TreeNode(token_id=i,parent=cur_node))
            #子ノードに移動
            cur_node = cur_node.children[i]
        #終端記号があれば一番最後に追加
        if terminal is not None:
            cur_node.terminal = terminal

    def print(self, node=None, _prefix="", _last=True):
        node = node or self.root
        print(_prefix, "`- " if _last else "|- ", node, sep="")
        _prefix += "   " if _last else "|  "
        child_count = len(node.children)
        for i, child in enumerate(node.children.values()):
            _last = i == (child_count - 1)
            self.print(child, _prefix, _last)
    
    def __str__(self):
        from io import StringIO
        buf = StringIO()
        self._print_str(self.root, buf)
        return buf.getvalue()
    
    def _print_str(self, node, buf, _prefix="", _last=True):
        buf.write(_prefix)
        buf.write("`- " if _last else "|- ")
        buf.write(str(node))
        buf.write("\n")
        _prefix += "   " if _last else "|  "
        child_count = len(node.children)
        for i, child in enumerate(node.children.values()):
            _last = i == (child_count - 1)
            self._print_str(child, buf, _prefix, _last)
        
class TreeNode:
    def __init__(self,token_id, parent, terminal=None, is_leaf=True):
        self.token_id = token_id
        self.terminal = terminal
        self.parent = parent
        self.is_leaf = is_leaf
        self.__children = defaultdict(lambda: None)
    
    @property
    def children(self):
        return self.__children
    
    # 子ノードを追加する
    def add_child(self, child):
        self.__children[child.token_id] = child
        # 子ノードが追加されたので、葉ではなくなる
        if self.is_leaf:
            self.is_leaf = False
            
    # subtreeの終端記号のセットを返す
    def get_subtree_terminals(self)->Set[str]:
        terminals:Set[str] = set()
        if self.terminal is not None:
            terminals.add(self.terminal)
            
        for child in self.children.values():
            terminals = terminals | child.get_subtree_terminals()
        return terminals
    
    #自分以下の部分木が瞬時符号になっているか確認
    def is_instant(self)->Tuple[bool,List[str]]:
        if self.is_leaf:
            return (True,[])
        conflict_pairs = []
        is_ok = True
        #葉ノード以外に終端記号がある場合は瞬時符号ではない
        if self.terminal is not None:
            conflict_pairs.append((self.token_id,self.terminal))
            is_ok = False
            
        #子ノードに対して繰り返し適用
        for child in self.children.values():
            b,pairs = child.is_instant()
            is_ok = is_ok and b
            conflict_pairs.extend(pairs)
        #すベての子ノードが瞬時符号ならば瞬時符号
        return (is_ok,conflict_pairs)
        

    def __str__(self):
        return f"Node({self.token_id}, Terminal={self.terminal})"

In [3]:


# 文法に従ったトークンのみを生成するlogits_processor
class ConstrainedLogitsProcessor(LogitsProcessor):
    def __init__(self, grammar: LL1Grammar, tokenizer: T5Tokenizer):
        super().__init__()
        self.grammar:LL1Grammar = grammar
        self.tokenizer:T5Tokenizer = tokenizer
        # self.space_token_id = self.tokenizer.convert_tokens_to_ids(" ")
        # print("space_token_id:",self.space_token_id)
        self.token_tree = self.generate_token_tree()
        
    def generate_token_tree(self)->Tree:
        #tokenizerとgrammarを用いて、終端記号トークンの木を生成する
        tree = Tree()
        terminals = list(self.grammar.terminals)
        #小文字に変換
        for i,terminal in enumerate(terminals):
            terminals[i] = terminal.lower()
        terminals_tokens_ids = self.tokenizer.batch_encode_plus(terminals,add_special_tokens=False)["input_ids"]
        for terminal,terminal_tokens_ids in zip(terminals,terminals_tokens_ids):
            tree.add_branch(terminal_tokens_ids,terminal)
            
        return tree    
        
    def __call__(self, input_ids, scores):
        #print(input_ids)
        #一度文章に変換
        protocol_batch = self.tokenizer.batch_decode(input_ids,skip_special_tokens=True)
        valid_tokens_ids = []
        agent_numbers_ids_1_15 = {2072, 1905, 2078, 2458, 2357, 2195, 2227, 2246, 2224,333,359,350,491,506,423}
        numbers_ids_0_9 = {942, 291, 293, 294, 306, 320, 331, 334, 337, 341}
        #各文章をスペースで分割
        for idx,protocol in enumerate(protocol_batch):
            possible_tokens_ids = set()
            #protocol_split = protocol.split()
            #空の場合
            print("protocol:",protocol)
            print("protocol_ids:",input_ids[idx])
            if protocol == "":
                #最初に来るトークンの候補を取得
                possible_terminals = self.grammar.first_sets[self.grammar.start_symbol]
                # print("possible_terminals:",possible_terminals)
                for terminal in possible_terminals:
                    terminal_token_id = self.tokenizer.convert_tokens_to_ids(terminal)
                    possible_tokens_ids.add(terminal_token_id)
            #最後のトークンが"day"の場合、次にありえるトークンを設定
            elif input_ids[idx][-1] == 2726: # "day" <-> 2726
                possible_tokens_ids.add(262)#" " <-> 262
            #最後のトークンが01~15の場合、次にありえるトークンを設定
            elif int(input_ids[idx][-1]) in agent_numbers_ids_1_15:
                # print("in agent_numbers_ids_1_15")
                possible_tokens_ids.add(262)#" " <-> 262
            #最後のトークンが0~9の場合、次にありえるトークンを設定
            elif int(input_ids[idx][-1]) in numbers_ids_0_9:
                # print("in numbers_ids_0_9")
                possible_tokens_ids.add(262)#" " <-> 262
            
            # #最後のトークンが262の場合、次にありえるトークンを設定
            # elif input_ids[idx][-1] == 262:
            #     #01~09, )が該当
            #     after262_tokens = [1905, 2078, 2224, 2357, 2458, 2246, 2072, 2227, 2195, 268]
            #     possible_terminals = self.grammar.get_next_terminals(protocol)
            #     possible_terminals_tokens = set()
            #     for terminal in possible_terminals:
            #         possible_terminals_tokens.add(self.tokenizer.convert_tokens_to_ids(terminal))
            #     possible_tokens_ids.add(after262_tokens & possible_terminals)
            #最後のトークンが262以外の場合は、次に来る終端記号がそのまま続く
            
            #次に来る終端記号がそのまま続く
            else:    
                possible_terminals = self.grammar.get_next_terminals(protocol)
                # print("possible_terminals:",possible_terminals)
                for terminal in possible_terminals:
                    terminal_token_id = self.tokenizer.convert_tokens_to_ids(terminal)
                    if terminal == "(":
                        possible_tokens_ids.add(290)
                    elif terminal == ")":
                        if input_ids[idx][-1] == 262: #すでに　空白(262)がある場合
                            possible_tokens_ids.add(268)
                        else:
                            possible_tokens_ids.add(262) #"number )"を実現するために、262を追加
                    elif terminal == "ε":
                        possible_tokens_ids.add(self.tokenizer.eos_token_id)
                    else:
                        possible_tokens_ids.add(terminal_token_id)
                    
                    
                            
            #もし、得られる終端記号の候補がない場合は、終了記号を追加
            if len(possible_tokens_ids) == 0:
                possible_tokens_ids.add(self.tokenizer.eos_token_id)
                
            valid_tokens_ids.append(possible_tokens_ids)

        #print("eos_token_id:{}".format(self.tokenizer.eos_token_id))
        print("valid_token_ids:{}".format(valid_tokens_ids))
        # print("scores.shape:{}".format(scores.shape))

        for batch_idx in range(scores.shape[0]):
            for token_id in range(scores.shape[1]):
                if token_id not in valid_tokens_ids[batch_idx]:
                    scores[batch_idx, token_id] = float('-inf')

        return scores

In [4]:
class T5JPToProtocolConverter(JPToProtocolConverter):
    def __init__(self, model_name: str = "default", model_path: str = "default"):
        #モデルの読み込み
        if model_name == "default":
            MODEL_NAME = "sonoisa/t5-base-english-japanese"
        else:
            MODEL_NAME = model_name
        if model_path == "default":
            MODEL_PATH = "/home/takuya/HDD1/work/AI_Wolf/2023S_AIWolfK2B/aiwolfk2b/agentLPS/jp2protocol_model/t5_upper_20230514_4.pth"
        else:
            MODEL_PATH = model_path
            
        self.model_name = MODEL_NAME
        self.model_path = MODEL_PATH
        
        self.model:T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
        self.tokenizer:T5Tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
        
        # プロトコル変換用に文法の準備
        self.protocol_grammar: LL1Grammar = aiwolf_protocol_grammar()
        # プロトコル変換用のlogits_processorの準備
        self.logits_processor = LogitsProcessorList([
            ConstrainedLogitsProcessor(self.protocol_grammar, self.tokenizer)
        ])
        #トークナイザーにLL1文法の終端記号を追加し、モデルのトークナイザーをリサイズ
        # new_tokens = self.protocol_grammar.terminals - set(self.tokenizer.get_vocab().keys())
        # self.tokenizer.add_tokens(list(new_tokens))
        # self.model.resize_token_embeddings(len(self.tokenizer))
        #学習したモデルがあれば読み込む
        if MODEL_PATH != "":
            self.model.load_state_dict(torch.load(MODEL_PATH,map_location=torch.device(device)))
        
    def convert(self, text_list: List[str]) -> List[str]:
        input = self.tokenizer.batch_encode_plus(text_list, max_length=512, padding=True, return_tensors='pt', truncation=True)
        
        outputs = self.model.generate(
        inputs=input["input_ids"],
        attention_mask=input["attention_mask"],
        #force_words_ids=self.force_words_ids,
        num_beams=5,
        temperature=1,
        #do_sample=True,
        num_return_sequences=1,
        no_repeat_ngram_size=0,
        remove_invalid_values=False,
        logits_processor=self.logits_processor,
        max_length = 16,
        early_stopping=True,
        )
        
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [5]:

def unit_test_T5JPToProtocolConverter():
    converter = T5JPToProtocolConverter()
    text = "Agent[08]が襲われたAgent[05]を霊媒すると人間だった"
    # 入力する文章
    text_list = [
        "Agent[03]はAgent[08]が狼だと推測する",
        # "Agent[06]はAgent[06]が占い師だとカミングアウトする",
        # "Agent[12]が占った結果Agent[10]は人狼だった",
        # "Agent[12]が占った結果Agent[10]は人間だった",
        # "Agent[08]が襲われたAgent[05]を霊媒すると人間だった",
        # "Agent[05]はAgent[10]を護衛した",
        # "Agent[10]はAgent[12]に投票する",
        # "Agent[06]はAgent[08]が狼だと思う",
        # "私が占い師です",
        # "Agent[12]が占った結果、Agent[10]は人狼でした",
        # "Agent[12]が占った結果、Agent[10]は人間でした",
        # "Agent[12]がAgent[05]を霊媒すると人間でした",
        # "Agent[12]はAgent[10]を守った",
        # "Agent[10]はAgent[12]に投票します",
        # "Agent[08]が狼だと思う",
        # "私が占い師です",
        # "占った結果、Agent[10]は人狼でした",
        # "占った結果、Agent[10]は人間でした",
        # "Agent[05]を霊媒すると人間でした",
        # "私はAgent[10]を守った",
        # "私はAgent[12]に投票します",
    ]
    protocol = converter.convert([text])
    print("one text:",protocol)
    
    protocols = converter.convert(text_list)
    print("text_list:", protocols)

In [6]:
converter = T5JPToProtocolConverter(model_path="")

In [7]:
print(converter.logits_processor[0].token_tree)

`- Node(-1, Terminal=None)
   |- Node(2674, Terminal=0)
   |- Node(637, Terminal=over)
   |- Node(386, Terminal=or)
   |- Node(3724, Terminal=None)
   |  `- Node(18339, Terminal=bodyguard)
   |- Node(451, Terminal=not)
   |- Node(7218, Terminal=guard)
   |  `- Node(312, Terminal=guarded)
   |- Node(685, Terminal=5)
   |- Node(262, Terminal=None)
   |  |- Node(2227, Terminal=03)
   |  |- Node(29475, Terminal=estimate)
   |  |- Node(2357, Terminal=07)
   |  |- Node(2246, Terminal=02)
   |  |- Node(2072, Terminal=05)
   |  |- Node(268, Terminal=))
   |  |- Node(2224, Terminal=06)
   |  |- Node(2195, Terminal=09)
   |  |- Node(1905, Terminal=01)
   |  |- Node(2078, Terminal=08)
   |  `- Node(2458, Terminal=04)
   |- Node(834, Terminal=7)
   |- Node(7528, Terminal=None)
   |  `- Node(408, Terminal=skip)
   |- Node(1725, Terminal=None)
   |  `- Node(370, Terminal=villager)
   |- Node(4830, Terminal=attack)
   |- Node(861, Terminal=8)
   |- Node(1233, Terminal=13)
   |- Node(605, Terminal=4)


## token列から終端記号列の復元

In [8]:
# 得られたtoken_treeがすべての終端記号を含んでいるか確認(=ちゃんとすべての終端記号を木が表現できているか確認)
token_tree_terminals = converter.logits_processor[0].token_tree.root.get_subtree_terminals()
terminals = list(converter.protocol_grammar.terminals)
#小文字に変換
for i,terminal in enumerate(terminals):
    terminals[i] = terminal.lower()
terminals = set(terminals)

token_tree_terminals >= terminals # Trueなので問題なし

True

In [9]:
text = 'agent 12 divined agent 10 human'
token_ids = converter.tokenizer.encode_plus(text,add_special_tokens=False)["input_ids"]
print(token_ids)

In [20]:
text = 'agent 12 guarded agent 10 human'
token_ids = converter.tokenizer.encode_plus(text,add_special_tokens=False)["input_ids"]
print(token_ids)
tokens = converter.tokenizer.tokenize(text)
print(tokens)
print(converter.tokenizer.convert_tokens_to_ids(tokens))


[12631, 886, 7218, 312, 12631, 746, 2199]
['▁agent', '▁12', '▁guard', 'ed', '▁agent', '▁10', '▁human']
[12631, 886, 7218, 312, 12631, 746, 2199]


In [18]:
#瞬時符号か確認
token_tree:Tree =converter.logits_processor[0].token_tree
token_tree.root.is_instant() # Falseで瞬時符号ではない

(False, [(7218, 'guard')])

In [None]:
#木の深さまでみれば復号はできる
token_tree:Tree =converter.logits_processor[0].token_tree
cur_node = token_tree.root
terminals_list = []
for token_id in token_ids:
    if cur_node.is_leaf:
        terminals_list.append(cur_node.terminal)
        cur_node = token_tree.root
    else:
        
        cur_node = cur_node.children[token_id]

In [24]:
converter.tokenizer.batch_encode_plus(["COMINGOUT","XOR"],add_special_tokens=False)

{'input_ids': [[262, 10170, 84, 82, 78, 83, 76, 84, 90, 89], [262, 93, 84, 87]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]]}

In [25]:
print(converter.logits_processor[0].token_tree.root.get_subtree_terminals())
print(converter.logits_processor[0].token_tree.root.children[7218].get_subtree_terminals())
print(converter.logits_processor[0].token_tree.root.children[262].get_subtree_terminals())

{'01', 'seer', '08', 'guarded', '14', '10', 'because', 'estimate', '0', 'comingout', '09', 'vote', '6', 'agree', 'bodyguard', 'attack', '07', '13', 'skip', '7', '03', 'id', 'possessed', '1', 'divined', 'day', 'guard', 'agent', '11', 'identified', 'inquire', 'divination', 'and', 'request', '3', 'werewolf', ')', 'voted', 'medium', 'over', '8', '04', '4', '(', 'villager', '15', '12', 'any', '05', 'disagree', '5', 'human', 'or', '2', 'not', 'attacked', '06', 'xor', 'ε', '9', '02'}
{'guard', 'guarded'}
{'01', '08', '07', '05', 'estimate', '03', ')', '09', '06', '04', '02'}


In [26]:
print(converter.logits_processor[0].token_tree.root.children[7218].get_subtree_terminals())

{'guard', 'guarded'}


In [27]:
#
"Agent[12] DIVINED Agent[10] HUMAN".lower()

'agent[12] divined agent[10] human'

In [31]:
# 得られたtoken_treeがすべての終端記号を含んでいるか確認
token_tree_terminals = converter.logits_processor[0].token_tree.root.get_subtree_terminals()
terminals = list(converter.protocol_grammar.terminals)
#小文字に変換
for i,terminal in enumerate(terminals):
    terminals[i] = terminal.lower()
terminals = set(terminals)

token_tree_terminals >= terminals # Trueなので問題なし

True