In [None]:
import re
from disco.scorers import BooleanScorer
from disco.distributions import LMDistribution
from disco.samplers import AccumulationSampler, QuasiRejectionSampler
import transformers
import numpy as np
import torch

# QRS sampler in Disco library

We recommend beta = $Z$, but we did not illustrate theoretical background. It is future work.

In [None]:
beta=0.5
sampler = QuasiRejectionSampler(target_ebm, model, beta=beta)
samples, log_scores = sampler.sample(sampling_size=2**7)

However, to compare with MCMC, we should implement in the other way.

# Experimental Setup

For this experiment, if you want to analyze more clearly, we recommend the bad proposal models.

In [None]:
token = "" # your huggingface token
a = LMDistribution("models/gemma-2b", token=token, LLM=True)
b = lambda s, c: bool(re.search(r"\bamazing\b", s.text)) # hard constraint
a2 = LMDistribution("models/amazing/dpg-fail", token=token, LLM=True) # Recipe: very low batch_size for DPG training

In [None]:
scorer = BooleanScorer(b)
g = a * scorer

In [None]:
distr = AccumulationSampler(distribution=a2, total_size=500000)
samples_q2, distr_q2 = distr.sample(sampling_size=500, context="")
len(samples_q2)

In [None]:
start = []
for i in range(len(samples_q2)):
    if b(samples_q2[i],_) :
        start.append(samples_q2[i])
len(start)

With a lot of samples, let's track how distribution will change!

In [None]:
def partition(text_list):
    output = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    for _, text in enumerate(text_list):
        j = text[1].index('amazing')/len(text[1])
        output[int(j//0.1)] += 1
    output = [float(i)/sum(output) for i in output]
    return output

In [None]:
partition(start)

In [None]:
def QRS_filter(target, beta, y_a2, loga2):
    out = []
    out_prob = []
    for i in range(len(y_a2)):
        target_log_scores = target.log_score(samples=[y_a2[i]], context="").to("cuda")
        rs = torch.exp(target_log_scores - loga2[i]) / beta
        us = torch.rand(len(rs)).to("cuda")
        if us<rs:
            out.append(y_a2[i])
            out_prob.append(loga2[i])
#             out_prob.append(min(target.log_score, beta*loga2[i]))
    return out, out_prob

In [None]:
samples_g2 = []
distr_g2 = []
for i in range(len(samples_q2)):
    if b(samples_q2[i], _):
        samples_g2.append(samples_q2[i])
        distr_g2.append(distr_q2[i])
len(samples_g2)

In [None]:
samples_g3, distr_g3 = QRS_filter(target=g, beta=0.05, y_a2=samples_g2[:20000], loga2=distr_g2)

print('AR degradation is ', len(samples_g3)/len(samples_g2))

In [None]:
partition(samples_g3)

In [None]:
def IMH(text, ebm, proposal, gop, n=50):
    samples_q = text
    distr_q = proposal.log_score(samples_q, context=gop)
    distr_P = ebm.log_score(samples_q, context="")
    for i in range(n):
        distr = AccumulationSampler(distribution=proposal, total_size=len(distr_P))
        samples_q2, distr_q2 = distr.sample(sampling_size=250, context=gop)
        distr_P2 = ebm.log_score(samples_q2, context="")
        for i in range(len(distr_P)):
            rs = torch.exp(distr_P2[i]-distr_q2[i]+distr_q[i]-distr_P[i])
            us = torch.rand(1).to("cuda")
            if us<rs:
                samples_q[i] = samples_q2[i]
                distr_q[i] = distr_q2[i]
                distr_P[i] = distr_P2[i]
        # print(torch.mean(distr_P) - sum(distr_q)/len(distr_q))
    return samples_q

In [None]:
print(partition(start))
IMH_out = IMH(text = start, ebm=g, proposal=a2, gop=gop, n=50)

print('AR degradation is ', len(IMH_out)/len(start))
print(partition(IMH_out))

# Visualization as UMAP

In [None]:
def chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]

def fetch_vectors(string_list, batch_size=64):
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
    model = transformers.DistilBertModel.from_pretrained("distilbert/distilbert-base-uncased")
    model.to(DEVICE)

    fin_features = []
    total = len(string_list) // batch_size + 1
    for data in tqdm(chunks(string_list, batch_size), total=total):
        tokenized = []
        for x in data:
            x = " ".join(x.strip().split()[:300])
            tok = tokenizer.encode(x, add_special_tokens=True)
            tokenized.append(tok[:512])

        max_len = 512
        padded = np.array([i + [0] * (max_len - len(i)) for i in tokenized])
        attention_mask = np.where(padded != 0, 1, 0)
        input_ids = torch.tensor(padded).to(DEVICE)
        attention_mask = torch.tensor(attention_mask).to(DEVICE)
        
        with torch.no_grad():
            last_hidden_states = model(input_ids, attention_mask=attention_mask)

        features = last_hidden_states[0][:, 0, :].cpu().numpy()
        fin_features.append(features)

    fin_features = np.vstack(fin_features)
    return fin_features

In [None]:
gemma = pd.read_json('texts/gemma-amazing.jsonl', lines=True)
prompt = pd.read_json('texts/prompt-amazing.jsonl', lines=True)
dpg = pd.read_json('texts/dpg-amazing.jsonl', lines=True)

In [None]:
out = gemma['text'].tolist()[:3500]+prompt['text'].tolist()[:3500]+dpg['text'].tolist()[:3500]

out_vec = fetch_vectors(out)

In [None]:
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

# out = TSNE(n_components=2, perplexity=200).fit_transform(out_vec)

reducer = umap.UMAP()
out_vec2 = StandardScaler().fit_transform(out_vec)
out = reducer.fit_transform(out_vec2)

print('finish')

out.shape

In [None]:
length = 200

data1 = go.Scatter(
    x=out[:length, 0],
    y=out[:length, 1],
    mode="markers",
    name="Gemma g",
    marker=dict(color='red')
)


data4 = go.Scatter(
    x=out[7000:7000+length, 0],
    y=out[7000:7000+length, 1],
    mode="markers",
    name="DPG g'",
    marker=dict(color='lime')
)


data2 = go.Scatter(
    x=out[3500:3500+length, 0],
    y=out[3500:3500+length, 1],
    mode="markers",
    name="prompted g'",
    marker=dict(color='blue')
)



fig = go.Figure()
fig.add_trace(data1)
fig.add_trace(data4)
fig.add_trace(data2)

fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict( 
    orientation="h",
    font=dict(size=30),
))
fig.update_layout(height=500)
fig.write_image("fig/lexical_umap.pdf")