In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np


class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        return text, label

# 70で作成した行列をロード
X_train = np.load("./matrix/x_train.npy")
Y_train = np.load("./matrix/y_train.npy")

X_train_tensor = torch.from_numpy(X_train)
Y_train_tensor =torch.from_numpy(Y_train)
datasets = TextDataset(X_train, Y_train)
train_dataloader = DataLoader(datasets, shuffle=True, batch_size=64)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(300, 4),
            nn.Softmax()
        )

    def forward(self, x):
        logits = self.layer(x)
        return logits

model = Model()
learning_rate = 1e-3
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [2]:
import time

batch_list = [1,2,4,8,16,32,64,128,256]

for batch_size in batch_list:
    train_dataloader = DataLoader(datasets, shuffle=True, batch_size=batch_size)

    start_time = time.time()

    size = len(datasets)
    epoch_loss = 0
    epoch_correct = 0
    total_samples = 0

    for batch, (X, y) in enumerate(train_dataloader):
        # 予測と損失の計算
        X = X.float()
        pred = model(X)
        loss = loss_fn(pred, y.to(torch.float64))

        # バックプロパゲーション
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 損失の累計
        epoch_loss += loss.item()

        # 正答数の累計
        predicted = pred.argmax(dim=1)
        correct = (predicted == y.argmax(dim=1)).sum().item()
        epoch_correct += correct
        total_samples += len(y)

    end_time = time.time()
    epoch_time = end_time - start_time
    # エポックごとの平均損失と正答率を計算
    avg_loss = epoch_loss / len(train_dataloader)
    avg_accuracy = epoch_correct / total_samples
    print(f"batch_size:{batch_size}, time:{epoch_time}, avg_loss:{avg_loss}, avg_accuracy:{avg_accuracy}")


  return self._call_impl(*args, **kwargs)


batch_size:1, time:1.4311411380767822, avg_loss:1.3176846547556196, avg_accuracy:0.7042451504076469
batch_size:2, time:1.0566000938415527, avg_loss:1.232331269550538, avg_accuracy:0.7735919782588323
batch_size:4, time:0.5227220058441162, avg_loss:1.1998176696090803, avg_accuracy:0.7757473526379908
batch_size:8, time:0.28849291801452637, avg_loss:1.1857348444268607, avg_accuracy:0.7765907600037485
batch_size:16, time:0.16674208641052246, avg_loss:1.1792018226877055, avg_accuracy:0.7764033361446913
batch_size:32, time:0.11226010322570801, avg_loss:1.1761334280042859, avg_accuracy:0.7765907600037485
batch_size:64, time:0.06622600555419922, avg_loss:1.174546332708074, avg_accuracy:0.7764033361446913
batch_size:128, time:0.053218841552734375, avg_loss:1.173900033266278, avg_accuracy:0.7764970480742198
batch_size:256, time:0.03944587707519531, avg_loss:1.173220598719298, avg_accuracy:0.7764970480742198
