# Converting saves

In [41]:
from forgebox.imports import *
from transformers import AutoModel, AutoTokenizer, AutoConfig
from tqdm.notebook import tqdm

In [42]:
WEIGHTS = Path("/nvme/GCI/public/lit/weights")

In [43]:
tag = "roberta-base"

In [44]:
WEIGHTS.ls()

['rblg_att_head_0725_212112',
 'pre_rbt_bs',
 'pre_rbtlg',
 'MeanPooler_aug_28_171839',
 'bert_cased_0724_193954',
 'CLSMeanLater_26_155044',
 'rbt_lg_0725_155121',
 'MeanPooler_3e_28_154350',
 'rblg_mean_0725_234256',
 'MeanPooler_3e_28_141004',
 'MeanPooler_base2_28_235651',
 'WithAttnHead_base2_28_234832',
 'CLSReg_base2_29_225830',
 'pre_rbt_bs2',
 'CLSReg_bs2pre_29_235936',
 'rbt_combined_v1',
 'rbt_lg_finer_0725_160941',
 'MeanLater_26_155703',
 'MeanPooler_ne_27_170538',
 'CLSReg_26_185501',
 'MeanPooler_base2_28_233227',
 'WithAttnHead_bs2pre_29_150559',
 'rblg_mean_0725_214714',
 'bert_buc_0724_194410',
 'MeanPooler_cyclr_27_232944',
 'MeanPooler_3e_28_161158',
 'pre_rbt_lg',
 'PairCLS_bs2pre_30_113040',
 'bert_buc_0725_002308',
 'MeanPooler_ne_27_151237',
 'pre_rbt_lg2',
 'CLSReg_ft_27_111118',
 'CLSReg_bs2pre_29_170744',
 'WithAttnHead_base2_29_140328',
 'MeanPooler_bs2pre_upk_30_105624',
 'CLSReg_base2_29_232842',
 'MeanPooler_3e_28_180243',
 'pre_rbtbs']

In [45]:
SAVE = WEIGHTS/"PairCLS_bs2pre_30_113040"

!ls {SAVE}

epoch=4-val_loss=0.25fd2.ckpt  epoch=6-val_loss=0.27fd0.ckpt
epoch=4-val_loss=0.31fd4.ckpt  epoch=7-val_loss=0.32fd3.ckpt
epoch=5-val_loss=0.26fd1.ckpt


In [46]:
def save_configs(tag, save):
    config = AutoConfig.from_pretrained(tag)
    tokenizer = AutoTokenizer.from_pretrained(tag)
    tokenizer.save_pretrained(save/"tokenizer")
    config.save_pretrained(save/"config")
    return tokenizer, config

In [47]:
tokenizer, config = save_configs(tag, SAVE)

In [48]:
class AttentionHead(nn.Module):
    def __init__(self, in_features, hidden_dim):
        super().__init__()
        self.in_features = in_features
        self.middle_features = hidden_dim

        self.W = nn.Linear(in_features, hidden_dim)
        self.V = nn.Linear(hidden_dim, 1)
        self.out_features = hidden_dim

    def forward(self, features):
        att = torch.tanh(self.W(features))
        score = self.V(att)
        attention_weights = torch.softmax(score, dim=1)
        context_vector = attention_weights * features
        context_vector = torch.sum(context_vector, dim=1)

        return context_vector

