In [None]:
from tqdm import tqdm, trange
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, SequentialLR
from torch.utils.data import DataLoader, random_split
import wandb


from gsm_dataset import GSMDataset, gsm_collate, gsm_prompt, sample
from biscuit import Biscuit

In [None]:
biscuit_model = Biscuit()

num_epochs = 10
warmup_steps = 1e2
learning_rate = 5e-5

latent_optimizer = optim.AdamW([*biscuit_model.latent_trunk.parameters(), biscuit_model.bot_embedding, biscuit_model.eot_embedding],
                               lr=learning_rate, weight_decay=0.01)
latent_warmup_scheduler = LambdaLR(latent_optimizer, lr_lambda=lambda s: learning_rate * (s + 1) / (warmup_steps + 1))
latent_cosine_scheduler = CosineAnnealingWarmRestarts(latent_optimizer, T_0=50, T_mult=1, eta_min=1e-5)
latent_combined_scheduler = SequentialLR(latent_optimizer, schedulers=[latent_warmup_scheduler, latent_cosine_scheduler],
                                         milestones=[warmup_steps])

In [None]:
dataset = GSMDataset()

train_size = int(0.9 * len(dataset))
example_size = int(0.02 * len(dataset)) # reserve some data for few shot prompting
test_size = len(dataset) - train_size - example_size

train_dataset, example_dataset, test_dataset = random_split(dataset, [train_size, example_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=gsm_collate)
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=True, collate_fn=gsm_collate)

COT_MAX_LENGTH = 6
run = wandb.init(project="biscuit", name="two-trunk")
test_step = 0
for epoch in trange(num_epochs, desc='Epoch'):
    for segments, keep_indices_lst in tqdm(train_loader, desc="Batch"):
        examples = sample(example_dataset, num_samples=4)
        latent_optimizer.zero_grad()
        latent_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=False)
        latent_loss.backward()
        latent_optimizer.step()
        latent_combined_scheduler.step()
        wandb.log({"latent_loss": latent_loss.item()})

    # compute test loss
    for segments, keep_indices_lst in tqdm(test_loader, desc="Test Batch"):
        test_step += 1
        examples = sample(example_dataset, num_samples=4)
        with torch.no_grad():
            latent_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=False)
        wandb.log({"test_step": test_step, "test_latent_loss": latent_loss.item()})

    # save checkpoint
    artifact = wandb.Artifact('checkpoint', type='model')
    latent_trunk_checkpoint_path = f'checkpoints/latent_trunk_epoch_{epoch}.pth'
    torch.save(biscuit_model.latent_trunk.state_dict(), latent_trunk_checkpoint_path)
    artifact.add_file(latent_trunk_checkpoint_path)
    wandb.log_artifact(artifact)
run.finish()