In [None]:
from __future__ import annotations

import re
import json
from pprint import pprint
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from dataclasses import dataclass
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torch.optim import AdamW

import evaluate
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, get_scheduler

In [None]:
def seed_everything(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(22)

In [None]:
pd.options.display.float_format = '{:.3f}'.format
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
@dataclass
class SlotFillingConfig:
    name_to_tag = {
        'category': 'CAT',
        'brand': 'BRAND',
        'model': 'MODEL',
        'price': 'PRICE',
        'rating': 'RAT',
    }
    tag_to_name = {v: k for k, v in name_to_tag.items()}
    tags = list(tag_to_name.keys())
    null_label = 'O'
    labels = sorted([null_label] + [f'B-{tag}' for tag in tags] + [f'I-{tag}' for tag in tags], 
                    key=lambda s: s[2] + s[0] if len(s) > 2 else '1')
    id_to_label = dict(enumerate(labels))
    label_to_id = {v: k for k, v in id_to_label.items()}
    num_labels = len(labels)

config = SlotFillingConfig()

In [None]:
checkpoint = 'google-bert/bert-base-multilingual-cased'

tokenizer_kwargs = dict(return_tensors='pt', max_length=64, truncation=True, padding='max_length')
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
pattern = r'|'.join(fr'({tag})' for tag in config.tags)

queries_product_search = Path('raw_data/product_search.txt').read_text().split('\n')
queries_product_info = Path('raw_data/product_info.txt').read_text().split('\n')

slot_filling_dataset = []           # text, tokens, labels
retrieval_dataset = []              # text, product
intent_classification_dataset = []  # text, intent

def label_data(data):
    """
    Create labeled tokenized dataset from raw data
    
    Args:
        data: input
    Returns:
        tuple (text, tokens, labels)
    """
    text = re.sub(r'\((.+?)\)[A-Z]+', r'\1', data)
    token_ids = tokenizer(text, **tokenizer_kwargs)['input_ids'][0]
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    labels = [config.null_label for _ in range(len(tokens))]
    slots = {}
    matched_tag = config.null_label
    i_data = 0
    i_tokens = 0
    while i_data < len(data) and i_tokens < len(tokens):
        data_char = data[i_data]
        token = tokens[i_tokens]
        if token in ('[CLS]', '[SEP]'):
            print(f'Technical token {token!r}')
            i_tokens += 1
        elif data_char == ' ':
            print(f'Space')
            i_data += 1
        elif data_char == '(':
            print(f'Opening parenthesis. Matching tag')
            i_data += 1
            data_rem = data[i_data:]
            i_clos_par = data_rem.find(')')
            value = data_rem[:i_clos_par]
            match = re.findall(pattern, data_rem[i_clos_par + 1 : i_clos_par + 6])[0]
            matched_tag = [tag for tag in match if tag != ''][0]
            slots[config.tag_to_name[matched_tag]] = value
            print(f'Found tag {matched_tag!r} with value {value!r}')
        elif data_char == ')':
            print(f'Closing parenthesis. Resetting tag to {config.null_label!r}')
            i_data += len(matched_tag)
            i_data += 1
            matched_tag = config.null_label
        elif matched_tag != config.null_label:
            print(f'Label token {token!r} as {matched_tag!r}')
            labels[i_tokens] = matched_tag
            i_data += len(token.strip('#'))
            i_tokens += 1
        else:
            print(f'Skipping chars \'', end='')
            for token_char in token:
                data_char = data[i_data]
                if token_char == '#':
                    print(f'#', end='')
                elif token_char != data_char:
                    raise ValueError(f'Token char {token_char!r} not equal to data char {data_char!r}')
                else:
                    print(f'{token_char}', end='')
                    i_data += 1
            print(f'\'\nToken end')
            i_tokens += 1
    labels_orig = deepcopy(labels)
    for i_label in range(1, len(labels)):
        if labels[i_label] == config.null_label:
            pass
        elif labels_orig[i_label - 1] != labels_orig[i_label]:
            labels[i_label] = f'B-{labels[i_label]}'
        elif labels_orig[i_label - 1] == labels_orig[i_label] != config.null_label:
            labels[i_label] = f'I-{labels[i_label]}'
    return text, tokens, labels, slots

for query in queries_product_search:
    print(f'{query=!r}')
    data, product, price = re.findall(r'(.+?); (.+?); (\d+)', query)[0]
    price = int(price)
    text, tokens, labels, slots = label_data(data)
    slot_filling_dataset.append({
        'raw_data': data,
        'text': text,
        'tokens': tokens,
        'labels': labels,
    })
    retrieval_item = {
        'raw_data': data,
        'text': text,
        'product': product,
        'price': price,
    }
    for tag_name in ['category', 'brand', 'model']:
        if tag_name in slots:
            retrieval_item[tag_name] = slots[tag_name]
    retrieval_dataset.append(retrieval_item)
    intent_classification_dataset.append({
        'text': text,
        'label': 'product_search',
    })
for query in queries_product_info:
    print(f'{query=}')
    data, price = re.findall(r'(.+?); (\d+)', query)[0]
    price = int(price)
    text, tokens, labels, slots = label_data(data)
    slot_filling_dataset.append({
        'raw_data': data,
        'text': text,
        'tokens': tokens,
        'labels': labels,
    })
    retrieval_item = {
        'raw_data': data,
        'text': text,
        'price': price,
    }
    for tag_name in ['brand', 'model']:
        if tag_name in slots:
            retrieval_item[tag_name] = slots[tag_name]
    retrieval_dataset.append(retrieval_item)
    intent_classification_dataset.append({
        'text': text,
        'label': 'product_info',
    })

In [None]:
with open('data/slot_filling_dataset.json', 'w') as file:
    json.dump(slot_filling_dataset, file)

In [None]:
with open('data/retrieval_dataset.json', 'w') as file:
    json.dump(retrieval_dataset, file)

In [None]:
with open('data/intent_classification_dataset.json', 'w') as file:
    json.dump(intent_classification_dataset, file)

In [None]:
dataset = Dataset.from_list(slot_filling_dataset)
dataset = dataset.remove_columns(['raw_data', 'tokens'])
dataset = dataset.train_test_split(test_size=0.05)

def preprocess_text(batch: list[str]):
    return tokenizer(batch, **tokenizer_kwargs)

dataset = dataset.map(lambda item: preprocess_text(item['text']), remove_columns=['text'], batched=True)

def preprocess_labels(batch: list[list[str]]):
    return {'labels': torch.tensor([[config.label_to_id[label] for label in labels] for labels in batch], 
                                   dtype=torch.int)}

dataset = dataset.map(lambda item: preprocess_labels(item['labels']), remove_columns=['labels'], batched=True)
dataset

In [None]:
dataset.set_format('torch')

In [None]:
train_dataloader = DataLoader(dataset['train'], shuffle=True, batch_size=8)
eval_dataloader = DataLoader(dataset['test'], batch_size=8)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(checkpoint, num_labels=len(config.labels))
model = model.to(device)

In [None]:
for name, param in model.named_parameters():
    if not param.requires_grad:
        print(name)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
class_weight = torch.full(size=(config.num_labels,), fill_value=1.0, dtype=torch.float).to(device)
class_weight[config.label_to_id[config.null_label]] = 0.1

cross_entropy = nn.CrossEntropyLoss(weight=class_weight)

In [None]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name='linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps,
)

