In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import json
import re

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig

import sys
from eval_utils import get_parser, compute_test_metrics
from utils import Format, get_to_string_processor



In [2]:
with open("configs/config_ruT5-base-st.json", "rb") as config:
    params = json.load(config)

params

{'format': 'SpecTokens',
 'max_bundles': 5,
 'model': 'ai-forever/ruT5-base',
 'add_nl_token': False,
 'add_eos_token': False,
 'change_pad_to_eos': False,
 'shuffle_bundles': True,
 'save_folder': 'ruT5-base',
 'train': {'n_epochs': 10,
  'lr': 5e-05,
  'batch_size': 16,
  'weight_decay': 0.01,
  'scheduler': 'cosine',
  'warmup_steps': 500,
  'fp16': True},
 'eval': {'batch_size': 16, 'show': 5}}

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [4]:
out_format = Format.SpecTokens if params["format"] == "SpecTokens" else Format.JustJson
out_format

<Format.SpecTokens: 0>

In [5]:
data = pd.read_csv("~/work/resources/data/ads_test_1000.csv")
data.head()

Unnamed: 0,Text,bundles,n_bundles
0,"самокат hudora, в отличном состоянии, от 5+ и ...","[{""Title"": ""\u0441\u0430\u043c\u043e\u043a\u04...",1
1,2 мяча и корзина 5€ лимассол,"[{""Title"": ""\u043d\u0430\u0431\u043e\u0440 \u0...",1
2,принимаются предзаказы на 100% органическое ма...,"[{""Title"": ""100% \u043e\u0440\u0433\u0430\u043...",1
3,"колонки, в рабочем состоянии! использовались р...","[{""Title"": ""\u043a\u043e\u043b\u043e\u043d\u04...",1
4,гироскутер 100 евро с зарядным,"[{""Title"": ""\u0433\u0438\u0440\u043e\u0441\u04...",1


In [6]:
data["n_bundles"].value_counts()

1     735
0      89
2      54
3      34
4      25
5      18
6      16
7      10
10      7
9       7
8       2
36      1
14      1
13      1
Name: n_bundles, dtype: int64

In [7]:
ckpt = params["save_folder"]

In [8]:
model_checkpoint = f"../good_checkpoints/{ckpt}"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
parser = get_parser(tokenizer, out_format)
to_string_processor = get_to_string_processor(out_format)

In [12]:
responses = []

model.to(device)
bs = 16
for ind in tqdm(range(0, len(data), bs), total=(len(data) + bs - 1) // bs):
    tokenized = tokenizer([data.loc[data.index[i], "Text"] + (tokenizer.eos_token if params.get("add_eos_token", False) else "")
                   for i in range(ind, min(ind + bs, len(data)))], max_length=512, padding=True, truncation=True, return_tensors="pt")["input_ids"]
    preds = model.generate(
        input_ids=tokenized.to(device),
        max_length=512,
        num_beams=4,
        early_stopping=True,
        eos_token_id=tokenizer.eos_token_id
    ).cpu()
    
    preds = torch.where(preds == -100, tokenizer.eos_token_id, preds)
    preds = tokenizer.batch_decode(preds, ignore_special_tokens=True)
    responses += [re.sub(tokenizer.pad_token, "", pred) for pred in preds]

  0%|          | 0/63 [00:00<?, ?it/s]

In [13]:
data["Responses"] = responses
data.head()

Unnamed: 0,Text,bundles,n_bundles,Responses
0,"самокат hudora, в отличном состоянии, от 5+ и ...","[{""Title"": ""\u0441\u0430\u043c\u043e\u043a\u04...",1,<BOB> <BOT> самокат hudora <EOT> <BOP> 65 <EOP...
1,2 мяча и корзина 5€ лимассол,"[{""Title"": ""\u043d\u0430\u0431\u043e\u0440 \u0...",1,<BOB> <BOT> 2 мяча и корзина <EOT> <BOP> 5 <EO...
2,принимаются предзаказы на 100% органическое ма...,"[{""Title"": ""100% \u043e\u0440\u0433\u0430\u043...",1,<BOB> <BOT> Органическое масло миндаля 100 мл ...
3,"колонки, в рабочем состоянии! использовались р...","[{""Title"": ""\u043a\u043e\u043b\u043e\u043d\u04...",1,"<BOB> <BOT> колонки, в рабочем состоянии <EOT>..."
4,гироскутер 100 евро с зарядным,"[{""Title"": ""\u0433\u0438\u0440\u043e\u0441\u04...",1,<BOB> <BOT> гироскутер с зарядным устройством ...


In [14]:
data.to_csv(f"~/work/resources/bench_results/{ckpt}_preds.csv", index=False)