### MDLM

In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, GPT2TokenizerFast
import argparse
import os
import tqdm
import inspect
import logging

from models.teacher import Teacher
from models.configuration_teacher import TeacherConfig
from data import CoTDataset, CoTDataCollator, extract_answer

from utils import get_sep_position
from transformers import AutoModelForMaskedLM

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Args:
    train_path = 'data/gsm8k/train.txt'
    val_path = 'data/gsm8k/valid.txt'
    save_model = 'train_models/gsm8k/mdlm/teacher'
    max_new_tokens = 128
    base_model = 'sedd'
    epochs = 1
    batch_size = 32
    lr = 5e-5
    max_grad_norm = 1.0

args = Args()

### Evaluation

In [None]:
import math
import time
import re
import torch
import sys
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
import argparse
import os
import inspect
import tqdm
import logging
import random
import torch.nn as nn

from src.data import CoTDataset, CoTDataCollator, extract_answer
from src.models.emulator import Emulator
from src.models.student import Student
from src.utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--test_path', type=str, required=True)
# parser.add_argument('--batch_size', type=int, default=1)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--student_path', type=str, required=True)
# parser.add_argument('--emulator_path', type=str, required=True)
# args = parser.parse_args()

bsz = 1
student_path = "models/4_by_4_mult/gpt2/student"
emulator_path = "models/4_by_4_mult/gpt2/emulator"

class Args:
    test_path = "data/4_by_4_mult/test_bigbench.txt"
    batch_size = bsz
    max_new_tokens = 128
    student_path = student_path
    emulator_path = emulator_path

args = Args()
print(args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print(ptdtype, dtype, device)

# Load Models
emulator = Emulator.from_pretrained(args.emulator_path).to(device).to(ptdtype)
student = Student.from_pretrained(args.student_path).to(device).to(ptdtype)
emulator.eval()
student.eval()

# Load data
tokenizer = emulator.tokenizer
collate_fn = CoTDataCollator(tokenizer)
test_dataset = CoTDataset(tokenizer, args.test_path, 1024)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

accuracy, throughput  = evaluate(test_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens)
print(f"Test Accuracy: {accuracy}. Throughput: {throughput}")


In [None]:
@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens):
    total_time = 0
    total_instances = 0
    total_correct = 0

    for batch in tqdm.tqdm(dataloader):
        input_ids_all = batch['input_ids_nocot'].to(device)
        # Remove answer part
        sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
        input_ids = input_ids_all[:, :sep_positions.max()+1]
        start_time = time.time()
        with ctx:
            emulated_teacher_states = emulator(input_ids)

            # Generate from student
            beam_output = student.generate(
                input_ids=input_ids,
                teacher_states=emulated_teacher_states,
                max_new_tokens=max_new_tokens,
            )

        # Evaluate
        #import pdb; pdb.set_trace()
        for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)):
            #sep_position = input_ids_single.tolist().index(tokenizer.eos_token_id)
            sep_position = sep_positions[i].item()
            tgt = input_ids_all_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            #import pdb; pdb.set_trace()
            total_instances += 1
            if ans == pred_ans:
                total_correct += 1
        end_time = time.time()
        total_time += end_time - start_time

    #print (total_time, total_instances, total_instances / total_time)
    throughput = total_instances / total_time
    accuracy = total_correct / total_instances
    return accuracy, throughput


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--test_path', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--max_new_tokens', type=int, default=128)
    parser.add_argument('--student_path', type=str, required=True)
    parser.add_argument('--emulator_path', type=str, required=True)
    args = parser.parse_args()

    print (args)
    
    dtype = 'float32'
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
    print (ptdtype, dtype, device)


    # Load Models
    emulator = Emulator.from_pretrained(args.emulator_path).to(device).to(ptdtype)
    student = Student.from_pretrained(args.student_path).to(device).to(ptdtype)
    emulator.eval()
    student.eval()

    # Load data
    tokenizer = emulator.tokenizer
    collate_fn = CoTDataCollator(tokenizer)
    test_dataset = CoTDataset(tokenizer, args.test_path, 1024)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

    accuracy, throughput  = evaluate(test_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens)
    print (f"Test Accuracy: {accuracy}. Throughput: {throughput}")


