# Libraries

In [10]:
%%capture
! pip install transformers
! pip install hazm
from hazm import *
import copy
import transformers
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, recall_score, roc_auc_score, precision_score
import math
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import json
from copy import deepcopy
import numpy as np
import random
import re
import string
import codecs
from shutil import copyfile
random.seed(12345)
label_encoder = preprocessing.LabelEncoder()

# downloading persian stopwords
# ######removed

# Preprocessing class
class Preprocessing:

  @staticmethod
  def remove_punctuations(text):
    new_text = []
    for l in text:
      if l not in string.punctuation + '\u00AB' + '\u00BB' + '\u060C' + '\u061B' + '\u061F':
        new_text.append(l)
      else:
        new_text.append(' ')
    return ''.join(new_text)
  
  @staticmethod
  def remove_numbers(text):
    new_text = []
    for l in text:
      if l not in '0123456789۰۱۲۳۴۵۶۷۸۹':
        new_text.append(l)
      else:
        new_text.append(' ')
    return ''.join(new_text)

  @staticmethod
  def remove_stopwords(text):
    normalizer = Normalizer()
    stopwords = [normalizer.normalize(x.strip()) for x in codecs.open('stopwords.txt','r','utf-8').readlines()]
    tokens = word_tokenize(text)
    new_text = []
    for token in tokens:
      if token not in stopwords:
        new_text.append(token)
      else:
        new_text.append(' ')
    return ' '.join(new_text)

  @staticmethod
  def remove_extra_space(text):
    new_text = re.sub(r'\s+',' ',text)
    return new_text

# a class to hold our data structure
class Data:
  def __init__(self, data):
    self.text = Preprocessing.remove_extra_space(Preprocessing.remove_stopwords(Preprocessing.remove_numbers(Preprocessing.remove_punctuations(data['text']))))
    self.category = data['category']

# loading pars roberta and tokenizer
from transformers import AutoConfig, AutoTokenizer, AutoModel, TFAutoModel
# v3.0
model_name_or_path = "HooshvareLab/roberta-fa-zwnj-base"
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast='True')
# model = TFAutoModel.from_pretrained(model_name_or_path)  For TF
parsbert = AutoModel.from_pretrained(model_name_or_path)


# defining our transformer model
class TransformerModel(nn.Module):

  def __init__(self, roberta):
    super(TransformerModel, self).__init__()
    self.roberta = roberta
    # we only use one linear head on the parsbert
    self.linear_head = nn.Linear(768, len(label_encoder.classes_))

  def forward(self, x):
    # main task
    x = self.roberta(x['input_ids'],x['attention_mask'])
    logits = self.linear_head(x.pooler_output)
    return logits

# load model
## removed
! pip install -U --no-cache-dir gdown --pre
## removed
model = torch.load('project_roberta_final_category.pth')

# Test The Result

In [17]:
def test(data):
  data = Data(data)
  device = 'cuda:0'
  func_model = model.to(device)
  text_tokens = tokenizer.encode_plus(
    str(data.category),
    str(data.text),
    add_special_tokens=True,
    max_length=512,
    pad_to_max_length=True,
    return_tensors="pt",
    truncation=True)
  input_ids = text_tokens["input_ids"].view(1, -1).to(device)
  attention_mask = text_tokens["attention_mask"].view(1, -1).to(device)
  feed_dict = {
    'input_ids': input_ids,
    'attention_mask': attention_mask}
  output = func_model(feed_dict)
  pred = output.argmax(dim=1, keepdim=True)
  if pred.item() == 1:
    return 'مهم'
  else:
    return 'غیرمهم'

In [20]:
# test
test()



'مهم'