# Load Model in PyTorch

- The following file is a simple example of how to load a model in PyTorch.

---

### 1. Load `Bigram-model-with-architecture` 

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(
            idx
        )  # (B,T,C) - Batch, tokens, channels (embedding size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(
                B * T
            )  # we could have also used the -1 argument to the view function
            loss = F.cross_entropy(
                logits, targets
            )  # we reshape the logits to be a 2D tensor and the targets to be 1D for the loss function as it expects **(N, C) and (N, 1)** respectively

        return logits, loss

    def generate(self, idx, max_new_tokens):
        """_summary_
        Use only the last token in the context to generate new tokens
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = torch.load("./bigram-model/bigram-model-with-architecture.pth")

model

BigramLanguageModel(
  (token_embedding_table): Embedding(65, 65)
)

---

### 2. Load `Bigram-model-with-state-dict`

In [6]:
bigram_model_with_state_dict = BigramLanguageModel(65) # initialize the model with the same architecture

model_state_dict = torch.load("./bigram-model/bigram-model-state-dict.pth")
model_state_dict

OrderedDict([('token_embedding_table.weight',
              tensor([[  2.1663,  -8.9142,  -7.7765,  ...,  -9.5859,  -2.7291, -10.1730],
                      [ -5.9427,  -4.5666,  -7.7469,  ...,  -9.2596,   0.6270,  -4.7146],
                      [  1.9087,   1.5847, -10.8409,  ...,  -9.4011,  -9.1053,  -8.9093],
                      ...,
                      [ -3.1896,  -0.0140,  -8.8260,  ...,  -6.7737,  -2.1301,  -7.2833],
                      [ -0.2109,   2.7150,  -1.2015,  ...,  -9.5673,  -8.3921,  -9.5525],
                      [ -7.7607,  -1.5148,  -7.2700,  ...,  -6.4779,  -0.6521,  -0.7968]],
                     device='mps:0'))])

In [7]:
bigram_model_with_state_dict.load_state_dict(model_state_dict)
bigram_model_with_state_dict

BigramLanguageModel(
  (token_embedding_table): Embedding(65, 65)
)

---

## Difference between the two methods

- The first method loads the model with the architecture and the weights. This method requires `only the class` of the model to be defined in the file.

- The second method loads the model with the state dictionary. This method requires `the class and the object initialized with same parameters`, and then the state dictionary is loaded into the object.

---

## Size difference between the two methods

- The first method is larger in size as it contains the architecture and the weights.

- The second method is smaller in size as it contains only the weights.