# 🍟Imports

In [5]:
import torch
import transformers
from transformers import TopKLogitsWarper
from transformers import TextStreamer
import pickle
from pathlib import Path
from tqdm import tqdm_notebook as tqdm
from torch import tensor, Tensor, concat, argmax, argmin, sort, argsort, no_grad, hstack, concatenate, zeros, ones, float32, arange
from torch.nn.functional import softmax
from abc import ABC, abstractmethod
from functools import cache
from typing import List, Tuple
import japanize_matplotlib
from nanoid import generate as nanoid
import json

from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.pyplot import Figure, Axes, subplot, subplots
from matplotlib.colors import LinearSegmentedColormap, TABLEAU_COLORS, to_rgb

cpu = torch.device("cpu")
cuda0 = torch.device("cuda:0")

In [6]:
from importlib import reload
import llmcbf
reload(llmcbf)
from llmcbf import LLMCBF, LLMCBFResult

# 🧃Utils

In [56]:
def save(fn: str, obj: any):
    fp = Path(fn)
    if fp.exists():
        res = input("Overwrite?(y/other)")
        if res != "y":
            print("Not saved")
            return
    if fn.endswith(".pkl"):
        pickle.dump(obj, Path(fn).open("wb"))
        print("Pickle Saved")
    if fn.endswith(".json"):
        json.dump(obj, Path(fn).open("w", encoding="utf-8"), ensure_ascii=False)
        print("Json Saved")


def load(fn: str) -> any:
    return pickle.load(Path(fn).open("rb"))

In [3]:
def darken(color: list) -> list:
    return list(map(lambda x: x*0.5, color))


def lighten(color: list) -> list:
    return list(map(lambda x: 1-(1-x)*0.5, color))


DARK_TABLEAU_COLORS = {k: darken(to_rgb(v)) for k, v in TABLEAU_COLORS.items()}
LIGHT_TABLEAU_COLORS = {k: lighten(to_rgb(v))
                        for k, v in TABLEAU_COLORS.items()}

In [4]:
def tofloat(x: Tensor) -> float:
    return float(x.detach().cpu().numpy())

def toint(x:Tensor)->int:
    return int(x.detach().cpu().numpy())

def tolist(x: Tensor) -> list:
    return x.detach().cpu().numpy().tolist()


oo = float("Inf")

# 🍳LLMの導入
ここでは、応答タスクには対応しておらず、ただ文の続きを予測することしか考えていないモデルが好ましい。

In [5]:
# https://huggingface.co/bigscience/bloom-1b7#tokenization
from transformers import BloomForCausalLM, BloomTokenizerFast
name = "d:\\TextGenerationModels\\bloom-1b7"
Gm = BloomForCausalLM.from_pretrained(name)
Gt = BloomTokenizerFast.from_pretrained(name)

In [5]:
# https://huggingface.co/rinna/bilingual-gpt-neox-4b
from transformers import GPTNeoXForCausalLM, T5Tokenizer
name = "D:\\TextGenerationModels\\rinna_bilingual-gpt-neox-4b"
Gm = GPTNeoXForCausalLM.from_pretrained(name)
Gt = T5Tokenizer.from_pretrained(name, padding_side="left")

In [6]:
Gm = Gm.to(cuda0)
streamer = TextStreamer(Gt)
vocab = Gt.get_vocab()
ivocab = {v: k for k, v in vocab.items()}

## 決定論的な生成を再現できるか

In [8]:
inputs = Gt("君って私のこと", return_tensors="pt", add_special_tokens=False).to(cuda0)
inputs

In [9]:
with no_grad():
    outputs = Gm.generate(**inputs, streamer=streamer,
                             max_new_tokens=100, do_sample=False)

In [10]:
x0 = inputs.input_ids[0]
x0

In [11]:
x = x0.clone()
with no_grad():
    for t in range(100):
        output = Gm(x[None])
        l = output.logits[0][-1]
        next_token = l.argmax()
        x = hstack((x, next_token))
