In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [2]:
data = pd.read_csv('/kaggle/input/meatinfo/meatinfo.csv', sep=';')

In [5]:
data = data.groupby('mtype').filter(lambda x: len(x) >= 500)
train_data, test_data = train_test_split(data, test_size=0.2)

In [6]:
vocab = {}
for text in train_data['text']:
    for word in text.split():
        if word not in vocab:
            vocab[word] = len(vocab)

In [7]:
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, output_size):
        super(TextClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, 100)
        self.fc1 = nn.Linear(100, 64)
        self.fc2 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.embedding(x)
        x = torch.mean(x, dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def text_to_numbers(text):
    numbers = []
    for word in text.split():
        if word in vocab:
            numbers.append(vocab[word])
    return torch.tensor(numbers)

In [8]:
model = TextClassifier(len(vocab), len(data['mtype'].unique()))
class_to_idx = {mtype: i for i, mtype in enumerate(data['mtype'].unique())}
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [9]:
for epoch in range(10):
    running_loss = 0.0
    for text, label in zip(train_data['text'], train_data['mtype']):
        numbers = text_to_numbers(text)
        optimizer.zero_grad()
        output = model(numbers.unsqueeze(0))
        target = torch.tensor([class_to_idx[label]]) # преобразование строки в число
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(train_data)))

Epoch 1 loss: 0.414
Epoch 2 loss: 0.137
Epoch 3 loss: 0.072
Epoch 4 loss: 0.042
Epoch 5 loss: 0.024
Epoch 6 loss: 0.016
Epoch 7 loss: 0.011
Epoch 8 loss: 0.009
Epoch 9 loss: 0.008
Epoch 10 loss: 0.007


In [10]:
torch.save(model.state_dict(), "text_classifier.pt")