In [None]:
from tqdm import tqdm, trange
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, SequentialLR
from torch.utils.data import DataLoader, random_split
import wandb


from gsm_dataset import GSMDataset, gsm_collate, gsm_prompt, sample
from biscuit import Biscuit

In [None]:
biscuit_model = Biscuit()

num_epochs = 10
warmup_steps = 1e2
learning_rate = 5e-5
optimizer = optim.AdamW(biscuit_model.model.parameters(), lr=learning_rate, weight_decay=0.01)
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda s: learning_rate * (s + 1) / (warmup_steps + 1))
cosine_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1, eta_min=1e-5)
combined_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
                                  milestones=[warmup_steps])

In [None]:
dataset = GSMDataset()

train_size = int(0.9 * len(dataset))
example_size = int(0.02 * len(dataset)) # reserve some data for few shot prompting
test_size = len(dataset) - train_size - example_size

train_dataset, example_dataset, test_dataset = random_split(dataset, [train_size, example_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True, collate_fn=gsm_collate)
test_loader = DataLoader(test_dataset, batch_size=12, shuffle=True, collate_fn=gsm_collate)

COT_MAX_LENGTH = 6
wandb.init(project="biscuit", name="baseline-0")
for epoch in trange(num_epochs, desc='Epoch'):
    for segments, keep_indices_lst in tqdm(train_loader, desc="Batch"):
        examples = sample(example_dataset, num_samples=4)
        loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, no_latent=True)
        wandb.log({"lr": combined_scheduler.get_last_lr()[0], "loss": loss.item()})
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        combined_scheduler.step()

    # compute test loss
    for segments, keep_indices_lst in tqdm(test_loader, desc="Test Batch"):
        examples = sample(example_dataset, num_samples=4)
        with torch.no_grad():
            loss = biscuit_model.compute_batch(gsm_prompt(examples), segments, keep_indices_lst, no_latent=True)
        wandb.log({"test_loss": loss.item()})

    # save checkpoint
    checkpoint_path = f'checkpoints/epoch_{epoch}.pth'
    torch.save(biscuit_model.model.state_dict(), checkpoint_path)
    artifact = wandb.Artifact('checkpoint', type='model')
    artifact.add_file(checkpoint_path)
    wandb.log_artifact(artifact)