In [None]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        
        logits = outputs['logits']
        labels = batch['labels']
        
        loss = cross_entropy(logits.transpose(-1, -2), labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

In [None]:
metric = evaluate.load('accuracy')
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    batch_predictions = torch.argmax(logits, dim=-1)
    for token_ids, labels, predictions in zip(batch['input_ids'], batch['labels'], batch_predictions):
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        tokens = [token for token in tokens if token != '[PAD]']
        labels = [config.id_to_label[label] for label in labels.tolist()]
        predictions = [config.id_to_label[label] for label in predictions.tolist()]
        result = pd.DataFrame(zip(tokens, labels, predictions), columns=['token', 'label', 'prediction'])
        pprint(result)
    
    metric.add_batch(predictions=batch_predictions.view(-1), references=batch['labels'].view(-1))

metric.compute()

In [None]:
model.save_pretrained('models/slot_filling')

In [None]:
tokenizer_kwargs = dict(return_tensors='pt', max_length=64, truncation=True, padding='max_length')
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-multilingual-cased')

def preprocess_text(batch: list[str]):
    batch = tokenizer(batch, **tokenizer_kwargs)
    return {k: v.to(device) for k, v in batch.items()}

In [None]:
model = AutoModelForTokenClassification.from_pretrained('models/slot_filling').to(device)

In [None]:
# query = 'Что включает Microsoft Surface Laptop 2?'
query = 'Мне нужен смартфон Xiaomi до 15000, который имеет рейтинг 4.5.'

tokens_ids = preprocess_text([query])
tokens = tokenizer.convert_ids_to_tokens(tokens_ids['input_ids'][0])

In [None]:
with torch.no_grad():
    outputs = model(**tokens_ids)
logits = outputs.logits[0]
predictions = torch.argmax(logits, dim=-1)
predictions = [config.id_to_label[label] for label in predictions.tolist()]
tokens_and_labels = list(zip(tokens, predictions))
tokens_and_labels = [(token, label) for token, label in tokens_and_labels if token != '[PAD]']
tags = {}
i = 0
while i < len(tokens_and_labels):
    token, label = tokens_and_labels[i]
    if label == config.null_label:
        i += 1
        continue
    if label.startswith('B-'):
        tag = label[2:]
        if tag in tags:
            continue
        tags[tag] = [token]
        i += 1
        token, label = tokens_and_labels[i]
        while label == f'I-{tag}':
            tags[tag].append(token)
            i += 1
            token, label = tokens_and_labels[i]
tags = {tag: tokenizer.decode(tokenizer.convert_tokens_to_ids(tokens)) for tag, tokens in tags.items()}
tags