In [1]:
import torch 
import torchtext
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER

In [2]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
def prepare_model():
    num_classes = 5
    input_dim = 768

    classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim)
    model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
    
    DEMO_MODEL_PATH = '/kaggle/input/ds-official-models-500-filtered-max-weighted-f1/model_max_weighted_f1.pth'
    model.load_state_dict(torch.load(DEMO_MODEL_PATH))
    model.to(DEVICE)
    
    print(f'Loaded model to [{DEVICE}] in [{DEMO_MODEL_PATH}]')
    
    return model

In [4]:
def prepare_text_transform():
    text_transform = torchtext.models.XLMR_LARGE_ENCODER.transform()
    return text_transform

In [5]:
def predict(sentence, model, text_transform, label_map):
    transformed_text = text_transform(sentence)
    out = model(torch.tensor([transformed_text]).to(DEVICE))
    return label_map[torch.argmax(out).item()]

In [6]:
label_map = {
    0: 'insult',
    1: 'neutral',
    2: 'politics',
    3: 'religion',
    4: 'terrorism'
}

In [7]:
model = prepare_model()
text_transform = prepare_text_transform()

Downloading: "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" to /root/.cache/torch/hub/checkpoints/xlmr.base.encoder.pt


  0%|          | 0.00/1.03G [00:00<?, ?B/s]

Loaded model to [cuda] in [/kaggle/input/ds-official-models-500-filtered-max-weighted-f1/model_max_weighted_f1.pth]


100%|██████████| 5.07M/5.07M [00:01<00:00, 2.74MB/s]
Downloading: "https://download.pytorch.org/models/text/xlmr.vocab.pt" to /root/.cache/torch/hub/checkpoints/xlmr.vocab.pt


  0%|          | 0.00/4.85M [00:00<?, ?B/s]

In [13]:
predict(input(), model, text_transform, label_map)

 Our constant need for entertainment has blurred the line between fiction and reality—on television, in American politics, and in our everyday lives.


'neutral'