# Test model reload
## Author: G. Erlebacher
We will perform the following experiment in `--test` mode. 
0. Initialize the model and save the initial state, `model0`
1. Run the `model0` for a single epoch and save it to `model1`
2. Run `model0` for two epochs and save the results to `model2a`
3. load `model1` and run `model1` for a single epoch. Save this to `model2b`
4. Compare `model2a` and `model2b`. They should be identical.

In [1]:
# Set up auto-reloading of modules
%load_ext autoreload 
%autoreload 2

In [2]:
# Option
# import os
# os.chdir('/path/to/root/folder')

from src.wandb_wrapper import WandbWrapper
from src.dataset import ConnTextULDataset
from src.model import Model
from src.main import hardcoded_args
from src.train_impl import create_data_slices
import src.train_impl as train_impl
import torch
from attrdict import AttrDict
import tqdm
from typing import List, Tuple, Dict, Any, Union
import pandas as pd

wandb = WandbWrapper()
torch.set_num_threads(1)

ModuleNotFoundError: No module named 'src'

In [None]:
def reset_context():
    ds = ConnTextULDataset(None, test=True, which_dataset=100, nb_rows=-1)
    return ds

ds = reset_context()
ds1 = reset_context()
c = hardcoded_args()

dd = ds[:2]
print("dd: ", dd['phonology'])

dd1 = ds1[:2]
print("dd1: ", dd1['phonology'])

# for d in range(len(ds)):
    # print(ds[d])
    # for k,v in d.items():
    #     print(k,v)

In [None]:
num_layers = 4

num_layers_dict = {
    "phon_dec": num_layers,
    "phon_enc": num_layers,
    "orth_dec": num_layers,
    "orth_enc": num_layers,
    "mixing_enc": num_layers,
}

model = Model(
    len(ds.character_tokenizer),
    len(ds.phonology_tokenizer),
    d_model=128,
    nhead=4,
    max_orth_seq_len=ds.max_orth_seq_len,
    max_phon_seq_len=ds.max_phon_seq_len,
    num_layers_dict=num_layers_dict,
    d_embedding=1,
)

print("load, global embedding.requires_grad: ", model.state_dict()['global_embedding'].requires_grad    )
# Why is it that reuqires_grad is false? That would explain why waits are not changing. 

In [None]:
c = AttrDict({"batch_size_train": 1, "batch_size_val": 1, "train_test_split": 0.8})
c.continue_training = False
c.d_model = 16
c.nhead = 1
c.d_embedding = 1
# c.pathway = 'o2p'
c.pathway = 'op2op'
c.learning_rate = 1.e-3

num_train = int(len(ds) * c.train_test_split)
# train_dataset_slices, val = create_data_slices(num_train, c, ds)
# val_cutpoint = val[0].start
# datum = ds[0:2]

In [None]:
MODEL_PATH = "./models_test"
model, opt = train_impl.setup_model(MODEL_PATH, c, ds, num_layers_dict)
print("opt: ", opt)
print("opt.state_dict(): ", opt.state_dict())
train_impl.print_weight_norms(model, "norms of model")

In [None]:
# Save initial model 
epoch = 0
# epoch_num not currently used
model_id = 3
epoch_num = 0

print("==> before train_impl.save")
train_impl.save(epoch, c, model, opt, MODEL_PATH, model_id, epoch_num=0)
print("==> after train_impl.save")

for m in model.parameters():
    print(f"model: {m.requires_grad=}")
    break

model_file_name = train_impl.get_model_file_name(model_id, epoch_num)
print("==> before train_impl.load_model")  
model1, opt1, c1 = train_impl.load_model(MODEL_PATH, model_id, epoch_num)
train_impl.print_weight_norms(model1, "norms of model1")
# print("opt1: ", opt1)

for m in model1.parameters():
    print(f"model1: {m.requires_grad=}")  # it is True
    break

print("norms of weights are the same. Use in pytest")

assert c1 == c, "config dictionaries are not the same"
assert opt.state_dict() == opt1.state_dict(), "opt.state_dict are not the same"

# Assume model1 and model2 are instances of your model
state_dict  = model.state_dict()
state_dict1 = model1.state_dict()
assert train_impl.compare_state_dicts(state_dict, state_dict1), "Model State dicts are not equal."


In [None]:
def reset_dataset_slices(ds, c):
    """
    Reset Dataset Slices

    Parameters:
    - ds
    - c
    """
    ds = reset_context()  # SHOULD NOT BE REQUIRED, unless ds is modified in place
    num_train = int(len(ds) * c.train_test_split)
    train_dataset_slices, val_dataset_slices = train_impl.create_data_slices(
            num_train, c, ds
        )
    c.n_steps_per_epoch = len(train_dataset_slices)
    return train_dataset_slices, val_dataset_slices

In [None]:
c.num_epochs = 0  # WHAT IS THIS?
pbar = tqdm.tqdm(range(epoch_num, epoch_num + c.num_epochs), position=0)
device = 'cpu'
example_ct = [0]
c.max_nb_steps = 4
wandb.is_wandb_on = False
generated_text_table = wandb.Table(columns=["Step", "Generated Output"])

# Closures to simplify function calls. I need to repeat them just before running 
# the models to ensure that train_dataset_slices is the same in both cases
def setup_closures(model, opt):
    """
    This closure function cannot be moved elsehwere, because it depends on `train_data_slices` and other
    variables to be in the global contecxt
    """
    example_ct = [0]
    train_dataset_slices, _ = reset_dataset_slices(ds, c)
    def single_step_fct(batch_slice, step, epoch, mode):
      return train_impl.single_step(
        c,
        pbar,
        model,
        train_dataset_slices,
        batch_slice,
        ds,
        device,
        opt,
        epoch,
        step,
        generated_text_table,
        example_ct,
        mode,
    )

    def train_single_epoch_fct(epoch):
      return train_impl.train_single_epoch(
        c,
        model,
        train_dataset_slices,
        epoch,
        single_step_fct,
    )

    return single_step_fct, train_single_epoch_fct

In [None]:
# Run model and model1 on the same initial data. Compare metrics

def run_train(model, opt, num_epochs):
    c.seed = 100
    torch.manual_seed(c.seed)
    torch.cuda.manual_seed_all(c.seed)
    single_step_fct, train_single_epoch_fct = setup_closures(model, opt)
    metrics: List[Dict] = [{}]
    for epoch in range(num_epochs):
        print("************* epoch: ", epoch, " *******************88")
        metrics[0] = train_single_epoch_fct(epoch)
        # print(metrics[0])
        return metrics[0]

In [None]:
"""
for m in model.parameters():
    print(f"model: {m.requires_grad=}")  # it is True
    break
for m in model1.parameters():
    print(f"model1: {m.requires_grad=}")  # it is True
    break
"""

In [None]:
run_train(model, opt, 2);

In [None]:
run_train(model1, opt1, 2);

In [None]:
print(opt1.state_dict())
print(opt1)

In [None]:
print(opt.state_dict())
print(opt)

In [None]:
for k,v in opt.state_dict().items():
    print(k, v)

In [None]:
for k,v in opt1.state_dict().items():
    print(k, v)

In [None]:
for k,v in opt1.state_dict().items():
    print(k, v)

In [None]:
print(opt, opt1)