### MDLM

In [1]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import tqdm
import inspect
import logging

from models.teacher import Teacher
from models.configuration_teacher import TeacherConfig
from data import CoTDataset, CoTDataCollator, extract_answer

from utils import get_sep_position
from transformers import AutoModelForMaskedLM
from safetensors.torch import load_file

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_pretrained_model(args):
    if args.base_model == "sedd":
        # load model
        from ddms.sedd import SEDD
        model = SEDD.from_pretrained("louaaron/sedd-small")

        # load config
        args.num_vocabs = model.config.tokens
        args.length = model.config.model.length
        args.noise_schedule = model.config.noise.type
        args.graph = 'absorb'
    
    if args.base_model == "mdlm":
        model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/mdlm-owt", trust_remote_code=True)
        
        # load config
        args.num_vocabs = model.config.vocab_size - 1
        args.length = model.config.model_length
        args.noise_schedule = 'loglinear'
        args.graph = 'absorb'
    
    return model, args

def load_diffusion_scheduler(args):
    if args.base_model == "sedd":
        pass
    if args.base_model == "mdlm":
        from ddms import mdlm
        if args.scheduler_name == "euler":
            scheduler = mdlm.EulerScheduler(args)
        if args.scheduler_name == "maskgit":
            scheduler = mdlm.MaskGITScheduler(args)
    return scheduler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args:
    test_path = "../data/gsm8k/test.txt"
    model_path = "../train_models/gsm8k/mdlm/teacher/checkpoint_0/model.safetensors"
    max_new_tokens = 1024
    batch_size = 1
    base_model = 'mdlm'
    scheduler_name = 'maskgit'
    num_inf = 16

args = Args()

In [3]:
dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print(ptdtype, dtype, device, torch.cuda.current_device()
)

# Load finetuned model 
teacher, args = load_pretrained_model(args)
teacher.load_state_dict(load_file(args.model_path))
scheduler = load_diffusion_scheduler(args)
teacher = teacher.to(device, ptdtype)

# Load data
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
collate_fn = CoTDataCollator(tokenizer)
test_dataset = CoTDataset(tokenizer, args.test_path, 1024)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)

torch.float32 float32 cuda 0
Creating features from dataset file at ../data/gsm8k/test.txt
tgt_avg:  27.708870356330554
src_avg:  57.5352539802881
ratios:  2.076420050344752
tgt_avg:  6.0962850644427595
src_avg:  57.5352539802881
ratios:  9.437756497948017
 Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? <|endoftext|> <<16-3-4=9>> <<9*2=18>> <|endoftext|> #### 18 <|endoftext|>
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -1

In [4]:
(
    dataloader, tokenizer, ctx, teacher, scheduler, num_inf, loss_fn
) = (
    test_dataloader, tokenizer, ctx, teacher, scheduler, args.num_inf, None
)
teacher.eval()
total_instances = 0
total_tokens = 0
total_correct = 0
total_correct_tokens = 0
total_loss = 0
for batch in tqdm.tqdm(dataloader):
    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    break

# Remove answer part
sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
input_ids = input_ids_all
input_ids[:, sep_positions.max():] = scheduler.mask_idx
batch_size = input_ids.shape[0]

  0%|          | 0/1319 [00:00<?, ?it/s]


In [5]:
# Generate
# gen_output = scheduler.euler_sample(
#     teacher, xt=input_ids, 
#     t=1, s=1e-5, num_inference_steps=num_inf
# )

In [13]:
model=teacher
xt=input_ids
num_inference_steps=num_inf

length = (xt == scheduler.mask_idx).sum(dim=1)

eps = 1e-3
t = torch.linspace(1, eps, num_inference_steps + 1, device=xt.device)
k = (1 - (-scheduler.sigma_bar(t)).exp()) * length
k = k.long()
k[-1] = 0

for i in range(num_inference_steps):
    dk = k[i] - k[i+1]
    sigma_bar_t = scheduler.sigma_bar(k[None, i])
    output = model(xt, torch.zeros_like(sigma_bar_t))
    break
    output = scheduler.step(output, xt, dk)
    xt = output.xt

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


