In [25]:
import polars as pl
import re
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

# 学習データからテキストのみをリストで抽出
FILE_PATH = "../第6章/news+aggregator/train.txt"
df = pl.read_csv(FILE_PATH, separator="\t", new_columns=["text", "categoory"])
text_list = df["text"].to_list()

#　記号を取り除く
cleaned_text_list = [re.sub(r'[^a-zA-Z\s]', "", text) for text in text_list]
word_list = [text.split() for text in cleaned_text_list]

word_frequency_dict = {}
# 単語リストをループして出現頻度をカウント
for words in word_list:
    for word in words:
        if word in word_frequency_dict:
            word_frequency_dict[word] += 1
        else:
            word_frequency_dict[word] = 1

sorted_word_frequency_dict = dict(sorted(word_frequency_dict.items(), key=lambda item: item[1], reverse=True)) #降順にソート


word_id_map = {}    #単語とidの辞書型
id = 1
for key, value in sorted_word_frequency_dict.items():
    if value == 1: #出現頻度が1のidは0
        word_id_map[key] = 0
    else:
        word_id_map[key] = id
        id += 1
# wordのidは0~9509の計9510個のidが存在する
vocab_size = max(word_id_map.values()) + 1

def get_index_vector(words, word_id_map):
    # まず、必要な形状のテンソルを初期化
    vectors = torch.zeros(len(words))
    # 各単語の位置に1をセット
    for i, word in enumerate(words):
        vectors[i] = word_id_map[word]
    return vectors.long()

In [26]:
from torch.nn.utils.rnn import pad_sequence

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        return text, label


X_train = [get_index_vector(words, word_id_map) for words in word_list] #文章をidに変換したリスト
Y_train = np.load("../第8章/matrix/y_train.npy")    #ラベルのロード(onehot vector)
Y_train =torch.from_numpy(Y_train)  #tensorに変換
datasets = TextDataset(X_train, Y_train)    #データセットとする
train_dataloader = DataLoader(datasets, shuffle=True, batch_size=1)    #データローダーの定義

In [39]:
from torch.nn.utils.rnn import pack_padded_sequence
from gensim.models.keyedvectors import KeyedVectors

#　最終層にsoftmaxは不要(クロスエントロピーの内部でsoftmaxをかけてくれるので)
class LSTMModel(nn.Module):
    def __init__(self,vocab_size=vocab_size ,embedding_dim = 300, hidden_dim = 50, output_dim = 4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.LSTM = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bias=True)
        self.linear = nn.Linear(hidden_dim, output_dim, bias=True)

    def forward(self, x, h=None):
        x = self.emb(x)
        y, (h, c)= self.LSTM(x, h)
        y = self.linear(h[-1])
        return y

In [45]:
model = LSTMModel()
learning_rate = 1e-2
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [46]:
from tqdm import tqdm

for t in tqdm(range(epochs)):
    size = len(train_dataloader.dataset)
    correct = 0
    for batch, (X, y) in enumerate(train_dataloader):
        # 予測と損失の計算
        y = y.argmax(dim=1) #loss_fnのyにはクラスインデックスが期待されている
        pred = model.forward(X)
        loss = loss_fn(pred, y)

        # バックプロパゲーション
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        correct += (pred.argmax(dim=1) == y).sum().item()

    loss = loss.item()
    print(f"epoch:{t+1}, loss: {loss:>7f}, accuracy: {correct/size}")

  5%|▌         | 1/20 [00:30<09:30, 30.05s/it]

epoch:1, loss: 0.823250, accuracy: 0.41302717900656044


 10%|█         | 2/20 [01:01<09:12, 30.72s/it]

epoch:2, loss: 0.986841, accuracy: 0.4202436738519213


 15%|█▌        | 3/20 [01:31<08:35, 30.33s/it]

epoch:3, loss: 2.296457, accuracy: 0.43467666354264295


 20%|██        | 4/20 [02:01<08:06, 30.40s/it]

epoch:4, loss: 0.802792, accuracy: 0.4591377694470478


 25%|██▌       | 5/20 [02:30<07:30, 30.04s/it]

epoch:5, loss: 0.794821, accuracy: 0.49194001874414245


 25%|██▌       | 5/20 [02:37<07:53, 31.54s/it]


KeyboardInterrupt: 