class LitLM(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base =  base
        self.learning_rate=1e-3
        self.config = self.base.config
        self.reg = nn.Linear(base.config.hidden_size, 1)
        self.crit = nn.MSELoss()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        vec  =self.base(x).last_hidden_state[:,0,:]
        return self.reg(vec)
    
    def inference(
        lit_model, model_weights:Path, data_loader:DataLoader, filename:str
    ) -> None:
        lit_model.load_state_dict(torch.load(str(model_weights)), strict=False)
        lit_model = lit_model.eval()
        lit_model = lit_model.cuda()
        results = []
        for batch_idx, batch in tqdm(enumerate(data_loader), leave=False):
            ids = batch["id"]
            x = batch["excerpt"].cuda()
            with torch.no_grad():
                y_ = lit_model(x)[:,0].detach().cpu().numpy()
            results.append(pd.DataFrame({"id":ids, "target":y_}))
        pd.concat(results).to_csv(filename, index=False)
        lit_model.cpu()
        
class LitLMAttn(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base =  base
        self.learning_rate=1e-4
        self.config = self.base.config
        self.head = AttentionHead(self.config.hidden_size, self.config.hidden_size)
        self.dout = nn.Dropout(.1)
        self.reg = nn.Linear(base.config.hidden_size, 1)
        self.crit = nn.MSELoss()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        vec  =self.base(x).last_hidden_state
        return self.reg(self.dout(self.head(vec)))

In [49]:
def recur_dir(path:Path):
    results = []
    for p in path.iterdir():
        if p.is_dir():
            results+=recur_dir(p)
        else:
            results.append(p)
    return results

In [50]:
for ckpt in tqdm(list(i for i in recur_dir(SAVE) if i.name[-5:]==".ckpt")):
    state = torch.load(str(ckpt), map_location='cpu')['state_dict']
    torch.save(state, str(ckpt.parent/(ckpt.name.replace("=","-").replace(".ckpt",".h5"))))

  0%|          | 0/5 [00:00<?, ?it/s]

In [34]:
# LitLMAttn(AutoModel.from_config(config)).load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])

In [35]:
LitLM(AutoModel.from_config(config)).load_state_dict(torch.load(str(ckpt), map_location='cpu')['state_dict'])

<All keys matched successfully>

In [51]:
for i in recur_dir(SAVE):
    if i.name[-5:]==".ckpt":
        os.system(f"rm -f {i}")
        
!ls {SAVE}

config			     epoch-5-val_loss-0.26fd1.h5  tokenizer
epoch-4-val_loss-0.25fd2.h5  epoch-6-val_loss-0.27fd0.h5
epoch-4-val_loss-0.31fd4.h5  epoch-7-val_loss-0.32fd3.h5


In [52]:
!kaggle datasets init -p {SAVE}

Data package template written to: /nvme/GCI/public/lit/weights/PairCLS_bs2pre_30_113040/dataset-metadata.json


In [53]:
!cat {SAVE/"dataset-metadata.json"}

{
  "title": "INSERT_TITLE_HERE",
  "id": "raynardj/INSERT_SLUG_HERE",
  "licenses": [
    {
      "name": "CC0-1.0"
    }
  ]
}

In [54]:
%%writefile {SAVE/"dataset-metadata.json"}
{
  "title": "PR_bs2pre_30_113040",
  "id": "raynardj/pr-bs2pre-30-113040",
  "licenses": [
    {
      "name": "CC0-1.0"
    }
  ]
}

Overwriting /nvme/GCI/public/lit/weights/PairCLS_bs2pre_30_113040/dataset-metadata.json


In [55]:
!kaggle datasets create -r tar -u -p {SAVE}

Starting upload for file epoch-4-val_loss-0.31fd4.h5
100%|████████████████████████████████████████| 476M/476M [07:07<00:00, 1.17MB/s]
Upload successful: epoch-4-val_loss-0.31fd4.h5 (476MB)
Starting upload for file tokenizer.tar
100%|███████████████████████████████████████| 2.50M/2.50M [00:05<00:00, 445kB/s]
Upload successful: tokenizer.tar (2MB)
Starting upload for file config.tar
100%|██████████████████████████████████████| 10.0k/10.0k [00:05<00:00, 1.94kB/s]
Upload successful: config.tar (10KB)
Starting upload for file epoch-5-val_loss-0.26fd1.h5
100%|████████████████████████████████████████| 476M/476M [07:10<00:00, 1.16MB/s]
Upload successful: epoch-5-val_loss-0.26fd1.h5 (476MB)
Starting upload for file epoch-4-val_loss-0.25fd2.h5
100%|████████████████████████████████████████| 476M/476M [07:20<00:00, 1.13MB/s]
Upload successful: epoch-4-val_loss-0.25fd2.h5 (476MB)
Starting upload for file epoch-7-val_loss-0.32fd3.h5
100%|████████████████████████████████████████| 476M/476M [07:14<00:

In [62]:
!ls -R {SAVE}

/nvme/GCI/public/lit/weights/bert_0724_231826:
config			      epoch-15-val_loss-0.30fd2.h5
dataset-metadata.json	      epoch-16-val_loss-0.28fd0.h5
epoch-10-val_loss-0.27fd4.h5  epoch-28-val_loss-0.29fd3.h5
epoch-10-val_loss-0.31fd1.h5  tokenizer

/nvme/GCI/public/lit/weights/bert_0724_231826/config:
config.json

/nvme/GCI/public/lit/weights/bert_0724_231826/tokenizer:
special_tokens_map.json  tokenizer_config.json	tokenizer.json	vocab.txt


In [12]:
def make_sub(SAVE):
    for di in SAVE.iterdir():
        if di.is_symlink():
            os.system(f"cp {SAVE}/dataset-metadata.json {di}/dataset-metadata.json")

In [13]:
make_sub(SAVE)