In [None]:
from __future__ import annotations

import json
from pprint import pprint
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from dataclasses import dataclass
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torch.optim import AdamW

import evaluate
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler

In [None]:
def seed_everything(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(22)

In [None]:
pd.options.display.float_format = '{:.3f}'.format
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
@dataclass
class IntentClassificationConfig:
    labels = [
        'product_search',
        'product_info',
        'order_status',
        'order_return',
        'operator',
        'payment',
        'authenticity',
    ]
    label_id_to_name = dict(enumerate(labels))
    label_name_to_id = {v: k for k, v in label_id_to_name.items()}
    
config = IntentClassificationConfig()

In [None]:
checkpoint = 'google-bert/bert-base-multilingual-cased'

tokenizer_kwargs = dict(return_tensors='pt', max_length=64, truncation=True, padding='max_length')
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
with open('data/intent_classification_dataset.json', 'r') as file:
    intent_classification_dataset: list = json.load(file)

In [None]:
for label in config.labels:
    if label.startswith('product_'):
        continue
    queries = Path(f'raw_data/{label}.txt').read_text().split('\n')
    for text in queries:
        intent_classification_dataset.append({
            'text': text,
            'label': label,
        })

In [None]:
with open('data/intent_classification_dataset.json', 'w') as file:
    json.dump(intent_classification_dataset, file)

In [None]:
dataset = Dataset.from_list(intent_classification_dataset)
dataset = dataset.train_test_split(test_size=0.05)

def preprocess_text(batch: list[str]):
    return tokenizer(batch, **tokenizer_kwargs)

dataset = dataset.map(lambda item: preprocess_text(item['text']), remove_columns=['text'], batched=True)

def preprocess_label(batch: list[str]):
    return {'label': torch.tensor([config.label_name_to_id[label] for label in batch])}

dataset = dataset.map(lambda item: preprocess_label(item['label']), remove_columns=['label'], batched=True)

In [None]:
dataset.set_format('torch')

In [None]:
train_dataloader = DataLoader(dataset['train'], shuffle=True, batch_size=8)
eval_dataloader = DataLoader(dataset['test'], batch_size=8)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=len(config.labels))
model = model.to(device)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name='linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps,
)

In [None]:
cross_entropy = nn.CrossEntropyLoss()

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        batch_fwd = {k: v for k, v in batch.items() if k != 'label'}
        outputs = model(**batch_fwd)
        
        logits = outputs['logits']
        labels = batch['label']
        
        loss = cross_entropy(logits, labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

In [None]:
metric = evaluate.load('accuracy')
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    batch_fwd = {k: v for k, v in batch.items() if k != 'label'}
    with torch.no_grad():
        outputs = model(**batch_fwd)

    logits = outputs.logits
    batch_predictions = torch.argmax(logits, dim=-1)
    
    metric.add_batch(predictions=batch_predictions.view(-1), references=batch['label'].view(-1))

metric.compute()

In [None]:
model.save_pretrained('models/intent_classification')