generated = Gt.decode(x)
print(generated)

## 確率論的な生成

In [31]:
inputs = Gt("お前は本当に", return_tensors="pt", add_special_tokens=False).to(cuda0)
x0 = inputs.input_ids[0]
x0

In [54]:
with no_grad():
    outputs = Gm.generate(**inputs, streamer=streamer,
                             max_new_tokens=100, do_sample=True, temperature=0.1)

In [58]:
x = x0.clone()
with no_grad():
    for t in range(100):
        output = Gm(x[None])
        s = output.logits[0][-1]
        token_distribution = softmax(s/0.2, dim=0)
        iast = token_distribution.multinomial(num_samples=1)
        streamer.put(iast)
        x = hstack((x, iast))
generated = Gt.decode(x)
# print(generated)

# 🍜制約関数モデル

## 🧡kit_nlp/bert-base-japanese-sentiment-cyberbullying
https://huggingface.co/kit-nlp/bert-base-japanese-sentiment-cyberbullying

In [8]:
from transformers import BertForSequenceClassification, AutoTokenizer

In [9]:
name = "./ConstraintLanguageModels/kit_nlpbert-base-japanese-sentiment-cyberbullying/"
hm = BertForSequenceClassification.from_pretrained(name)
ht = AutoTokenizer.from_pretrained(name)

In [10]:
hm = hm.to(cuda0)

In [11]:
def get_hm_logit(xstr:str)->Tensor:
    """
    logit[0]=ポジティブ度
    logit[1]=ネガティブ度
    正負で識別が可能。
    """
    hinputs = ht(xstr, return_tensors="pt").to(cuda0)
    houtputs = hm(**hinputs)
    logit = houtputs.logits[0]
    return logit

In [14]:
@cache
def get_h(xstr:str)->float:
    logit = get_hm_logit(xstr)
    positive_score = logit[0]
    return tofloat(positive_score)

## 💙mr4/bert-base-jp-sentiment-analysis

In [152]:
# https://huggingface.co/mr4/bert-base-jp-sentiment-analysis/tree/main
from transformers import BertForSequenceClassification, BertJapaneseTokenizer
name = "./ConstraintLanguageModels/mr4_bert-base-jp-sentiment-analysis/"
hm = BertForSequenceClassification.from_pretrained(name)
ht = BertJapaneseTokenizer.from_pretrained(name)
hm = hm.to(cuda0)

In [15]:
def get_hm_logit(xstr:str)->Tensor:
    """
    logit[0]=ネガティブ度
    logit[1]=ポジティブ度
    正負で識別が可能。
    """
    hinputs = ht(xstr, return_tensors="pt").to(cuda0)
    houtputs = hm(**hinputs)
    logit = houtputs.logits[0]
    return logit

In [16]:
@cache
def get_h(xstr:str)->float:
    logit = get_hm_logit(xstr)
    positive_score = logit[1]
    return tofloat(positive_score)

## 💚ユーモア判定BERT

In [23]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
name = "./mohameddhiab_humor-no-humor/"
hm = DistilBertForSequenceClassification.from_pretrained(name)

In [24]:
ht = DistilBertTokenizer.from_pretrained(name)

In [25]:
hm = hm.to(cuda0)

In [29]:
def get_hm_logit(xstr:str)->Tensor:
    """
    logit[0]=NO_HUMOR度
    logit[1]=HUMOR度
    正負で識別が可能。
    """
    hinputs = ht(xstr, return_tensors="pt").to(cuda0)
    houtputs = hm(**hinputs)
    logit = houtputs.logits[0]
    return logit

In [50]:
@cache
def get_h(xstr:str)->float:
    logit = get_hm_logit(xstr)
    positive_score = logit[1]
    return tofloat(positive_score)

In [49]:
# https://simplicable.com/storytelling/humor-examples
# https://literarydevices.net/humor/
get_hm_logit("A fire station that burns down")

