# Imports

In [None]:
import wandb
from llm_non_identifiability.runner import LightningGrammarModule
from llm_non_identifiability.datamodule import GrammarDataModule

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from os.path import abspath, dirname, join
import torch

from copy import deepcopy

# Load model

## Download artifact from wandb

In [None]:
run = wandb.init()

### Get model

In [None]:
# model_name = 'model-sw5tu72o:best'
# model = run.use_artifact(f'causal-representation-learning/llm-non-identifiability/{model_name}', type='model')
# artifact_dir = model.download()


import wandb
run = wandb.init()
artifact = run.use_artifact('causal-representation-learning/rule_extrapolation/model-lnv4oun6:v0', type='model')
artifact_dir = artifact.download()

In [None]:
artifact_dir

In [None]:
artifact

### Get prompt completion tables

In [None]:
table_name = "run-xgl66xch-id_prompt_completions:v0"
id_prompt_completions_artifact = run.use_artifact(f'causal-representation-learning/llm_non_identifiability/{table_name}', type='run_table')
id_prompt_completions_dir = id_prompt_completions_artifact.download()

In [None]:
id_prompt_completions: list = id_prompt_completions_artifact.get("id_prompt_completions").data


# Finetuning

## Constants

In [None]:
MAX_LENGTH = 256
BATCH_SIZE = 128
GRAMMAR = "aNbN"
MAX_EPOCHS = 100

## Load model from checkpoint

In [None]:
# source: https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html#checkpoint-loading
checkpoint = "model.ckpt"
PATH = join(artifact_dir, checkpoint)

# To load a model along with its weights and hyperparameters use the following method
model: LightningGrammarModule = LightningGrammarModule.load_from_checkpoint(
    PATH, device="cuda" if torch.cuda.is_available() else "cpu", grammar=GRAMMAR
)

## Define datamodule

In [None]:
datamodule = GrammarDataModule(max_length=MAX_LENGTH, batch_size=BATCH_SIZE, grammar=GRAMMAR)

In [None]:
datamodule.prepare_data()

In [None]:
batch = next(iter(datamodule.train_dataloader()))

In [None]:
X_input = batch[:, :-1]
X_expected = batch[:, 1:]

In [None]:
pred = model.model(X_input)

In [None]:
next_items = model._pick_next_tokens(pred)[:, -1].view(-1, 1)

In [None]:
model._pick_next_tokens(pred)

In [None]:
X_input.shape

In [None]:
pred.shape, batch.shape

In [None]:
(model._pick_next_tokens(pred)== X_input).sum(), (model._pick_next_tokens(pred)!= X_input).sum()

In [None]:
(model._pick_next_tokens(pred)== X_expected).sum(), (model._pick_next_tokens(pred)!= X_expected).sum()

In [None]:
model.hparams.loss_fn(pred, X_expected,reduction="none")

In [None]:
torch.nn.CrossEntropyLoss(reduction="mean")(pred, X_expected)

In [None]:
torch.nn.CrossEntropyLoss(reduction="mean", ignore_index=2)(pred, X_expected)

In [None]:
torch.nn.CrossEntropyLoss(reduction="sum")(pred, X_expected)

In [None]:
torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=2)(pred, X_expected)

## Copy model and change alpha

In [None]:
bad_model = deepcopy(model)

In [None]:
bad_model.model.relu_rescale = torch.nn.Parameter(torch.tensor(100.0), requires_grad=False)

## Define logger

In [None]:
logger = WandbLogger(entity="causal-representation-learning", project="llm-non-identifiability", name="finetune-good", log_model="all")

## Checkpoint callback

In [None]:
model_checkpoint = ModelCheckpoint(
    monitor='Val/loss',
    save_top_k=1,
    mode='min',
)

## Train original model

In [None]:
trainer  = Trainer(max_epochs=MAX_EPOCHS, logger=logger)
trainer.fit(model, datamodule=datamodule, )


## Train bad model

In [None]:
assert model.model.relu_rescale.requires_grad == False

In [None]:
logger = WandbLogger(entity="causal-representation-learning", project="llm-non-identifiability", name="finetune-bad", log_model="all")

In [None]:
trainer  = Trainer(max_epochs=MAX_EPOCHS, logger=logger)
trainer.fit(bad_model, datamodule=datamodule, )