In [127]:
from gensim.models.fasttext import FastText
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from torchmetrics.classification import MulticlassAccuracy
from collections import Counter

In [117]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True, nonlinearity='tanh')
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, embedded_input):
        # Initialize hidden state with zeros
        hidden_state = torch.zeros(1, 1, self.hidden_size).to(device)
        # Pass the input sequence through the RNN layer
        rnn_output, hidden_state = self.rnn(embedded_input, hidden_state)
        # Reshape the output to be of shape (batch_size * sequence_length, hidden_size)
        rnn_output = rnn_output.contiguous().view(-1, self.hidden_size)
        # Pass the RNN output through the fully connected layer to get the predicted tags
        predicted_tags = self.fc(rnn_output)
        
        return predicted_tags

In [17]:
class NERDataset(Dataset):
    def __init__(self, embedded_sentences):
        self.embedded_sentences = embedded_sentences
    def __len__(self):
        return len(self.embedded_sentences)
    def __getitem__(self, idx):
        return self.embedded_sentences[idx]

In [18]:
data = []
labels = []
with open("rus.tsv", "r", encoding="utf-8") as file:
    for line in file:
        tokens = list(line.strip().split("\t"))
        data.append(tokens[0])
        labels.append(tokens[-1])

In [19]:
label_unique = list(set(labels))
label_unique.remove('')

In [20]:
model = FastText(sentences=data, window=5, min_count=1, workers=4, sg=1)

In [21]:
embedded_input = [model.wv.get_vector(word) for word in data]

In [22]:
targets = torch.Tensor(np.array(embedded_input))
labels = LabelEncoder.fit_transform(targets, labels)

In [23]:
targets = list(zip(targets, labels))

In [147]:
counts = Counter(labels)
class_weights = []
values = dict(counts).values()

for val in values:
    class_weights.append((val/sum(values)))
class_weights = torch.Tensor(class_weights).to(device)

In [151]:
input_size = 100
hidden_size = 8
output_size = len(label_unique)+1
num_epochs = 1500
bs = 64
lr = 8e-3
wd = 6e-3
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [152]:
nn_model = RNN(input_size, hidden_size, output_size)
nn_model = nn_model.to(device)
dataloader = DataLoader(targets, batch_size=bs, shuffle=True)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(nn_model.parameters(), lr=lr, weight_decay=wd)
acc = MulticlassAccuracy(num_classes=output_size).to(device)

In [153]:
for epoch in range(num_epochs):
    total_loss = 0
    total_acc = 0
    for batch, gt in dataloader:
        batch, gt = batch.to(device), gt.to(device)
        optimizer.zero_grad()
        batch = batch.reshape(1,-1,100)
        outputs = nn_model(batch)
        loss = criterion(outputs.reshape(-1, output_size), gt.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_acc += acc(outputs, gt).item()
    print(f'Epoch {epoch+1}: Loss {total_loss/len(dataloader):.4f}')
    print(total_acc/len(dataloader))

Epoch 1: Loss 0.7250
0.1065501601252097
Epoch 2: Loss 0.5200
0.10742568218433707
Epoch 3: Loss 0.6651
0.1035907184683219
Epoch 4: Loss 0.5634
0.10976530858179141
Epoch 5: Loss 0.5701
0.12064612838322741
Epoch 6: Loss 0.5105
0.12547682309094582
Epoch 7: Loss 0.5663
0.129019239247157
Epoch 8: Loss 0.5528
0.12806593609494488
Epoch 9: Loss 0.5071
0.1341438369463085
Epoch 10: Loss 0.5659
0.13456476488731273
Epoch 11: Loss 0.5156
0.1289905665304013
Epoch 12: Loss 0.5770
0.132885534591022
Epoch 13: Loss 0.5245
0.1351421860194761
Epoch 14: Loss 0.5346
0.1391786895935895
Epoch 15: Loss 0.5984
0.13587902692267398
Epoch 16: Loss 0.5321
0.1387361732338824
Epoch 17: Loss 0.5948
0.13714356851378245
Epoch 18: Loss 0.5394
0.1393515727147616
Epoch 19: Loss 0.6094
0.13598029633641057
Epoch 20: Loss 0.5346
0.13841744755476587
Epoch 21: Loss 0.5190
0.133018735498112
Epoch 22: Loss 0.5164
0.13621432588252563
Epoch 23: Loss 0.5941
0.13723670307032484
Epoch 24: Loss 0.5129
0.1408210347161071
Epoch 25: Loss 0

KeyboardInterrupt: 