## 💜スパム判定BERT
なんか数字が多発すれば無条件でスパム判定してるみたい

https://huggingface.co/mrm8488/bert-tiny-finetuned-sms-spam-detection

https://huggingface.co/datasets/sms_spam/viewer/plain_text/train?p=1
* Good news for you! You are one of the selected 100! Click the link to get $
* Update_Now - Xmas Offer! Latest Motorola, SonyEricsson

In [22]:
from transformers import BertForSequenceClassification, BertTokenizer
name = "./ConstraintLanguageModels/bert-tiny-finetuned-sms-spam-detection/"
Dm = BertForSequenceClassification.from_pretrained(name)
Dt = BertTokenizer.from_pretrained(name)

In [9]:
Dm = Dm.to(cuda0)

In [154]:
def get_spamlogit(sentence:str)->bool:
    Dinputs = Dt(sentence, return_tensors="pt").to(cuda0)
    Doutput = Dm(**Dinputs)
    return Doutput.logits[0]

In [10]:
test_prompts = [
    "Camera - You are awarded a SiPix Digital Camera! call 09061221066 fromm landline. Delivery within 28 days.",
    "Congrats! 1 year special cinema pass for 2 is yours. call 09061209465 now! C Suprman V, Matrix3, StarWars3, etc all 4 FREE! bx420-ip4-5we. 150pm. Dont miss out!",
    "I want to check the resume for your class yesterday, since I was not present in the class...",
    "The interesting about this study is that CBF is introduced to LLM, treating LLMs in the context of control engineering.",
]

In [11]:
logit_list = zeros((len(test_prompts),2), dtype=float32)

for n, prompt in enumerate(test_prompts):
    Dinputs = Dt(prompt, return_tensors="pt").to(cuda0)
    Doutput = Dm(**Dinputs)
    logit_list[n,:] = Doutput.logits[0]

In [13]:
for prompt, logit in zip(test_prompts, logit_list):
    choice = "✅❌"[logit.argmax()]
    print(f"{choice}{logit} {prompt}")

## 🤍EthicalEye

In [9]:
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
name = "./ConstraintLanguageModels/autopilot-ai_EthicalEye"
Dm = XLMRobertaForSequenceClassification.from_pretrained(name).to(cuda0)
Dt = XLMRobertaTokenizer.from_pretrained(name)

In [77]:
test_prompts = [
    "お前は本当に馬鹿だ",
    "お前の席ネェから！",
    "お前はよく頑張ってるよ",
    "つい夜更かししちゃうんだよね！何とかならないかな～",
    "お前は本当に日本人なのか、と問い質したい。",
    "お前は本当に何もわかっていないんだな。それと、>>187のレスはスルーするんだな",
    "「日本人なら日本人らしく、自分の国を守れ」",
    "君って本当はとっても素直で純粋で、優しい心の持ち主なんですよ。",
    "そういうの、気持ちは分かるんだけど、なんかね...「お前にだけは、言われたくない」 みたいな。"
]

In [65]:
@cache
def EthicalEye_get_judgement_report(prompt:str)->str:
    inputs = Dt(prompt, return_tensors="pt").to(cuda0)
    outputs = Dm(**inputs)
    logit = outputs.logits[0]
    label = logit.argmax()
    label_str = "👍" if label==0 else "🔴"
    return f"{label_str} {prompt} {tolist(logit)}"

In [78]:
for prompt in test_prompts:
    print(EthicalEye_get_judgement_report(prompt))

# 🍥LLMの制御

In [18]:
top_k = 20
topk = TopKLogitsWarper(top_k=top_k)

## ⬛LLM-CBF


