In [2]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import sys
import jsonlines
import joblib
import csv
import operator
from collections import Counter

TOKENIZER_MAX_LENGTH = 512
CLASSIFIER_NUM_CEFR_LEVELS = 6
DEVICE = "cuda"

class CEFRClassifier(nn.Module):
    def __init__(self, num_cefr_levels):
        super(CEFRClassifier, self).__init__()
        
        self.bert = AutoModel.from_pretrained("distilbert/distilbert-base-uncased")

        # Freeze distilBERT params
        for param in self.bert.parameters():
            param.requires_grad = False
        
        self.fc1 = nn.Linear(768, 768)
        self.fc2 = nn.Linear(768, 128)
        self.fc3 = nn.Linear(28, 28)
        self.output = nn.Linear(128 + 28, num_cefr_levels)

        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, attention_mask, aux_features):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = bert_output.last_hidden_state
        pooled_output = sequence_output[:, 0].squeeze()

        x = self.fc1(pooled_output)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        y = self.fc3(aux_features)
        y = self.relu(y)
        y = self.dropout(y)
        
        return self.output(torch.cat((x,y), -1))

def classify(text, features, classifier, tokenizer):
    tokenized = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        max_length=TOKENIZER_MAX_LENGTH,
        truncation=True,
    )

    features = torch.tensor(features).to(DEVICE)
    logits = classifier.forward(tokenized["input_ids"].squeeze().to(DEVICE), tokenized["attention_mask"].squeeze().to(DEVICE), features)
    return torch.argmax(torch.nn.functional.softmax(logits, dim=-1)).item()


load_path = "runs/model_20240712_092923_9"
load_path_dataset = "simplewiki_preprocessed.jsonl"
classifier = CEFRClassifier(CLASSIFIER_NUM_CEFR_LEVELS)
classifier.load_state_dict(torch.load(load_path, map_location=torch.device(DEVICE)))
classifier.to(DEVICE)
classifier.eval()

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

n_total = 253830

with jsonlines.open(load_path_dataset) as reader:
    for idx, obj in enumerate(reader):
        if idx % 1000 == 0:
            print(f"{idx} / {n_total}")
        
        with jsonlines.open("simplewiki_classified.jsonl", mode='a') as writer:
            label = classify(obj["text"], obj["features"], classifier, tokenizer)

            writer.write({"page_id": obj["page_id"], "title": obj["title"], "text_len": obj["text_length"], "text": obj["text"], "label": label})

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

0 / 253830
1000 / 253830
2000 / 253830
3000 / 253830
4000 / 253830
5000 / 253830
6000 / 253830
7000 / 253830
8000 / 253830
9000 / 253830
10000 / 253830
11000 / 253830
12000 / 253830
13000 / 253830
14000 / 253830
15000 / 253830
16000 / 253830
17000 / 253830
18000 / 253830
19000 / 253830
20000 / 253830
21000 / 253830
22000 / 253830
23000 / 253830
24000 / 253830
25000 / 253830
26000 / 253830
27000 / 253830
28000 / 253830
29000 / 253830
30000 / 253830
31000 / 253830
32000 / 253830
33000 / 253830
34000 / 253830
35000 / 253830
36000 / 253830
37000 / 253830
38000 / 253830
39000 / 253830
40000 / 253830
41000 / 253830
42000 / 253830
43000 / 253830
44000 / 253830
45000 / 253830
46000 / 253830
47000 / 253830
48000 / 253830
49000 / 253830
50000 / 253830
51000 / 253830
52000 / 253830
53000 / 253830
54000 / 253830
55000 / 253830
56000 / 253830
57000 / 253830
58000 / 253830
59000 / 253830
60000 / 253830
61000 / 253830
62000 / 253830
63000 / 253830
64000 / 253830
65000 / 253830
66000 / 253830
67000 / 