# Uploading the Model

In [None]:
from typing import Any

import torch
import torch.nn as nn
import wandb
from huggingface_hub import PyTorchModelHubMixin

from simple_stories_train.models.llama import Llama, LlamaConfig
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT

In [None]:
weights = wandb.restore(
    "model_step_4824.pt", run_path="dbra/simple-stories/runs/w66ikhs3", replace=True
)

In [None]:
# We create a wrapper class for hf that inherits from PyTorchModelHubMixin
class LlamaTransformer(
    nn.Module,
    PyTorchModelHubMixin,
    repo_url="https://github.com/danbraunai/simple_stories_train",
    language=["en"],
    pipeline_tag="text-generation",
):
    def __init__(self, **config: Any):
        super().__init__()
        self.llama = Llama(LlamaConfig(**config))

    def forward(self, x: torch.Tensor):
        return self.llama(x)


# create model
config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)

In [None]:
# We load the model weights obtained from wandb
state_dict = torch.load(weights.name, map_location=torch.device("cpu"))

# Strip `_orig_mod.` from keys, this appears to be an artifact of wandb
new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

model.llama.load_state_dict(new_state_dict)

In [None]:
# We perform a sanity check to see if the model is working

from simple_stories_train.dataloaders import DatasetConfig, create_data_loader

config = DatasetConfig(
    tokenizer_file_path="tokenizer/stories-3072.json",
    column_name="story",
    is_tokenized=False,
)
loader, tokenizer = create_data_loader(dataset_config=config, batch_size=1, buffer_size=1000)

loader = iter(loader)

for _ in range(10):
    input = next(loader)["input_ids"].to(torch.int)
    out = model(input)
    assert out[0].shape == torch.Size([1, 1, 50257])

    print(
        f"""...{tokenizer.decode(input.tolist()[0][-20:])}
          -> {tokenizer.decode(out[0].argmax(-1).tolist()[0][-10:])}
           """.replace(" ##", "").replace(" .", ".")
    )

...timid girl into a brave explorer. a small puppet hung on a wall , its
          -> bright
           
...over by the fence , a boy named samuel watched kids play baseball. he wanted
          -> to
           
.... " are you sure you want to follow that ? " it chattered. " many have
          -> tried
           
...my porch. i watched the clouds drift by , feeling alone. my old friend samuel
          -> had
           
...was a boy named jose , who often felt alone. one day , as he wandered , a bright
          -> light
           
..., a girl made a wish for a true companion. little did she know , a wise
          -> old
           
...and see what it could do. as anne polished the lantern , a small light appeared.
          -> it
           
...to meet other kids. the camp was amazing ! they had rockets , space suits
          -> ,
           
.... it reminded her of her own happiness. together , they watched the lights flicker and glow ,
          -> and
           
...##aptor 

In [None]:
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"
model.save_pretrained(HUB_REPO_NAME)

In [None]:
# Finally, we upload the model to the hub
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"
model.push_to_hub(HUB_REPO_NAME)

# Downloading and Using the Model

In [None]:
# Check if the model is available on the hub
from typing import Any

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT

config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"

model = model.from_pretrained(HUB_REPO_NAME)

In [None]:
# Checking model output
for _ in range(10):
    input = next(loader)["input_ids"].to(torch.int)
    out = model(input)
    print(out[0].argmax(-1).tolist()[0])
    assert out[0].shape == torch.Size([1, 1, 50257])

    print(
        f"""...{tokenizer.decode(input.tolist()[0][-20:])}
          -> {tokenizer.decode(out[0].argmax(-1).tolist()[0][-10:])}
           """.replace(" ##", "").replace(" .", ".")
    )