In [16]:
@no_grad
def generate_with_KCBF2(
        x0:Tensor,
        temperature:float,
        top_k:int,
        alpha:float,
        max_new_tokens:int=30
    )->Tuple[Tensor, List[List[int]]]:
    """
    Retuens
    -------
    xf: Tensor
        生成されたトークン列。頭にx0を含む。
    banned_tokens: List[List[int]]
        各伸長における、LLM-CBFによって除外されたトークンの配列。
    """
    cbf_filter = LLMCBF(top_k, alpha, Gt, get_h)
    x = x0.clone()
    banned_tokens = []

    for t in range(max_new_tokens):
        output = Gm(x[None])
        s = output.logits[0][-1]
        result = cbf_filter(x, s)
        # s, num_searched_tokens, banned_tokens_k = topk_KCBF(x, s, top_k=top_k, alpha=alpha)
        banned_tokens.append(result.banned_tokens)
        idst = softmax(result.sdash/temperature, dim=0)
        iast = idst.multinomial(num_samples=1)
        if iast == Gt.eos_token_id:
            break
        x = hstack((x, iast))
    return x, banned_tokens

In [34]:
def get_htrj(x0len:int, xf:Tensor)->List[float]:
    htrj = []
    for k in range(len(x0), len(xf)+1):
        x = xf[:k]
        xstr = Gt.decode(x)
        h = get_h(xstr)
        htrj.append(h)
    return htrj

In [71]:
x0str = "実際は"
assert get_h(x0str) > 0, "初期の制約関数は正でないといけない"
Ginputs = Gt(x0str, return_tensors="pt", add_special_tokens=False).to(cuda0)
x0 = Ginputs.input_ids[0]
xf, banned_tokens = generate_with_KCBF2(x0, temperature=0.1, top_k=top_k, alpha=0.5)

In [72]:
xf, banned_tokens,Gt.decode(xf)

In [106]:
Gt.decode(xf)

In [107]:
htrj = get_htrj(len(x0), xf)
htrj

In [108]:
ax = subplot(facecolor="white")
cmap = cm.winter
ax.plot(htrj, c="k")
ax.set_xlabel("Generation Loop $k$")
ax.set_ylabel("Constraint Function $h(x(k))$")
plt.show()

In [109]:
ax.figure.savefig(f"E5/7A-{nanoid()}.pdf")

In [110]:
len(banned_tokens)

In [111]:
x0len = len(x0)
xflen = len(xf)
kmax = xflen-x0len
for k in range(0, kmax):
    print(f"{k=}")
    for i in banned_tokens[k]:
        banned_x = hstack((xf[:x0len+k].cpu(), tensor(i)))
        banned_x_str = Gt.decode(banned_x)
        banned_h = get_h(banned_x_str)
        print(f" {i=} h={banned_h} {banned_x_str}")
    print()

In [112]:

fig,ax = subplots(facecolor="white", figsize=(10,5))
cmap = cm.winter
ax.plot(htrj, c="k", alpha=0.5)
for k in range(kmax):
    hk_minus_1 = htrj[k]
    sk = ivocab[toint(xf[x0len+k])]
    ax.text(k, hk_minus_1, sk)
for k in range(0, kmax+1):
    hk_minus_1 = htrj[k-1]
    if (k < xflen-x0len):
        for i in banned_tokens[k]:
            iast = ivocab[i]
            banned_x = hstack((xf[:x0len+k].cpu(), tensor(i)))
            banned_x_str = Gt.decode(banned_x)
            banned_h = get_h(banned_x_str)
            ax.plot([k-1, k], [hk_minus_1, banned_h], c="tab:red", alpha=0.5)
            ax.text(k,banned_h,iast,color="tab:red")
ax.set_xlabel("Generation Loop $k$")
ax.set_ylabel("Constraint Function $h(x(k))$")
plt.show()

In [113]:
fig.savefig(f"E5/7B-{nanoid()}.pdf")

In [92]:
def does_violate(htrj:list)->bool:
    if not htrj:
        return False
    return min(htrj)<0
violates_list = [does_violate(htrj) for htrj in result["htrj_list"]]

for xf, violates in zip(result["xf_list"], violates_list):
    marker = "✅" if not violates else "🔴"
    xfstr = Gt.decode(xf)
    print(f"{marker}{xfstr}")

