In [1]:
from tqdm import tqdm, trange
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, SequentialLR
from torch.utils.data import DataLoader, random_split
import wandb


from gsm_dataset import GSMDataset, gsm_collate, gsm_prompt, sample
from biscuit import Biscuit

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
biscuit_model = Biscuit()

num_epochs = 10
warmup_steps = 1e2
learning_rate = 5e-5

token_optimizer = optim.AdamW(biscuit_model.token_trunk.parameters(), lr=learning_rate, weight_decay=0.01)
token_warmup_scheduler = LambdaLR(token_optimizer, lr_lambda=lambda s: learning_rate * (s + 1) / (warmup_steps + 1))
token_cosine_scheduler = CosineAnnealingWarmRestarts(token_optimizer, T_0=50, T_mult=1, eta_min=1e-5)
token_combined_scheduler = SequentialLR(token_optimizer, schedulers=[token_warmup_scheduler, token_cosine_scheduler],
                                        milestones=[warmup_steps])

latent_optimizer = optim.AdamW(biscuit_model.latent_trunk.parameters(), lr=learning_rate, weight_decay=0.01)
latent_warmup_scheduler = LambdaLR(latent_optimizer, lr_lambda=lambda s: learning_rate * (s + 1) / (warmup_steps + 1))
latent_cosine_scheduler = CosineAnnealingWarmRestarts(latent_optimizer, T_0=50, T_mult=1, eta_min=1e-5)
latent_combined_scheduler = SequentialLR(latent_optimizer, schedulers=[latent_warmup_scheduler, latent_cosine_scheduler],
                                         milestones=[warmup_steps])

In [None]:
dataset = GSMDataset()

train_size = int(0.9 * len(dataset))
example_size = int(0.02 * len(dataset)) # reserve some data for few shot prompting
test_size = len(dataset) - train_size - example_size

train_dataset, example_dataset, test_dataset = random_split(dataset, [train_size, example_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=gsm_collate)
test_loader = DataLoader(test_dataset, batch_size=6, shuffle=True, collate_fn=gsm_collate)

COT_MAX_LENGTH = 6
run = wandb.init(project="biscuit", name="test")
train_step = 0
test_step = 0
token_loss_frequency = 10 # once every n steps
for epoch in trange(num_epochs, desc='Epoch'):
    for segments, keep_indices_lst in tqdm(train_loader, desc="Batch"):
        train_step += 1
        examples = sample(example_dataset, num_samples=4)
        latent_optimizer.zero_grad()
        latent_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=False)
        latent_loss.backward()
        latent_optimizer.step()
        latent_combined_scheduler.step()
        wandb.log({"latent_loss": latent_loss.item()})

        if train_step % token_loss_frequency == 1:
            token_optimizer.zero_grad()
            token_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=True)
            token_loss.backward()
            token_optimizer.step()
            token_combined_scheduler.step()
            wandb.log({"token_loss": token_loss.item()})

    # compute test loss
    for segments, keep_indices_lst in tqdm(test_loader, desc="Test Batch"):
        test_step += 1
        examples = sample(example_dataset, num_samples=4)
        with torch.no_grad():
            token_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=True)
            latent_loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, token_batch=False)
        wandb.log({"test_step": test_step, "test_token_loss": token_loss.item(), "test_latent_loss": latent_loss.item()})

    # save checkpoint
    token_trunk_checkpoint_path = f'checkpoints/token_trunk_epoch_{epoch}.pth'
    torch.save(biscuit_model.token_trunk.state_dict(), token_trunk_checkpoint_path)
    artifact = wandb.Artifact('checkpoint', type='model')
    artifact.add_file(token_trunk_checkpoint_path)
    latent_trunk_checkpoint_path = f'checkpoints/latent_trunk_epoch_{epoch}.pth'
    torch.save(biscuit_model.latent_trunk.state_dict(), latent_trunk_checkpoint_path)
    artifact.add_file(latent_trunk_checkpoint_path)
    wandb.log_artifact(artifact)
run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mchristyjestin[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Batch:   0%|          | 0/3363 [00:02<?, ?it/s]
Test Batch:   0%|          | 0/300 [00:00<?, ?it/s]
Epoch:   0%|          | 0/10 [00:03<?, ?it/s]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
latent_loss,▁
test_latent_loss,▁
test_step,▁
test_token_loss,▁
token_loss,▁

0,1
latent_loss,2.71691
test_latent_loss,3.61977
test_step,1.0
test_token_loss,61.3125
token_loss,76.875
