In [16]:
import pickle
import sys
from functools import partial
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from transformers import (AutoConfig, AutoModelForCausalLM,
                          PreTrainedTokenizerFast)

sys.path.append('..')

from utils import get_slice, load_external_module, train_collate_fn

load_external_module('lming', '/main/draft-v2/lming/__init__.py')

from lming.utils import from_tensor

In [8]:
model_config_dir = '/main/draft-v2/pavel-tikhomirov-runs/gpt-2-fdim-2:v3/'
model_config = AutoConfig.from_pretrained(model_config_dir + 'config.json')
device = 'cuda'
model_weights_dir = '/main/draft-v2/pavel-tikhomirov-runs/wandb/latest-run/checkpoints/test/model_40500_completion_ratio_intersection_5=0.4269.pt'

model_weights = torch.load(model_weights_dir)

model = AutoModelForCausalLM.from_config(model_config).to(device)
model = torch.compile(model)
model.load_state_dict(model_weights)
model.eval()
None

  model_weights = torch.load(model_weights_dir)


In [10]:
with open(Path("/main/draft-v2/pavel-tikhomirov-runs/fdim-2-whitehead:v0") / "test.pkl", 'rb') as file:
    test_dataset = get_slice(pickle.load(file), None)

tokenizer = PreTrainedTokenizerFast.from_pretrained(model_config_dir)
inference_loader = DataLoader(
    test_dataset, batch_size=1,
    shuffle=False, collate_fn=partial(train_collate_fn, tokenizer=tokenizer, fdim=2)
    )

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [11]:
batch = next(iter(inference_loader))

In [12]:
generation_config = {
    "max_length": 50,
    "suppress_tokens": ["y", "n", ":", "<s>"],
    "num_return_sequences": 5,
    "do_sample": True
    }
generation_config['suppress_tokens'] = tokenizer.convert_tokens_to_ids(generation_config['suppress_tokens'])

In [13]:
result = model.generate(**batch.to(device), **generation_config)

In [17]:
decodes = from_tensor(result.cpu(), tokenizer=tokenizer)

In [18]:
decodes

[[1,
  -2,
  1,
  2,
  -1,
  -1,
  -2,
  1,
  1,
  1,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  -2,
  -1,
  2,
  -1,
  2,
  -1,
  -1,
  -1,
  -2,
  -1,
  -2,
  -1,
  -1,
  -1,
  2,
  1],
 [1,
  -2,
  1,
  2,
  -1,
  -1,
  2,
  1,
  2,
  -1,
  -1,
  -1,
  -2,
  -1,
  -2,
  1,
  -2,
  -1,
  2,
  -1,
  -1,
  -1,
  -2,
  -1,
  2,
  1,
  1,
  1,
  -2,
  -1,
  -1,
  2,
  1,
  1,
  1,
  -2,
  1,
  1,
  2,
  -1],
 [1,
  -2,
  1,
  2,
  -1,
  -2,
  -1,
  -1,
  -1,
  2,
  2,
  1,
  1,
  1,
  -2,
  1,
  -2,
  -1,
  2,
  -1,
  -2,
  1,
  1,
  1,
  2,
  -1,
  -1,
  -1,
  -1],
 [1,
  -2,
  1,
  2,
  -1,
  -1,
  -1,
  -2,
  1,
  -2,
  -1,
  -1,
  -1,
  2,
  -1,
  2,
  1,
  1,
  1,
  1,
  -2,
  -1,
  2,
  1,
  1,
  1,
  -2,
  -1,
  -1,
  2,
  1,
  -2,
  1,
  1,
  2,
  -1,
  -1,
  -1,
  -2,
  -1,
  -1,
  -1,
  -1,
  -1],
 [1,
  -2,
  1,
  2,
  1,
  2,
  2,
  -1,
  -1,
  -1,
  -2,
  -2,
  1,
  -2,
  -1,
  2,
  -1,
  2,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  -2,
  -2]]