### MDLM

In [None]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

# See the `MDLM` collection page on the hub for list of available models.
model_name = 'kuleshov-group/mdlm-owt'
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)

In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, GPT2TokenizerFast
import argparse
import os
import tqdm
import inspect
import logging

from models.teacher import Teacher
from models.configuration_teacher import TeacherConfig
from data import CoTDataset, CoTDataCollator, extract_answer

from utils import get_sep_position
from transformers import AutoModelForMaskedLM

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Args:
    train_path = '../data/gsm8k/train.txt'
    val_path = '../data/gsm8k/valid.txt'
    save_model = 'train_models/gsm8k/mdlm/teacher'
    max_new_tokens = 128
    base_model = 'mdlm'
    epochs = 1
    batch_size = 32
    lr = 5e-5
    max_grad_norm = 1.0

args = Args()

In [None]:
def load_pretrained_model(args):
    if args.base_model == "sedd":
        # load model
        from ddms.sedd import SEDD
        model = SEDD.from_pretrained("louaaron/sedd-small")

        # load config
        args.num_vocabs = model.config.tokens
        args.length = model.config.model.length
        args.noise_schedule = model.config.noise.type
        args.graph = 'absorb'
    
    if args.base_model == "mdlm":
        model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/mdlm-owt", trust_remote_code=True)
        
        # load config
        args.num_vocabs = model.config.vocab_size
        args.length = model.config.model_length
        args.noise_schedule = 'loglinear'
        args.graph = 'absorb'
    
    return model, args

def load_diffusion_scheduler(args):
    if args.base_model == "sedd":
        from ddms import sedd
        scheduler = sedd.EulerScheduler(args)
    if args.base_model == "mdlm":
        from ddms import mdlm
        scheduler = mdlm.EulerScheduler(args)
    return scheduler

In [None]:
dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Teacher 
teacher, args = load_pretrained_model(args)
scheduler = load_diffusion_scheduler(args)

# Load data
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

In [52]:
for batch in tqdm.tqdm(train_dataloader):
    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    break

  0%|          | 0/12020 [00:00<?, ?it/s]


In [59]:
x0 = torch.rand(4,4)
xt = torch.zeros(4,4)
cond = x0 > 0.5
xt[cond] = x0[cond]

In [65]:
print(tokenizer.decode(batch['input_ids_only'][1].tolist()))
print(tokenizer.decode([220]))

 A lion needs to gain 500 pounds for the winter. In the summer, it feasts on zebras and during autumn, it hunts gazelles and buffalos. It gained half its weight from zebras during summer and during autumn, it gained a quarter of that amount from gazelles. Buffalos made up the rest of its diet. How many pounds did it gain eating buffalos? <|endoftext|> 
 


In [70]:
x = torch.rand(3,4,5,6)
x[x > 1].mean()

tensor(nan)

In [54]:
batch['input_ids_only']

tensor([[ 1002, 27775, 13267,  ..., 50256, 50256, 50256],
        [  317, 18744,  2476,  ...,   220, 50256,   220],
        [ 8114,   468,   642,  ..., 50256, 50256, 50256],
        ...,
        [ 3362,  6593, 19132,  ..., 50256, 50256, 50256],
        [25737,  6134,  3126,  ..., 50256, 50256, 50256],
        [ 1629,   257,  3807,  ..., 50256, 50256, 50256]])

In [53]:
print(tokenizer.decode(batch['input_ids_only'][0].tolist()))
print(tokenizer.decode(batch['input_ids_only'][1].tolist()))

 If Kanye decides to jog around the park for 3 hours and a bottle of water is 500ml and costs $0.5. He drinks 1 bottle after each hour to stay hydrated, how much does he spend on water in total? <|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
 A lion needs to gain 500 pounds for the winter. In the summer, it feasts on zebras and during autumn, it hunts gazelles and buffalos. It gained half its weight from zebras during summer and during autumn, it gained a quarter of that amount from gazelles. Buffalos made up the rest of its diet. How many pounds did it gain

In [None]:
print(batch['input_ids_only'].shape)
print(batch['input_ids_cot'].shape)
print(batch['input_ids_nocot'].shape)
print(batch['input_ids_all'].shape)

In [None]:
batch['input_ids_only'].shape

In [None]:
batch['input_ids_all'].shape

In [None]:
for batch in tqdm.tqdm(val_dataloader):
    input_ids_all = batch['input_ids_all'].to(device)
    labels = batch['labels_all'].to(device)
    break