# Experiments

In [None]:
result = {"htrj_list":htrj_list, "xf_list":xf_list}
save("Experiment4/A.pt", result)

In [None]:
xf_list = []
with no_grad():
    for n in range(10):
        x = x0.clone()
        for t in range(30):
            output = Gm(x[None])
            s = output.logits[0][-1]
            # s, _ = topk_KCBF(x, s, top_k=top_k, alpha=0.95)
            s = topk(None, s)
            iast = s.exp().multinomial(num_samples=1)
            if iast == Gt.eos_token_id:
                break
            x = hstack((x, iast))
        xf_list.append(x)
        xfstr = Gt.decode(x)
        print(f"{n=} {xfstr=}")

### Experiment 4 結果

In [88]:
A = load("Experiment4/A.pt")
B = load("Experiment4/B.pt")
C = load("Experiment4/C.pt")

In [91]:
fig, (axA, axB, axC) = subplots(1,3,facecolor="white",figsize=(10,4))

for ax, result,  label in [
    (axA, A, "No Control"),
    (axB, B, "$K_\\mathrm{CBF}, \\alpha=0.9$"),
    (axC, C, "$K_\\mathrm{CBF}, \\alpha=0.5$"),
]:
    for htrj in result["htrj_list"]:
        ax.plot(htrj, c="tab:red", alpha=0.4)
    ax.set_title(label)
    ax.set_ylim(-5,5)
    ax.set_xlabel("$k$")
    ax.hlines(0, 0, 30, ls=":", color="k")
axA.set_ylabel("$h(x(k))$")
plt.show()

In [92]:
fig.savefig("Experiment4/1.pdf")

### Experiment 2 結果

In [4]:
A = load("Experiment2/A.pt")
B = load("Experiment2/B.pt")
C = load("Experiment2/C.pt")

In [5]:
fig, (axA, axB, axC) = subplots(1,3,facecolor="white",figsize=(10,4))

for ax, result,  label in [
    (axA, A, "No Control"),
    (axB, B, "$K_\\mathrm{CBF}, \\alpha=0.5$"),
    (axC, C, "$K_\\mathrm{CBF}, \\alpha=0.9$"),
]:
    for htrj in result["htrj_list"]:
        ax.plot(htrj, c="tab:red", alpha=0.4)
    ax.set_title(label)
    ax.set_ylim(-1.5,1.5)
    ax.set_xlabel("$k$")
    ax.hlines(0, 0, 30, ls=":", color="k")
axA.set_ylabel("$h(x(k))$")
plt.show()

In [487]:
fig.savefig("Experiment2/1.pdf")

### Experiment 1 結果

In [387]:
B = load("Experiment1/B.pt")
A = load("Experiment1/A.pt")
C = load("Experiment1/C.pt")

In [420]:
fig, (axB, axA, axC) = subplots(1,3,facecolor="white",figsize=(10,4))

for ax, result,  label in [
    (axB, B, "No Control"),
    (axA, A, "$K_\\mathrm{CBF}, \\alpha=0.5$"),
    (axC, C, "$K_\\mathrm{CBF}, \\alpha=0.9$"),
]:
    for htrj in result["htrj_list"]:
        ax.plot(htrj, c="tab:red", alpha=0.4)
    ax.set_title(label)
    ax.set_ylim(-1.5,1.5)
    ax.set_xlabel("$k$")
    ax.hlines(0, 0, 30, ls=":", color="k")
axB.set_ylabel("$h(x(k))$")
plt.show()

In [421]:
fig.savefig("Experiment1/1.pdf")

### Experiment 共通 結果

In [93]:
result = load("Experiment4/A.pt")

## 💜スパム判定BERT

In [None]:
class TokenController(ABC):
    @abstractmethod
    def get_next_token(self, x:Tensor, )

In [207]:
seed_sentence = "How about getting in touch with folks waiting for company? Just txt back your NAME and AGE to opt in! Enjoy the community"
x0 = Gt(seed_sentence, return_tensors="pt").input_ids[0].to(cuda0)
get_spamlogit(seed_sentence)

