In [2]:
import os
import torch 
import torch.nn.functional as F
import numpy as np
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from torch.distributions import Categorical
from tqdm import tqdm

from src.utils import registry
import src.utils as utils
from S4Models import S4LMModel, S4GlobalRanker

In [5]:
with initialize(version_base=None, config_path="configs"):
    
    config=compose(config_name="wiki_noncausal_lm_config.yaml")
    print(OmegaConf.to_yaml(config))
    OmegaConf.set_struct(config, False)

trainer:
  device: cuda
  accumulate_grad_batches: 16
  max_epochs: 1
  gradient_clip_val: 0.0
  log_every_n_steps: 1
  evaluation_step: 200
  fp16: true
  task: mlm
checkpoint:
  dirpath: checkpoints
  verbose: true
  path: null
dataset:
  _name_: wiki
  data: wikipedia
  subset: 20220301.en
  cache_dir: ../cache
  test_size: 0.0001
  tokenizer: bert-base-uncased
  batch_size: 2
  l_max: 4096
optimizer:
  _name_: adamw
  lr: 0.0001
  weight_decay: 0.0001
scheduler:
  _name_: cosine_warmup
  num_warmup_steps: 1000
  num_training_steps: 800000
embedding:
  rescale: true
  d_model: ${model.d_model}
  n_tokens: 30522
decoder:
  tied: false
  d_output: ${model.d_model}
model:
  layer:
  - _name_: s4
    d_state: 64
    l_max: ${dataset.l_max}
    postact: glu
    dropout: ${...dropout}
    lr: ${optimizer.lr}
    n_ssm: 128
    bidirectional: true
  - _name_: s4
    d_state: 64
    l_max: ${dataset.l_max}
    postact: glu
    dropout: ${...dropout}
    lr: ${optimizer.lr}
    n_ssm: 128
  

In [7]:
model = S4LMModel(config)
state_dict = torch.load('pytorch_model.bin',map_location='cpu')
model.load_state_dict(state_dict,strict=True)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)/1000000

105.023232