In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import hydra
from hydra import initialize, compose
from typing import Dict, List
from nn_core.common import PROJECT_ROOT

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="finetune")

hydra.initialize()

In [10]:
cfg = compose(config_name="finetune", overrides=[])

In [11]:
import logging
import os
import time
from typing import Dict, List, Union

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
import torch.nn as nn
import wandb
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Callback, LightningModule
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

from tvp.data.datasets.registry import get_text_dataset
from tvp.modules.text_encoder import TextEncoder
from tvp.modules.text_heads import get_classification_head
from tvp.pl_module.text_classifier import TextClassifier
from tvp.utils.io_utils import get_class, load_model_from_artifact
from tvp.utils.utils import LabelSmoothing, build_callbacks

pylogger = logging.getLogger(__name__)
torch.set_float32_matmul_precision("high")

In [12]:
seed_index_everything(cfg)

template_core: NNTemplateCore = NNTemplateCore(
    restore_cfg=cfg.train.get("restore", None),
)

logger: NNLogger = NNLogger(
    logging_cfg=cfg.train.logging, 
    cfg=cfg, 
    resume_id=template_core.resume_id
)

classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"

text_encoder: TextEncoder = hydra.utils.instantiate(cfg.nn.module.model, keep_lang=False)
    
model_class = get_class(text_encoder)
metadata = {"model_name": cfg.nn.module.model.model_name, "model_class": model_class}

if cfg.reset_classification_head:
    classification_head = get_classification_head(
        input_size=cfg.nn.module.model.hidden_size,
        num_classes=cfg.nn.data.dataset.num_classes
    )

    model_class = get_class(classification_head)
    
    metadata = {
        "model_name": cfg.nn.module.model.model_name,
        "model_class": model_class,
        "num_classes": cfg.nn.data.dataset.num_classes,
        "input_size": cfg.nn.module.model.hidden_size,
    }

else:
    classification_head = load_model_from_artifact(
        artifact_path=f"{classification_head_identifier}:latest", 
        run=logger.experiment
    )

model: TextClassifier = hydra.utils.instantiate(
    cfg.nn.module, 
    encoder=text_encoder, classifier=classification_head, 
    _recursive_=False,
    save_grad_norms=cfg.train.save_grad_norms
)

dataset = get_text_dataset(
    dataset_name=cfg.nn.data.train_dataset,
    tokenizer_name=cfg.nn.module.model.model_name,
    train_split_ratio_for_val=cfg.nn.data.splits_pct.val,
    max_seq_length=cfg.nn.data.max_seq_length,
    batch_size=cfg.nn.data.batch_size.train,
    num_workers=cfg.nn.data.num_workers.train
)

model.freeze_head()

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    # max_epochs=cfg.epochs, 
    max_epochs=cfg.nn.data.dataset.ft_epochs,
    logger=logger,
    callbacks=callbacks,
    max_steps=100,
    **cfg.train.trainer,
)

pylogger.info(f"Starting fine-tuning on {cfg.ft_on_data_split} data split!")
if cfg.ft_on_data_split == "train":
    ft_dataloader = dataset.train_loader
elif cfg.ft_on_data_split == "val":
    ft_dataloader = dataset.val_loader
else:
    raise ValueError(f"Unknown data split to fine-tune on: {cfg.ft_on_data_split}. Possible values: \"train\" or \"val\"")

pylogger.info("Starting training!")
trainer.fit(
    model=model, 
    train_dataloaders=ft_dataloader, 
    ckpt_path=template_core.trainer_ckpt_path
)

pylogger.info("Starting testing!")
trainer.test(model=model, dataloaders=dataset.test_loader)

model_class = get_class(text_encoder)

metadata = {
    "model_name": cfg.nn.module.model.model_name, 
    "model_class": model_class
}

if logger is not None:
    logger.experiment.finish()

trainer.save_checkpoint(os.path.join(storage_dir, "final_model.ckpt"))

Global seed set to 1608637542


Loading ViT-B-16 pre-trained weights.


  rank_zero_warn(


  rank_zero_warn(
  rank_zero_warn(
Map: 100%|██████████| 40430/40430 [00:03<00:00, 12114.53 examples/s]
Map: 100%|██████████| 390965/390965 [00:29<00:00, 13307.55 examples/s]
Map: 100%|██████████| 40430/40430 [00:01<00:00, 31212.66 examples/s]
Map: 100%|██████████| 390965/390965 [00:11<00:00, 33809.86 examples/s]


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type                   | Params
---------------------------------------------------------------
0 | train_acc           | MulticlassAccuracy     | 0     
1 | val_acc             | MulticlassAccuracy     | 0     
2 | test_acc            | MulticlassAccuracy     | 0     
3 | encoder             | ClipTextEncoder        | 149 M 
4 | classification_head | TextClassificationHead | 1.0 K 
---------------------------------------------------------------
149 M     Trainable params
1.0 K     Non-trainable params
149 M     Total params
598.487   Total estimated model params size (MB)


Epoch 0:   1%|          | 100/10234 [00:06<10:54, 15.48it/s, v_num=vo1w]

INFO: `Trainer.fit` stopped: `max_steps=100` reached.


Epoch 0:   1%|          | 100/10234 [01:56<3:15:58,  1.16s/it, v_num=vo1w]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 1264/1264 [00:21<00:00, 58.56it/s]


0,1
acc/test,▁
acc/train_epoch,▁
epoch,▁▁▁█
grad_norm_batch,▁█
grad_norm_epoch,▁
loss/test,▁
loss/train_epoch,▁
loss/train_step,▁█
lr-SGD,▁▁
trainer/global_step,▁▁████

0,1
acc/test,0.64972
acc/train_epoch,0.61969
epoch,1.0
grad_norm_batch,5.49674
grad_norm_epoch,5.54338
loss/test,0.63923
loss/train_epoch,0.66009
loss/train_step,0.71494
lr-SGD,0.001
trainer/global_step,100.0


In [30]:
encoder_ckpt = os.path.join(storage_dir, "encoder.pt")
torch.save(model.encoder.state_dict(), encoder_ckpt)

head_ckpt = os.path.join(storage_dir, "head.pt")
torch.save(model.classification_head.state_dict(), head_ckpt)

In [31]:
head_state_dict = torch.load(head_ckpt)

In [32]:
from pprint import pprint
pprint(list(head_state_dict.keys()))

['classification_head.weight', 'classification_head.bias']