### Dataset generation (N by N)

In [None]:
import random
def generate_n_by_n(order):
    # Randomly sample two numbers
    multiplicand = random.randint(10**(order-1), 10**order)  # Adjust the range as needed
    multiplier = random.randint(10**(order-1), 10**order)    # Adjust the range as needed

    # Compute the product
    product = multiplicand * multiplier

    # Extract digits of the multiplier in reverse order
    multiplier_digits = [int(d) for d in str(multiplier)][::-1]

    partial_products = []
    cumulative_sum = 0
    cumulative_sums = []

    for i, digit in enumerate(multiplier_digits):
        # partial_product for display
        partial_product = multiplicand * digit * 10**i
        partial_product_rev = str(partial_product)[::-1]
        partial_product_rev = partial_product_rev + '0' * max(0, order+1+i-len(partial_product_rev))
        partial_products.append(partial_product_rev)

        # Update cumulative sum
        cumulative_sum += partial_product

        # Reverse cumulative sum for display
        cumulative_sum_rev = str(cumulative_sum)[::-1]
        cumulative_sum_rev = cumulative_sum_rev + '0' * max(0, order+1+i-len(cumulative_sum_rev))
        cumulative_sums.append(cumulative_sum_rev)

    # Format the chain of thought
    cot_str = [' '.join(list(partial_products[0]))]
    for pp, cs in zip(partial_products[1:-1], cumulative_sums[1:-1]):
        pp_str = ' '.join([' '.join(list(pp)), '( ' + ' '.join(list(cs)) + ' )'])
        cot_str.append(pp_str)
    cot_str.append(' '.join(list(partial_products[-1])))
    cot_str = ' + '.join(cot_str)

    # Format the dataset
    multiplicand_rev = str(multiplicand)[::-1]
    multiplier_rev = str(multiplier)[::-1]
    q_str = ' * '.join([' '.join(list(multiplicand_rev)), ' '.join(list(multiplier_rev))])
    a_str = ' '.join(list(cumulative_sums[-1]))
    all_str = f"{q_str}||{cot_str} #### {a_str}"
    
    return all_str

In [None]:
import os

# Generate 808k training examples and 1k validation examples
num_train = 808 * 1e3 # 1e3
num_valid = 1 * 1e3 # 1e3
num_test = 1 * 1e3 # 1e3

for order in [7, 8, 9, 10]:
    data_dir = f'data/{order}_by_{order}_mult'

    train_set = set()
    valid_set = set()
    test_set = set()

    # Generate training data
    while len(train_set) < num_train:
        example = generate_n_by_n(order)
        equation = example.split('\n')[0]  # Use the first line as a unique identifier
        if equation not in train_set:
            train_set.add(example)

    # Generate validation data
    while len(valid_set) < num_valid:
        example = generate_n_by_n(order)
        equation = example.split('\n')[0]
        if equation not in train_set and equation not in valid_set:
            valid_set.add(example)

    # Generate test data
    while len(test_set) < num_test:
        example = generate_n_by_n(order)
        equation = example.split('\n')[0]
        if equation not in train_set and equation not in valid_set and equation not in test_set:
            test_set.add(example)

    os.makedirs(data_dir, exist_ok=True)

    # Write training data to a file
    with open(f'{data_dir}/train.txt', 'w') as f:
        for example in train_set:
            f.write(example + '\n')

    # Write validation data to a file
    with open(f'{data_dir}/valid.txt', 'w') as f:
        for example in valid_set:
            f.write(example + '\n')

    # Write test data to a file
    with open(f'{data_dir}/test_bigbench.txt', 'w') as f:
        for example in test_set:
            f.write(example + '\n')