In [1]:
import torch
from src.datasets.discrete_synthetic.discrete_synthetic import DiscreteSyntheticDataset
from src.datasets.discrete_helper import collate_fn
from src.tokenizers.discrete_synthetic.discrete_synthetic_tokenizer import DiscreteSyntheticTokenizer
from src.nn.models.discrete_model import DiscreteModel
from src.training.training import train_discrete_model
from matplotlib import pyplot as plt
from src.inference.discrete_inference import dis_t, bayesian_inference
from accelerate import Accelerator
from accelerate.utils import merge_fsdp_weights
from src.training.checkpoint import save_checkpoint, load_checkpoint

## FSDP checkpoint to non-FSDP

In [11]:
merge_fsdp_weights("checkpoint/pytorch_model_fsdp_0", "checkpoints")

In [12]:
merge_fsdp_weights("checkpoint/optimizer_0", "checkpoints/optimizer.bin", safe_serialization=False)

## non-FSDP checkpoint to FSDP

In [14]:
from safetensors.torch import load_file

In [15]:
result = load_file("checkpoints/model.safetensors")

In [16]:
result

{'classifier': tensor([[-6.1993e-01, -3.3674e-02,  3.9473e-01, -9.8560e-01,  5.7942e-03,
          -1.4059e-01,  6.4171e-01, -2.9709e-01, -3.4139e-01, -9.5085e-01,
          -2.3260e-01],
         [ 3.1059e-03,  4.6115e-01, -2.7397e-01,  1.1259e-01,  5.5562e-01,
           1.0437e+00,  8.5906e-01,  4.8378e-01, -1.5345e+00,  5.1380e-01,
           5.5663e-01],
         [-8.4661e-01, -8.2405e-01,  5.3505e-01, -1.0770e+00,  6.7262e-01,
           3.9758e-01,  3.0170e-01,  4.2540e-01, -1.9679e+00,  6.3208e-01,
          -8.2033e-01],
         [-1.2007e+00,  1.2715e+00,  1.4130e+00,  2.2191e-01,  1.5783e+00,
           7.5850e-01, -1.0560e+00, -2.2760e+00, -1.2134e+00, -1.4211e+00,
           1.2931e-01],
         [ 7.9976e-01, -2.1493e-01,  5.3031e-01, -1.1161e-01, -9.5826e-01,
          -3.4487e-01, -3.6048e-01, -7.6354e-01, -1.5223e+00, -1.7514e-01,
           4.7254e-01],
         [-2.1998e-01, -1.2222e+00,  6.7143e-01, -1.5188e+00,  1.8963e+00,
          -1.3096e-01,  2.0753e-01,  4.83

In [9]:
hasattr(accelerator.state, "fsdp_plugin")

False

In [None]:
tokenizer = DiscreteSyntheticTokenizer()
max_seq_len = 32
train_ds = DiscreteSyntheticDataset(tokenizer, tokenized_length=max_seq_len)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)

model = DiscreteModel(max_seq_len, tokenizer.vocab_size(), hidden_dim=64, num_heads=8)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [5]:
current_classifier = model.classifier.clone()

In [8]:
model, opt, train_dl = accelerate.prepare(model, opt, train_dl)

In [9]:
accelerate.load_state("checkpoints")

In [13]:
current_classifier = current_classifier.to(accelerate.device)

In [14]:
torch.allclose(model.classifier, current_classifier)

False

In [15]:
accelerate.save_state("checkpoint")

PosixPath('checkpoint')