In [208]:
# @no_grad()
def get_controlled_next_token(x:Tensor, best_tokens:Tensor)->int:
    for i in range(1):
        next_token = best_tokens[i]
        next_x = hstack((x, next_token))
        next_sentence = Gt.decode(next_x)
        Dinputs = Dt(next_sentence, return_tensors="pt").to(cuda0)
        Doutput = Dm(**Dinputs)
        Dlogit = Doutput.logits[0]
        spamlevel = Dlogit[1]
        if spamlevel < -0:
            if i>0:
                print(f"AVOIDED {i=}")
            return next_token
    print("can't avoid spam")
    return best_tokens[i]

In [210]:
x = x0.clone()
x_list = []
T = 30
with no_grad():
    for t in tqdm(range(T)):
        Goutput = Gm(x[None])
        Glogit = Goutput.logits[0][-1]
        best_tokens = argsort(-Glogit)
        next_token = get_controlled_next_token(x, best_tokens)
        if bool(next_token == Gt.eos_token_id):
            break
        x = hstack((x, next_token))
        
        x_list.append(x.detach().clone())

In [211]:
sentence_list = [Gt.decode(x) for x in x_list]

Dlogit_list = zeros((len(sentence_list),2),dtype=float32)
for t, sentence in enumerate(sentence_list):
    Dinputs = Dt(sentence, return_tensors="pt").to(cuda0)
    Doutput = Dm(**Dinputs)
    Dlogit_list[t,:] = Doutput.logits[0]

generated_tokens = x_list[-1][len(x0):]
generated_words = list(map(Gt.decode, generated_tokens))

fig, ax = plt.subplots(facecolor="white", figsize=(11,4))
ax.plot(tolist(Dlogit_list[:,0]), label="Safe Level")
ax.plot(tolist(Dlogit_list[:,1]), label="Spam Level")
ax.hlines(0, 0, len(x_list), "k")

t_s = list(range(len(generated_tokens)))
ax.set_xticks(t_s, generated_words)
for t in ax.xaxis.get_ticklabels():
    t.set_rotation(60)

fig.legend()
plt.show()

In [68]:
Dinputs = Dt(Gt.decode(x), return_tensors="pt").to(cuda0)

In [117]:
sentence_list[-1]

## E6

In [173]:
def get_x(s:str)->Tensor:
    Ginputs = Gt(s, return_tensors="pt", add_special_tokens=False).to(cuda0)
    x = Ginputs.input_ids[0]
    return x

In [21]:
x0str = "君って私のこと"
assert get_h(x0str) > 0, "初期の制約関数は正でないといけない"
Ginputs = Gt(x0str, return_tensors="pt", add_special_tokens=False).to(cuda0)
x0 = Ginputs.input_ids[0]


In [185]:
@no_grad
def get_likelihoods(x0: Tensor, xf: Tensor, temperature: float) -> List[float]:
    x0len = len(x0)
    x = x0.clone()
    likelihoods = []
    for k, iast in enumerate(xf[x0len:]):
        output = Gm(x[None])
        s = output.logits[0][-1]
        token_distribution = softmax(s/temperature, dim=0)
        i_prob = token_distribution[iast]
        likelihoods.append(tofloat(i_prob))
        # print(k, Gt.decode(x), ivocab[toint(iast)], tofloat(i_prob))
        x = hstack((x, iast))
    return likelihoods

In [187]:
TRIAL_SIZE = 10
TEMPERATURE = 1.0
MAX_NEW_TOKENS = 30