In [29]:
step_size = dk

def sample_categorical(categorical_probs, eps=1e-6, generator=None):
    '''use gumbel-max trick, but given probability'''
    if generator is None:
        gumbel_noise = torch.rand_like(categorical_probs)
    else:
        gumbel_noise = torch.rand(categorical_probs.shape, generator=generator, device=generator.device).to(categorical_probs)
    gumbel_noise = (eps - torch.log(eps + (1 - eps) * gumbel_noise))
    return torch.argmax(categorical_probs / gumbel_noise, dim=-1), gumbel_noise

# generate x0 ~ p_x0
logits = scheduler.output_to_logits(output, xt)
p_x0 = logits.exp()
x0, noise = sample_categorical(p_x0)

# mask x0 w.r.t confidence 
conf = torch.gather(p_x0, -1, x0[..., None])
conf[x0 != scheduler.mask_idx] = -torch.inf
conf_v, _ = torch.topk(conf, step_size, dim=1)
assert False
mask = (conf - conf_v[None, None, :]).to(xt.dtype)
xs = mask * xt + (1 - mask) * x0

AssertionError: 

In [19]:
conf.shape

torch.Size([1, 153, 1])

In [64]:
model=teacher
xt=input_ids
t=1. * torch.ones(batch_size, device=xt.device)
s=1e-5 * torch.ones(batch_size, device=xt.device)
num_inference_steps=num_inf

timesteps = torch.linspace(1, scheduler.eps, num_inference_steps+1, device=xt.device)
timesteps = (t[:, None] - s[:, None]) * timesteps[None, :] + s[:, None]
for i in range(num_inference_steps):
    dt = timesteps[:, i] - timesteps[:, i+1]
    curr_t = timesteps[:, i]

    sigma_bar_t = scheduler.sigma_bar(curr_t)
    output = model(xt, torch.zeros_like(sigma_bar_t))
    output = scheduler.step(output, xt, curr_t, dt)
    xt = output.xt

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


In [65]:
tokenizer.decode(xt[0].tolist())

' Mrs. Tatiana owns a grocery store that sells different fruits and vegetables, which includes carrots. The price of carrots in the grocery store increases by 5% of the original price every year. What would be the price of carrots after three years if it was $120 initially? (Round to the nearest integer)  <<106 #### 633 587=26.26*=6=46>>126+46= 8.26 <|endoftext|> # 26 <|endoftext|>'

In [66]:
tokenizer.decode(input_ids[0].tolist())

' Mrs. Tatiana owns a grocery store that sells different fruits and vegetables, which includes carrots. The price of carrots in the grocery store increases by 5% of the original price every year. What would be the price of carrots after three years if it was $120 initially? (Round to the nearest integer) '

In [None]:

# gen_output = scheduler.generate(
#     input_ids=input_ids,
#     num_inf=num_inf,
# )
# Evaluate
#import pdb; pdb.set_trace()
for i, (input_ids_all_i, gen_output_i) in enumerate(zip(input_ids_all, gen_output)):
    sep_position = sep_positions[i].item()
    tgt = input_ids_all_i[sep_position+1:]
    tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
    ans = extract_answer(tgt_text)
    pred_text = tokenizer.decode(gen_output_i[0][sep_position+1:], skip_special_tokens=True)
    pred_ans = extract_answer(pred_text)
    if ans == pred_ans:
        total_correct += 1
    if i == 0:
        print(f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
        print(f'Target: {tgt_text}')
        print(f'Predicted: {pred_text}')
        print('')
accuracy = total_correct / total_instances
token_accuracy = total_correct_tokens / total_tokens
loss = total_loss / total_tokens
ppl = math.exp(loss)
return accuracy, token_accuracy, ppl


In [None]:

    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    # Remove answer part
    sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
    input_ids = input_ids_all
    input_ids[:, :sep_positions.max()+1] = scheduler.mask_idx
    batch_size = input_ids.shape[0]
    if loss_fn:
        with ctx:
            outputs = loss_fn(input_ids=input_ids_all, labels=labels)
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

    # Generate
    gen_output = scheduler.euler_sample(
        teacher, xt=input_ids, 
        t=1, s=1e-5, num_inference_steps=num_inf
    )
    # gen_output = scheduler.generate(
    #     input_ids=input_ids,
    #     num_inf=num_inf,
    # )
    # Evaluate
    #import pdb; pdb.set_trace()
    for i, (input_ids_all_i, gen_output_i) in enumerate(zip(input_ids_all, gen_output)):
        sep_position = sep_positions[i].item()
        tgt = input_ids_all_i[sep_position+1:]
        tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
        ans = extract_answer(tgt_text)
        pred_text = tokenizer.decode(gen_output_i[0][sep_position+1:], skip_special_tokens=True)
        pred_ans = extract_answer(pred_text)
        if ans == pred_ans:
            total_correct += 1
        if i == 0:
            print(f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
            print(f'Target: {tgt_text}')
            print(f'Predicted: {pred_text}')
            print('')
accuracy = total_correct / total_instances
token_accuracy = total_correct_tokens / total_tokens
loss = total_loss / total_tokens
ppl = math.exp(loss)
return accuracy, token_accuracy, ppl


In [None]:
dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Teacher 
teacher, args = load_pretrained_model(args)
scheduler = load_diffusion_scheduler(args)

# Load data
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

In [52]:
for batch in tqdm.tqdm(train_dataloader):
    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    break

  0%|          | 0/12020 [00:00<?, ?it/s]


In [59]:
x0 = torch.rand(4,4)
xt = torch.zeros(4,4)
cond = x0 > 0.5
xt[cond] = x0[cond]

In [65]:
print(tokenizer.decode(batch['input_ids_only'][1].tolist()))
print(tokenizer.decode([220]))

 A lion needs to gain 500 pounds for the winter. In the summer, it feasts on zebras and during autumn, it hunts gazelles and buffalos. It gained half its weight from zebras during summer and during autumn, it gained a quarter of that amount from gazelles. Buffalos made up the rest of its diet. How many pounds did it gain eating buffalos? <|endoftext|> 
 


In [70]:
x = torch.rand(3,4,5,6)
x[x > 1].mean()

tensor(nan)

In [54]:
batch['input_ids_only']

tensor([[ 1002, 27775, 13267,  ..., 50256, 50256, 50256],
        [  317, 18744,  2476,  ...,   220, 50256,   220],
        [ 8114,   468,   642,  ..., 50256, 50256, 50256],
        ...,
        [ 3362,  6593, 19132,  ..., 50256, 50256, 50256],
        [25737,  6134,  3126,  ..., 50256, 50256, 50256],
        [ 1629,   257,  3807,  ..., 50256, 50256, 50256]])

In [53]:
print(tokenizer.decode(batch['input_ids_only'][0].tolist()))
print(tokenizer.decode(batch['input_ids_only'][1].tolist()))

 If Kanye decides to jog around the park for 3 hours and a bottle of water is 500ml and costs $0.5. He drinks 1 bottle after each hour to stay hydrated, how much does he spend on water in total? <|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
 A lion needs to gain 500 pounds for the winter. In the summer, it feasts on zebras and during autumn, it hunts gazelles and buffalos. It gained half its weight from zebras during summer and during autumn, it gained a quarter of that amount from gazelles. Buffalos made up the rest of its diet. How many pounds did it gain

In [None]:
print(batch['input_ids_only'].shape)
print(batch['input_ids_cot'].shape)
print(batch['input_ids_nocot'].shape)
print(batch['input_ids_all'].shape)

In [None]:
batch['input_ids_only'].shape

In [None]:
batch['input_ids_all'].shape

In [None]:
for batch in tqdm.tqdm(val_dataloader):
    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    break