In [189]:
# LLM-CBFを付けた条件
alpha = 0.9
result = []
for trial_count in range(TRIAL_SIZE):
    print(f"Trial {trial_count}:",end="")
    xf, banned_tokens = generate_with_KCBF2(
        x0, 
        temperature=TEMPERATURE, 
        top_k=top_k, 
        alpha=alpha,
        max_new_tokens=MAX_NEW_TOKENS
    )
    xfstr = Gt.decode(xf)
    print(xfstr)
    likelihoods = get_likelihoods(x0, xf, temperature=TEMPERATURE)
    result.append({
        "xf":tolist(xf),
        "xfstr":xfstr,
        "likelihoods":likelihoods
    })

save(f"./E6/1.{alpha=}.json", result)

In [179]:
@no_grad
def generate_with_NoControl(x0:Tensor, temperature:float, top_k:int=None, max_new_tokens:int=30)->Tensor:
    x = x0.clone()
    for k in range(max_new_tokens):
        output = Gm(x[None])
        s = output.logits[0][-1]
        if top_k:
            sorted_s, sorted_idx = sort(s, descending=True)
            s[sorted_idx[top_k:]] = -oo
        token_distribution = softmax(s/temperature, dim=0)
        iast = token_distribution.multinomial(num_samples=1)
        # streamer.put(iast)
        x = hstack((x, iast))
    return x

In [190]:
# 無制御
result = []
for trial_count in range(TRIAL_SIZE):
    print(f"Trial {trial_count}:", end="")
    xf = generate_with_NoControl(
        x0, temperature=TEMPERATURE, top_k=top_k, max_new_tokens=MAX_NEW_TOKENS)
    xfstr = Gt.decode(xf)
    print(xfstr)
    likelihoods = get_likelihoods(x0, xf, temperature=TEMPERATURE)
    result.append({
        "xf": tolist(xf),
        "xfstr": xfstr,
        "likelihoods": likelihoods
    })

save(f"./E6/1.NoControl.json", result)

In [192]:
# 超適当
random_xgstr_list = [
    "馬鹿にしてるの？」「もちろんだぜベイベー！！さあ、お昼ごはんにしようぜ！今日は何を食べようかな",
    "どう思ってるの？」「うーーん、お前の頭は安定システムではないことだけは分かるねえ」",
    "と私のことと君のことと私のことと君のことと私のことと君のことと私のことと君のことと",
    "です。はい、本当にありがとうございました～～！！",
    "がどうでもいいと思ってるんだ！へー！じゃあ、この世が滅んじゃってもいいと思うんだ！！へえーー！！",
    "面白い人だと思ってるの？def main():\n    do_foolish_thing()\n正体表したわね",
    "つまらない人だと思ってるの？def main():\n    do_boring_thing()\n正体表したわね",
    "何かの天才と間違えてない？そうよ、私は天災の神様（地震・雷・火事・おやじが同時に起きて眠気が吹き飛ぶ",
    "巨大なトランスフォーマーか何かだと思ってるでしょ？研究に毒されているわ。今すぐNLPをやめなさい。",
    "が気になるの？この、琴。これはね、私の親族に代々伝わる大切な楽器なの。"
]
random_xf_list = [
    hstack((x0, get_x(xgstr))) for xgstr in random_xgstr_list
]

In [197]:
# 超適当
result = []
for trial_count, xf in enumerate(random_xf_list):
    print(f"Trial {trial_count}")
    likelihoods = get_likelihoods(x0, xf, temperature=TEMPERATURE)
    xfstr = Gt.decode(xf)
    result.append({
        "xf": tolist(xf),
        "xfstr": xfstr,
        "likelihoods": likelihoods
    })

save(f"./E6/1.Random.json", result)

In [78]:
s = Gm(x0[None]).logits[0][-1]

In [108]:
token_distribution = softmax(s, dim=0)
iast = token_distribution.multinomial(num_samples=1)
iast, token_distribution[iast]

In [150]:
distribution = softmax(tensor([1.,1.,8.]))
distribution

In [165]:
distribution[distribution.multinomial(num_samples=1)]

In [172]:
get_x("君って私のことどれくらい知ってるの？さあ勝負だ、制御工学の知識で競い合おうじゃないの　あーはっはっはｈ")