<a href="https://colab.research.google.com/github/learnerwcl/colab/blob/main/BiRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!cd /content
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -zxvf aclImdb_v1.tar.gz 2>&1 > /dev/null

--2025-01-22 15:53:53--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’


2025-01-22 15:53:56 (23.8 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]



In [9]:
import glob
from collections import Counter
import re
import os

import torch
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score  # 计算 AUC

from tqdm import tqdm  # 可选，用于显示进度条

def grad_clipping(net, theta):
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

def clean_text(text):
    text = re.sub(r"[^\w\s]", "", text)
    text = text.lower().strip()
    return text

def build_movie_vocab_chuncked(root_dir, min_freq=20):
    counter = Counter()
    all_file = glob.glob(os.path.join(root_dir,"**/*.txt"), recursive=True)
    for fn in all_file:
        with open(fn, 'r') as file:
            text = file.read()
            text = clean_text(text)
            words = text.split(" ")
            counter.update(words)

    counter = {word:freq for word,freq in counter.items() if freq>=min_freq}
    vocab = {word: idx for idx, (word, freq) in enumerate(counter.items(), start=2)}
    vocab["<PAD>"] = 0
    vocab["<UNK>"] = 1
    return vocab

In [10]:
# LazyLoader
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob


class ImbdDataSet(Dataset):
  def __init__(self, root_path, vocab, max_length=128, data_type='trian', transform=None):
    self.data_path_list = []
    self.label_list = []
    self.transform = transform
    self.vocab = vocab
    self.max_length = max_length

    self.root_path = root_path

    pos_path = os.path.join(root_path, data_type, 'pos')
    neg_path = os.path.join(root_path, data_type, 'neg')

    for item in glob.glob(os.path.join(pos_path,"*.txt")):
      self.label_list.append(1)
      self.data_path_list.append(item)

    for item in glob.glob(os.path.join(neg_path,"*.txt")):
      self.label_list.append(0)
      self.data_path_list.append(item)


  def __len__(self):
    return len(self.data_path_list)

  def __getitem__(self, idx):
    label_ = self.label_list[idx]
    path_ = self.data_path_list[idx]
    with open(path_,'r') as f:
      data_ = f.read()

    data_ = clean_text(data_)

    words = data_.split(" ")
    data_ = [self.vocab.get(word, self.vocab['<UNK>']) for word in words]

    # 将数据处理为定长
    if len(data_) > self.max_length:  # 截断
        data_ = data_[:self.max_length]
    else:  # 填充
        data_ = data_ + [self.vocab['<PAD>']] * (self.max_length - len(data_))

    if self.transform:
        data_ = self.transform(data_)

    return torch.tensor(data_, dtype=torch.long), torch.tensor(label_, dtype=torch.long)



In [6]:
vocab = build_movie_vocab_chuncked("/content/aclImdb", 512)
print(f"vocab size: {len(vocab)}")
train_data = ImbdDataSet("/content/aclImdb", vocab, max_length=128, data_type='train')
test_data = ImbdDataSet("/content/aclImdb", vocab, max_length=128, data_type='test')
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

vocab size: 3396


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleRNNLayer(nn.Module):
  def __init__(self, input_size, hidden_size):
    super(SimpleRNNLayer, self).__init__()

    self.Wxh = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bh = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.reset_parameters()

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.Wxh)
    nn.init.xavier_uniform_(self.Whh)
    nn.init.zeros_(self.bh)

  def forward(self, inputs, h_):
    # inputs: (batch_size, input_size)

    h_ = torch.tanh(
        inputs @ self.Wxh.T +
        h_ @ self.Whh.T  +
        self.bh.T
    )

    # outputs: (batch_size, hidden_size)

    return h_

In [6]:
class BiRNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, embedding_size):
    super(BiRNN, self).__init__()

    self.hidden_size = hidden_size
    self.embed = nn.Embedding(input_size, embedding_size)

    self.front_rnn = SimpleRNNLayer(embedding_size, hidden_size)
    self.back_rnn = SimpleRNNLayer(embedding_size, hidden_size)

    self.fc = nn.Linear(2*hidden_size, output_size)


  def forward(self, inputs):
    hidden_front = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)
    hidden_back = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)

    # 1. embedding: (batch_size, seq_len, embed_dim)
    inputs = self.embed(inputs)

    # 2. transpose: (seq_len, batch_size, embed_dim)
    inputs = torch.transpose(inputs, 0, 1)

    front_outputs = []
    back_outputs = []

    # 3. x: (batch_size, embed_dim)
    for x in inputs:
      hidden_front = self.front_rnn(x, hidden_front)
      # batch_size, hidden_size
      front_outputs.append(hidden_front)

    for x in reversed(inputs):
      hidden_back = self.back_rnn(x, hidden_back)
      # batch_size, hidden_size
      back_outputs.append(hidden_back)

    back_outputs = back_outputs[::-1]

    # 4. concat

    hidden = torch.cat((hidden_front, hidden_back), dim=1)

    hidden = self.fc(hidden)

    return hidden, (front_outputs, back_outputs)

In [7]:
model = BiRNN(len(vocab), 128, 2, 128 )  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda


In [15]:
def train_model_new(model, dataloader, evalloader, criterion, optimizer, device, scheduler=None, epochs=10):
    """
    训练模型的通用函数。

    参数：
    - model: 定义好的神经网络模型。
    - dataloader: 数据加载器（训练集）。
    - criterion: 损失函数。
    - optimizer: 优化器。
    - device: 训练设备（"cuda" 或 "cpu"）。
    - epochs: 训练轮数。

    返回：
    - model: 训练后的模型。
    - metrics: 包含训练过程中的损失和其他指标。
    """
    model.to(device)  # 将模型加载到设备
    metrics = {"loss": [], "auc": [], 'eval_loss': [], 'eval_auc': []}  # 记录每个 epoch 的损失

    for epoch in range(epochs):
        model.train()  # 设置模型为训练模式
        if scheduler:
          scheduler.step()
        epoch_loss = 0.0
        all_labels = []  # 存储真实标签
        all_probs = []  # 存储预测概率
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch in progress_bar:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            # 前向传播
            outputs, _  = model(inputs)

            # outputs = outputs.squeeze(-1)
            loss = criterion(outputs, labels)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()

            grad_clipping(model, 1)

            optimizer.step()

            # 累加损失
            epoch_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()  # 假设二分类，取第二类概率
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())


        # 记录每个 epoch 的平均损失
        avg_loss = epoch_loss / len(dataloader)
        metrics["loss"].append(avg_loss)
        epoch_auc = roc_auc_score(all_labels, all_probs)
        metrics["auc"].append(epoch_auc)

        model.eval()
        eval_loss = 0.0
        eval_labels = []  # 存储真实标签
        eval_probs = []  # 存储预测概率

        for batch_eval in evalloader:
          inputs_eval, labels_eval = batch_eval
          inputs_eval, labels_eval = inputs_eval.to(device), labels_eval.to(device)
          outputs_eval, _  = model(inputs_eval)
          loss_eval = criterion(outputs_eval, labels_eval)
          eval_loss += loss_eval.item()
          probs = torch.softmax(outputs_eval, dim=1)[:,1].detach().cpu().numpy()
          eval_probs.extend(probs)
          eval_labels.extend(labels_eval.cpu().numpy())
        eval_loss_avg = eval_loss / len(evalloader)
        metrics['eval_loss'].append(eval_loss_avg)
        eval_auc = roc_auc_score(eval_labels, eval_probs)
        metrics['eval_auc'].append(eval_auc)
#

        # print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, AUC: {epoch_auc:.4f}")
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, AUC: {epoch_auc:.4f}, Eval Loss: {eval_loss_avg:.4f}, Eval AUC: {eval_auc:.4f}")

    return model, metrics

In [9]:
trained_model, metrics = train_model_new(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=30
)

Epoch 1/30: 100%|██████████| 391/391 [00:42<00:00,  9.16it/s]


Epoch [1/30], Loss: 0.6880, AUC: 0.5803, Eval Loss: 0.6764, Eval AUC: 0.6214


Epoch 2/30: 100%|██████████| 391/391 [00:40<00:00,  9.71it/s]


Epoch [2/30], Loss: 0.6160, AUC: 0.7204, Eval Loss: 0.6198, Eval AUC: 0.7300


Epoch 3/30: 100%|██████████| 391/391 [00:40<00:00,  9.72it/s]


Epoch [3/30], Loss: 0.5447, AUC: 0.7991, Eval Loss: 0.5839, Eval AUC: 0.7676


Epoch 4/30: 100%|██████████| 391/391 [00:40<00:00,  9.68it/s]


Epoch [4/30], Loss: 0.4777, AUC: 0.8527, Eval Loss: 0.5617, Eval AUC: 0.7999


Epoch 5/30: 100%|██████████| 391/391 [00:41<00:00,  9.41it/s]


Epoch [5/30], Loss: 0.4129, AUC: 0.8931, Eval Loss: 0.5809, Eval AUC: 0.7948


Epoch 6/30: 100%|██████████| 391/391 [00:41<00:00,  9.52it/s]


Epoch [6/30], Loss: 0.3538, AUC: 0.9228, Eval Loss: 0.5789, Eval AUC: 0.8130


Epoch 7/30: 100%|██████████| 391/391 [00:42<00:00,  9.17it/s]


Epoch [7/30], Loss: 0.2976, AUC: 0.9458, Eval Loss: 0.6409, Eval AUC: 0.7986


Epoch 8/30: 100%|██████████| 391/391 [00:44<00:00,  8.79it/s]


Epoch [8/30], Loss: 0.2401, AUC: 0.9649, Eval Loss: 0.6590, Eval AUC: 0.8186


Epoch 9/30: 100%|██████████| 391/391 [00:42<00:00,  9.16it/s]


Epoch [9/30], Loss: 0.1854, AUC: 0.9791, Eval Loss: 0.7367, Eval AUC: 0.8176


Epoch 10/30: 100%|██████████| 391/391 [00:40<00:00,  9.58it/s]


Epoch [10/30], Loss: 0.1515, AUC: 0.9858, Eval Loss: 0.8348, Eval AUC: 0.8241


Epoch 11/30: 100%|██████████| 391/391 [00:40<00:00,  9.66it/s]


Epoch [11/30], Loss: 0.1145, AUC: 0.9917, Eval Loss: 0.9192, Eval AUC: 0.8080


Epoch 12/30: 100%|██████████| 391/391 [00:40<00:00,  9.59it/s]


Epoch [12/30], Loss: 0.0864, AUC: 0.9952, Eval Loss: 1.0015, Eval AUC: 0.8149


Epoch 13/30: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s]


Epoch [13/30], Loss: 0.0698, AUC: 0.9968, Eval Loss: 1.0644, Eval AUC: 0.8184


Epoch 14/30: 100%|██████████| 391/391 [00:40<00:00,  9.54it/s]


Epoch [14/30], Loss: 0.0621, AUC: 0.9974, Eval Loss: 1.1766, Eval AUC: 0.7946


Epoch 15/30: 100%|██████████| 391/391 [00:39<00:00,  9.85it/s]


Epoch [15/30], Loss: 0.0632, AUC: 0.9972, Eval Loss: 1.2067, Eval AUC: 0.8072


Epoch 16/30: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s]


Epoch [16/30], Loss: 0.0525, AUC: 0.9980, Eval Loss: 1.3603, Eval AUC: 0.7811


Epoch 17/30: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s]


Epoch [17/30], Loss: 0.0499, AUC: 0.9980, Eval Loss: 1.3471, Eval AUC: 0.8238


Epoch 18/30: 100%|██████████| 391/391 [00:40<00:00,  9.65it/s]


Epoch [18/30], Loss: 0.0411, AUC: 0.9988, Eval Loss: 1.4263, Eval AUC: 0.8029


Epoch 19/30: 100%|██████████| 391/391 [00:39<00:00,  9.81it/s]


Epoch [19/30], Loss: 0.0368, AUC: 0.9989, Eval Loss: 1.4661, Eval AUC: 0.8121


Epoch 20/30: 100%|██████████| 391/391 [00:39<00:00,  9.85it/s]


Epoch [20/30], Loss: 0.0350, AUC: 0.9991, Eval Loss: 1.5012, Eval AUC: 0.8112


Epoch 21/30: 100%|██████████| 391/391 [00:40<00:00,  9.69it/s]


Epoch [21/30], Loss: 0.0331, AUC: 0.9991, Eval Loss: 1.5595, Eval AUC: 0.8204


Epoch 22/30: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s]


Epoch [22/30], Loss: 0.0374, AUC: 0.9989, Eval Loss: 1.5556, Eval AUC: 0.7976


Epoch 23/30: 100%|██████████| 391/391 [00:41<00:00,  9.40it/s]


Epoch [23/30], Loss: 0.0381, AUC: 0.9989, Eval Loss: 1.6340, Eval AUC: 0.7936


Epoch 24/30: 100%|██████████| 391/391 [00:39<00:00,  9.83it/s]


Epoch [24/30], Loss: 0.0388, AUC: 0.9988, Eval Loss: 1.6127, Eval AUC: 0.7974


Epoch 25/30: 100%|██████████| 391/391 [00:39<00:00,  9.80it/s]


Epoch [25/30], Loss: 0.0358, AUC: 0.9990, Eval Loss: 1.6460, Eval AUC: 0.8096


Epoch 26/30: 100%|██████████| 391/391 [00:39<00:00,  9.84it/s]


Epoch [26/30], Loss: 0.0321, AUC: 0.9992, Eval Loss: 1.7284, Eval AUC: 0.8063


Epoch 27/30: 100%|██████████| 391/391 [00:41<00:00,  9.48it/s]


Epoch [27/30], Loss: 0.0449, AUC: 0.9985, Eval Loss: 1.7123, Eval AUC: 0.8103


Epoch 28/30: 100%|██████████| 391/391 [00:41<00:00,  9.50it/s]


Epoch [28/30], Loss: 0.0283, AUC: 0.9994, Eval Loss: 1.7569, Eval AUC: 0.8172


Epoch 29/30: 100%|██████████| 391/391 [00:40<00:00,  9.76it/s]


Epoch [29/30], Loss: 0.0254, AUC: 0.9995, Eval Loss: 1.8060, Eval AUC: 0.8139


Epoch 30/30: 100%|██████████| 391/391 [00:39<00:00,  9.82it/s]


Epoch [30/30], Loss: 0.0299, AUC: 0.9993, Eval Loss: 1.8659, Eval AUC: 0.8207


In [11]:
class SimpleGRULayer(nn.Module):
  def __init__(self, input_size, hidden_size):
    super(SimpleGRULayer, self).__init__()

    self.Wxr = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whr = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.br = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.Wxz = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whz = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bz = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.Wxh = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bh = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.reset_parameters()

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.Wxr)
    nn.init.xavier_uniform_(self.Whr)
    nn.init.zeros_(self.br)

    nn.init.xavier_uniform_(self.Wxz)
    nn.init.xavier_uniform_(self.Whz)
    nn.init.zeros_(self.bz)

    nn.init.xavier_uniform_(self.Wxh)
    nn.init.xavier_uniform_(self.Whh)
    nn.init.zeros_(self.bh)

  def forward(self, inputs, h_):
    # inputs: (batch_size, input_size)

    r_ = torch.sigmoid(
       inputs @ self.Wxr.T +
       h_ @ self.Whr.T +
       self.br.T
    )

    z_ = torch.sigmoid(
       inputs @ self.Wxz.T +
       h_ @ self.Whz.T +
       self.bz.T
    )

    h_hat = torch.tanh(
        inputs @ self.Wxh.T +
        (r_ * h_) @ self.Whh.T  +
        self.bh.T
    )

    h_ = z_ * h_ + (1 - z_) * h_hat

    # outputs: (batch_size, hidden_size)

    return h_

In [14]:
 class BiGRU(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, embedding_size):
    super(BiGRU, self).__init__()

    self.hidden_size = hidden_size
    self.embed = nn.Embedding(input_size, embedding_size)

    self.front_rnn = SimpleGRULayer(embedding_size, hidden_size)
    self.back_rnn = SimpleGRULayer(embedding_size, hidden_size)

    self.fc = nn.Linear(2*hidden_size, output_size)


  def forward(self, inputs):
    hidden_front = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)
    hidden_back = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)

    # 1. embedding: (batch_size, seq_len, embed_dim)
    inputs = self.embed(inputs)

    # 2. transpose: (seq_len, batch_size, embed_dim)
    inputs = torch.transpose(inputs, 0, 1)

    front_outputs = []
    back_outputs = []

    # 3. x: (batch_size, embed_dim)
    for x in inputs:
      hidden_front = self.front_rnn(x, hidden_front)
      # batch_size, hidden_size
      front_outputs.append(hidden_front)

    for x in reversed(inputs):
      hidden_back = self.back_rnn(x, hidden_back)
      # batch_size, hidden_size
      back_outputs.append(hidden_back)

    back_outputs = back_outputs[::-1]

    # 4. concat

    hidden = torch.cat((hidden_front, hidden_back), dim=1)

    hidden = self.fc(hidden)

    return hidden, (front_outputs, back_outputs)

In [17]:
vocab = build_movie_vocab_chuncked("/content/aclImdb", 512)
print(f"vocab size: {len(vocab)}")
train_data = ImbdDataSet("/content/aclImdb", vocab, max_length=256, data_type='train')
test_data = ImbdDataSet("/content/aclImdb", vocab, max_length=256, data_type='test')
train_loader = DataLoader(train_data, batch_size=512, shuffle=True)
test_loader = DataLoader(test_data, batch_size=512, shuffle=False)

vocab size: 3396


In [18]:
model = BiGRU(len(vocab), 128, 2, 128 )  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda


In [None]:
trained_model, metrics = train_model_new(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=10
)

Epoch 1/10: 100%|██████████| 49/49 [00:38<00:00,  1.26it/s]


Epoch [1/10], Loss: 0.6794, AUC: 0.5929, Eval Loss: 0.6580, Eval AUC: 0.6481


Epoch 2/10: 100%|██████████| 49/49 [00:39<00:00,  1.26it/s]


Epoch [2/10], Loss: 0.5720, AUC: 0.7701, Eval Loss: 0.5480, Eval AUC: 0.8236


Epoch 3/10: 100%|██████████| 49/49 [00:39<00:00,  1.24it/s]


Epoch [3/10], Loss: 0.4453, AUC: 0.8749, Eval Loss: 0.4825, Eval AUC: 0.8523


Epoch 4/10: 100%|██████████| 49/49 [00:38<00:00,  1.26it/s]


Epoch [4/10], Loss: 0.3677, AUC: 0.9163, Eval Loss: 0.4191, Eval AUC: 0.8911


Epoch 5/10: 100%|██████████| 49/49 [00:39<00:00,  1.25it/s]


Epoch [5/10], Loss: 0.3112, AUC: 0.9405, Eval Loss: 0.4068, Eval AUC: 0.9042


Epoch 6/10:  92%|█████████▏| 45/49 [00:36<00:03,  1.22it/s]

In [12]:
class SimpleLSTMLayer(nn.Module):
  def __init__(self, input_size, hidden_size):
    super(SimpleLSTMLayer, self).__init__()

    self.Wxi = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bi = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.Wxf = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bf = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.Wxo = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Who = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bo = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.Wxc = nn.Parameter(torch.Tensor(hidden_size, input_size))
    self.Whc = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
    self.bc = nn.Parameter(torch.Tensor(hidden_size, 1))

    self.reset_parameters()

  def reset_parameters(self):
    nn.init.xavier_uniform_(self.Wxi)
    nn.init.xavier_uniform_(self.Whi)
    nn.init.zeros_(self.bi)

    nn.init.xavier_uniform_(self.Wxf)
    nn.init.xavier_uniform_(self.Whf)
    nn.init.zeros_(self.bf)

    nn.init.xavier_uniform_(self.Wxo)
    nn.init.xavier_uniform_(self.Who)
    nn.init.zeros_(self.bo)

    nn.init.xavier_uniform_(self.Wxc)
    nn.init.xavier_uniform_(self.Whc)
    nn.init.zeros_(self.bc)

  def forward(self, inputs, h_, c_):
    # inputs: (batch_size, input_size)

    i_ = torch.sigmoid(
       inputs @ self.Wxi.T +
       h_ @ self.Whi.T +
       self.bi.T
    )

    f_ = torch.sigmoid(
       inputs @ self.Wxf.T +
       h_ @ self.Whf.T +
       self.bf.T
    )

    o_ = torch.sigmoid(
       inputs @ self.Wxo.T +
       h_ @ self.Who.T +
       self.bo.T
    )

    c_hat = torch.tanh(
        inputs @ self.Wxc.T +
        h_ @ self.Whc.T +
        self.bc.T
    )

    c_ = f_ * c_ + i_ * c_hat

    h_ = o_ * torch.tanh(c_)

    # outputs: (batch_size, hidden_size)

    return h_, c_

In [13]:
 class BiLSTM(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, embedding_size):
    super(BiLSTM, self).__init__()

    self.hidden_size = hidden_size
    self.embed = nn.Embedding(input_size, embedding_size)

    self.front_rnn = SimpleLSTMLayer(embedding_size, hidden_size)
    self.back_rnn = SimpleLSTMLayer(embedding_size, hidden_size)

    self.fc = nn.Linear(2*hidden_size, output_size)


  def forward(self, inputs):
    hidden_front = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)
    hidden_back = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)

    cell_front = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)
    cell_back = torch.zeros(inputs.shape[0], self.hidden_size, device=inputs.device)

    # 1. embedding: (batch_size, seq_len, embed_dim)
    inputs = self.embed(inputs)

    # 2. transpose: (seq_len, batch_size, embed_dim)
    inputs = torch.transpose(inputs, 0, 1)

    front_outputs = []
    back_outputs = []

    front_cells = []
    back_cells = []

    # 3. x: (batch_size, embed_dim)
    for x in inputs:
      hidden_front, cell_front = self.front_rnn(x, hidden_front, cell_front)
      # batch_size, hidden_size
      front_outputs.append(hidden_front)
      front_cells.append(cell_front)

    for x in reversed(inputs):
      hidden_back, cell_back = self.back_rnn(x, hidden_back, cell_back)
      # batch_size, hidden_size
      back_outputs.append(hidden_back)
      back_cells.append(cell_back)

    back_outputs = back_outputs[::-1]

    # 4. concat

    hidden = torch.cat((hidden_front, hidden_back), dim=1)

    hidden = self.fc(hidden)

    return hidden, (front_outputs, back_outputs)

In [14]:
vocab = build_movie_vocab_chuncked("/content/aclImdb", 256)
print(f"vocab size: {len(vocab)}")
train_data = ImbdDataSet("/content/aclImdb", vocab, max_length=256, data_type='train')
test_data = ImbdDataSet("/content/aclImdb", vocab, max_length=256, data_type='test')
train_loader = DataLoader(train_data, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1024, shuffle=False)

vocab size: 5817


In [16]:
model = BiLSTM(len(vocab), 128, 2, 128 )  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda


In [17]:
trained_model, metrics = train_model_new(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=30
)

Epoch 1/30: 100%|██████████| 25/25 [00:25<00:00,  1.03s/it]


Epoch [1/30], Loss: 0.6884, AUC: 0.5525, Eval Loss: 0.6796, Eval AUC: 0.5959


Epoch 2/30: 100%|██████████| 25/25 [00:25<00:00,  1.01s/it]


Epoch [2/30], Loss: 0.6551, AUC: 0.6573, Eval Loss: 0.6221, Eval AUC: 0.7072


Epoch 3/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [3/30], Loss: 0.5521, AUC: 0.7926, Eval Loss: 0.5249, Eval AUC: 0.8167


Epoch 4/30: 100%|██████████| 25/25 [00:25<00:00,  1.02s/it]


Epoch [4/30], Loss: 0.4817, AUC: 0.8508, Eval Loss: 0.4937, Eval AUC: 0.8488


Epoch 5/30: 100%|██████████| 25/25 [00:25<00:00,  1.02s/it]


Epoch [5/30], Loss: 0.4070, AUC: 0.8970, Eval Loss: 0.4731, Eval AUC: 0.8729


Epoch 6/30: 100%|██████████| 25/25 [00:24<00:00,  1.00it/s]


Epoch [6/30], Loss: 0.3668, AUC: 0.9165, Eval Loss: 0.4292, Eval AUC: 0.8855


Epoch 7/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [7/30], Loss: 0.3363, AUC: 0.9296, Eval Loss: 0.4212, Eval AUC: 0.8927


Epoch 8/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [8/30], Loss: 0.3060, AUC: 0.9416, Eval Loss: 0.4249, Eval AUC: 0.8980


Epoch 9/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [9/30], Loss: 0.2594, AUC: 0.9581, Eval Loss: 0.4088, Eval AUC: 0.9031


Epoch 10/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [10/30], Loss: 0.2283, AUC: 0.9674, Eval Loss: 0.4164, Eval AUC: 0.9023


Epoch 11/30: 100%|██████████| 25/25 [00:24<00:00,  1.03it/s]


Epoch [11/30], Loss: 0.2172, AUC: 0.9707, Eval Loss: 0.4339, Eval AUC: 0.9038


Epoch 12/30: 100%|██████████| 25/25 [00:25<00:00,  1.03s/it]


Epoch [12/30], Loss: 0.1733, AUC: 0.9805, Eval Loss: 0.4465, Eval AUC: 0.9022


Epoch 13/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [13/30], Loss: 0.1642, AUC: 0.9822, Eval Loss: 0.4675, Eval AUC: 0.9002


Epoch 14/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [14/30], Loss: 0.1281, AUC: 0.9883, Eval Loss: 0.4690, Eval AUC: 0.9031


Epoch 15/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [15/30], Loss: 0.1180, AUC: 0.9900, Eval Loss: 0.5142, Eval AUC: 0.9011


Epoch 16/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [16/30], Loss: 0.1078, AUC: 0.9914, Eval Loss: 0.5767, Eval AUC: 0.8951


Epoch 17/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [17/30], Loss: 0.1049, AUC: 0.9921, Eval Loss: 0.5287, Eval AUC: 0.9000


Epoch 18/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [18/30], Loss: 0.0935, AUC: 0.9932, Eval Loss: 0.6235, Eval AUC: 0.8944


Epoch 19/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [19/30], Loss: 0.0829, AUC: 0.9944, Eval Loss: 0.6128, Eval AUC: 0.8975


Epoch 20/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [20/30], Loss: 0.0613, AUC: 0.9963, Eval Loss: 0.6159, Eval AUC: 0.8980


Epoch 21/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [21/30], Loss: 0.0679, AUC: 0.9960, Eval Loss: 0.6419, Eval AUC: 0.8953


Epoch 22/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [22/30], Loss: 0.0618, AUC: 0.9966, Eval Loss: 0.6494, Eval AUC: 0.8960


Epoch 23/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [23/30], Loss: 0.0527, AUC: 0.9974, Eval Loss: 0.6949, Eval AUC: 0.8904


Epoch 24/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [24/30], Loss: 0.0500, AUC: 0.9975, Eval Loss: 0.6648, Eval AUC: 0.8906


Epoch 25/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [25/30], Loss: 0.0382, AUC: 0.9981, Eval Loss: 0.8405, Eval AUC: 0.8867


Epoch 26/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [26/30], Loss: 0.0391, AUC: 0.9982, Eval Loss: 0.6427, Eval AUC: 0.8985


Epoch 27/30: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s]


Epoch [27/30], Loss: 0.0355, AUC: 0.9984, Eval Loss: 0.7479, Eval AUC: 0.8893


Epoch 28/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [28/30], Loss: 0.0282, AUC: 0.9988, Eval Loss: 0.8050, Eval AUC: 0.8944


Epoch 29/30: 100%|██████████| 25/25 [00:24<00:00,  1.02it/s]


Epoch [29/30], Loss: 0.0219, AUC: 0.9991, Eval Loss: 0.8323, Eval AUC: 0.8915


Epoch 30/30: 100%|██████████| 25/25 [00:24<00:00,  1.03it/s]


Epoch [30/30], Loss: 0.0260, AUC: 0.9990, Eval Loss: 0.7653, Eval AUC: 0.8890


In [20]:
def train_model_pkg(model, dataloader, evalloader, criterion, optimizer, device, scheduler=None, epochs=10):
    """
    训练模型的通用函数。

    参数：
    - model: 定义好的神经网络模型。
    - dataloader: 数据加载器（训练集）。
    - criterion: 损失函数。
    - optimizer: 优化器。
    - device: 训练设备（"cuda" 或 "cpu"）。
    - epochs: 训练轮数。

    返回：
    - model: 训练后的模型。
    - metrics: 包含训练过程中的损失和其他指标。
    """
    model.to(device)  # 将模型加载到设备
    metrics = {"loss": [], "auc": [], 'eval_loss': [], 'eval_auc': []}  # 记录每个 epoch 的损失

    for epoch in range(epochs):
        model.train()  # 设置模型为训练模式
        if scheduler:
          scheduler.step()
        epoch_loss = 0.0
        all_labels = []  # 存储真实标签
        all_probs = []  # 存储预测概率
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        idx = 0
        for batch in progress_bar:
            idx += 1
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            # 前向传播
            outputs  = model(inputs)

            # outputs = outputs.squeeze(-1)
            loss = criterion(outputs, labels)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()

            grad_clipping(model, 1)

            optimizer.step()

            # 累加损失
            epoch_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()  # 假设二分类，取第二类概率
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())


        # 记录每个 epoch 的平均损失
        avg_loss = epoch_loss / len(dataloader)
        metrics["loss"].append(avg_loss)
        epoch_auc = roc_auc_score(all_labels, all_probs)
        metrics["auc"].append(epoch_auc)

        model.eval()
        eval_loss = 0.0
        eval_labels = []  # 存储真实标签
        eval_probs = []  # 存储预测概率

        for batch_eval in evalloader:
          inputs_eval, labels_eval = batch_eval
          inputs_eval, labels_eval = inputs_eval.to(device), labels_eval.to(device)
          outputs_eval  = model(inputs_eval)
          loss_eval = criterion(outputs_eval, labels_eval)
          eval_loss += loss_eval.item()
          probs = torch.softmax(outputs_eval, dim=1)[:,1].detach().cpu().numpy()
          eval_probs.extend(probs)
          eval_labels.extend(labels_eval.cpu().numpy())
        eval_loss_avg = eval_loss / len(evalloader)
        metrics['eval_loss'].append(eval_loss_avg)
        eval_auc = roc_auc_score(eval_labels, eval_probs)
        metrics['eval_auc'].append(eval_auc)
#

        # print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, AUC: {epoch_auc:.4f}")
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, AUC: {epoch_auc:.4f}, Eval Loss: {eval_loss_avg:.4f}, Eval AUC: {eval_auc:.4f}")

    return model, metrics

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleGRUPkg(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, output_size, batch_size,
                 dropout_prob=0.2, bidirectional=False):
        super(SimpleGRUPkg, self).__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embed_dim,
                                      padding_idx=0)

        self.rnn = nn.GRU(batch_size, hidden_size, batch_first=True, dropout=dropout_prob,
                          bidirectional=bidirectional)


        if bidirectional:
          self.fc = nn.Linear(2*hidden_size, output_size)
        else:
          self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        """
        inputs shape: (batch_size, seq_len)
        返回:
          - logits: (batch_size, output_size) 只在最后时间步输出
          - h: (batch_size, hidden_size) 最后时间步的隐藏状态
        """
        # 0. inputs: (batch_size, seq_len)
        device = inputs.device

        # 1. embedding: (batch_size, seq_len, embed_dim)
        inputs = self.embedding(inputs)

        # 2. (batch_size, seq_len, hidden_size)
        inputs, _ = self.rnn(inputs)

        # 3. (batch_size, 1, hidden_size)
        outputs = self.fc(inputs[:,-1,:])

        return outputs

In [19]:
model = SimpleGRUPkg(len(vocab), 128, 64, 2, batch_size=128, bidirectional=True)  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda




In [21]:
trained_model, metrics = train_model_pkg(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=30
)

Epoch 1/30: 100%|██████████| 25/25 [00:06<00:00,  4.10it/s]


Epoch [1/30], Loss: 0.6938, AUC: 0.5116, Eval Loss: 0.6933, Eval AUC: 0.5120


Epoch 2/30: 100%|██████████| 25/25 [00:05<00:00,  4.25it/s]


Epoch [2/30], Loss: 0.6879, AUC: 0.5471, Eval Loss: 0.6934, Eval AUC: 0.5221


Epoch 3/30: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


Epoch [3/30], Loss: 0.6831, AUC: 0.5644, Eval Loss: 0.6933, Eval AUC: 0.5337


Epoch 4/30: 100%|██████████| 25/25 [00:06<00:00,  3.96it/s]


Epoch [4/30], Loss: 0.6757, AUC: 0.5847, Eval Loss: 0.6975, Eval AUC: 0.5654


Epoch 5/30: 100%|██████████| 25/25 [00:06<00:00,  3.71it/s]


Epoch [5/30], Loss: 0.6634, AUC: 0.6209, Eval Loss: 0.7205, Eval AUC: 0.6340


Epoch 6/30: 100%|██████████| 25/25 [00:06<00:00,  3.65it/s]


Epoch [6/30], Loss: 0.6146, AUC: 0.7218, Eval Loss: 0.6160, Eval AUC: 0.7329


Epoch 7/30: 100%|██████████| 25/25 [00:06<00:00,  3.63it/s]


Epoch [7/30], Loss: 0.5415, AUC: 0.8004, Eval Loss: 0.5784, Eval AUC: 0.7972


Epoch 8/30: 100%|██████████| 25/25 [00:06<00:00,  3.79it/s]


Epoch [8/30], Loss: 0.5202, AUC: 0.8129, Eval Loss: 0.5328, Eval AUC: 0.8242


Epoch 9/30: 100%|██████████| 25/25 [00:05<00:00,  4.21it/s]


Epoch [9/30], Loss: 0.5131, AUC: 0.8166, Eval Loss: 0.6515, Eval AUC: 0.7868


Epoch 10/30: 100%|██████████| 25/25 [00:05<00:00,  4.24it/s]


Epoch [10/30], Loss: 0.5102, AUC: 0.8200, Eval Loss: 0.5215, Eval AUC: 0.8335


Epoch 11/30: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


Epoch [11/30], Loss: 0.4680, AUC: 0.8456, Eval Loss: 0.5104, Eval AUC: 0.8416


Epoch 12/30: 100%|██████████| 25/25 [00:05<00:00,  4.18it/s]


Epoch [12/30], Loss: 0.4644, AUC: 0.8488, Eval Loss: 0.5707, Eval AUC: 0.8386


Epoch 13/30: 100%|██████████| 25/25 [00:06<00:00,  3.58it/s]


Epoch [13/30], Loss: 0.4521, AUC: 0.8544, Eval Loss: 0.4852, Eval AUC: 0.8644


Epoch 14/30: 100%|██████████| 25/25 [00:06<00:00,  3.69it/s]


Epoch [14/30], Loss: 0.4030, AUC: 0.8899, Eval Loss: 0.5224, Eval AUC: 0.8497


Epoch 15/30: 100%|██████████| 25/25 [00:06<00:00,  3.65it/s]


Epoch [15/30], Loss: 0.3768, AUC: 0.9056, Eval Loss: 0.4495, Eval AUC: 0.8893


Epoch 16/30: 100%|██████████| 25/25 [00:06<00:00,  3.71it/s]


Epoch [16/30], Loss: 0.3332, AUC: 0.9289, Eval Loss: 0.4159, Eval AUC: 0.9022


Epoch 17/30: 100%|██████████| 25/25 [00:05<00:00,  4.25it/s]


Epoch [17/30], Loss: 0.3042, AUC: 0.9394, Eval Loss: 0.4174, Eval AUC: 0.9059


Epoch 18/30: 100%|██████████| 25/25 [00:05<00:00,  4.23it/s]


Epoch [18/30], Loss: 0.2905, AUC: 0.9435, Eval Loss: 0.4254, Eval AUC: 0.9070


Epoch 19/30: 100%|██████████| 25/25 [00:05<00:00,  4.20it/s]


Epoch [19/30], Loss: 0.2615, AUC: 0.9540, Eval Loss: 0.3830, Eval AUC: 0.9148


Epoch 20/30: 100%|██████████| 25/25 [00:06<00:00,  4.10it/s]


Epoch [20/30], Loss: 0.2433, AUC: 0.9591, Eval Loss: 0.3923, Eval AUC: 0.9173


Epoch 21/30: 100%|██████████| 25/25 [00:06<00:00,  3.72it/s]


Epoch [21/30], Loss: 0.2287, AUC: 0.9633, Eval Loss: 0.4385, Eval AUC: 0.9153


Epoch 22/30: 100%|██████████| 25/25 [00:07<00:00,  3.54it/s]


Epoch [22/30], Loss: 0.2199, AUC: 0.9657, Eval Loss: 0.3822, Eval AUC: 0.9167


Epoch 23/30: 100%|██████████| 25/25 [00:06<00:00,  3.71it/s]


Epoch [23/30], Loss: 0.2051, AUC: 0.9694, Eval Loss: 0.4010, Eval AUC: 0.9189


Epoch 24/30: 100%|██████████| 25/25 [00:06<00:00,  3.80it/s]


Epoch [24/30], Loss: 0.1883, AUC: 0.9730, Eval Loss: 0.4075, Eval AUC: 0.9197


Epoch 25/30: 100%|██████████| 25/25 [00:06<00:00,  4.11it/s]


Epoch [25/30], Loss: 0.1747, AUC: 0.9755, Eval Loss: 0.4216, Eval AUC: 0.9155


Epoch 26/30: 100%|██████████| 25/25 [00:05<00:00,  4.24it/s]


Epoch [26/30], Loss: 0.1603, AUC: 0.9780, Eval Loss: 0.4769, Eval AUC: 0.9160


Epoch 27/30: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


Epoch [27/30], Loss: 0.1520, AUC: 0.9792, Eval Loss: 0.4687, Eval AUC: 0.9148


Epoch 28/30: 100%|██████████| 25/25 [00:06<00:00,  3.96it/s]


Epoch [28/30], Loss: 0.1406, AUC: 0.9807, Eval Loss: 0.4608, Eval AUC: 0.9107


Epoch 29/30: 100%|██████████| 25/25 [00:06<00:00,  3.73it/s]


Epoch [29/30], Loss: 0.1327, AUC: 0.9821, Eval Loss: 0.4678, Eval AUC: 0.9111


Epoch 30/30: 100%|██████████| 25/25 [00:06<00:00,  3.68it/s]


Epoch [30/30], Loss: 0.1311, AUC: 0.9828, Eval Loss: 0.4522, Eval AUC: 0.9084


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleRNNPkg(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, output_size, batch_size,
                 dropout_prob=0.2, bidirectional=False):
        super(SimpleRNNPkg, self).__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embed_dim,
                                      padding_idx=0)

        self.rnn = nn.RNN(batch_size, hidden_size, batch_first=True, dropout=dropout_prob,
                          bidirectional=bidirectional)


        if bidirectional:
          self.fc = nn.Linear(2*hidden_size, output_size)
        else:
          self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        """
        inputs shape: (batch_size, seq_len)
        返回:
          - logits: (batch_size, output_size) 只在最后时间步输出
          - h: (batch_size, hidden_size) 最后时间步的隐藏状态
        """
        # 0. inputs: (batch_size, seq_len)
        device = inputs.device

        # 1. embedding: (batch_size, seq_len, embed_dim)
        inputs = self.embedding(inputs)

        # 2. (batch_size, seq_len, hidden_size)
        inputs, _ = self.rnn(inputs)

        # 3. (batch_size, 1, hidden_size)
        outputs = self.fc(inputs[:,-1,:])

        return outputs

In [34]:
vocab = build_movie_vocab_chuncked("/content/aclImdb", 256)
print(f"vocab size: {len(vocab)}")
train_data = ImbdDataSet("/content/aclImdb", vocab, max_length=128, data_type='train')
test_data = ImbdDataSet("/content/aclImdb", vocab, max_length=128, data_type='test')
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

vocab size: 5817


In [35]:
model = SimpleRNNPkg(len(vocab), 128, 64, 2, batch_size=128, bidirectional=True)  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda




In [36]:
trained_model, metrics = train_model_pkg(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=30
)

Epoch 1/30: 100%|██████████| 391/391 [00:06<00:00, 64.07it/s]


Epoch [1/30], Loss: 0.6970, AUC: 0.5210, Eval Loss: 0.6926, Eval AUC: 0.5558


Epoch 2/30: 100%|██████████| 391/391 [00:05<00:00, 70.09it/s]


Epoch [2/30], Loss: 0.6844, AUC: 0.5763, Eval Loss: 0.6936, Eval AUC: 0.5340


Epoch 3/30: 100%|██████████| 391/391 [00:05<00:00, 76.49it/s]


Epoch [3/30], Loss: 0.6788, AUC: 0.5956, Eval Loss: 0.6963, Eval AUC: 0.5511


Epoch 4/30: 100%|██████████| 391/391 [00:05<00:00, 76.13it/s]


Epoch [4/30], Loss: 0.6591, AUC: 0.6453, Eval Loss: 0.6699, Eval AUC: 0.6492


Epoch 5/30: 100%|██████████| 391/391 [00:06<00:00, 63.94it/s]


Epoch [5/30], Loss: 0.6182, AUC: 0.7157, Eval Loss: 0.6598, Eval AUC: 0.6797


Epoch 6/30: 100%|██████████| 391/391 [00:06<00:00, 65.10it/s]


Epoch [6/30], Loss: 0.5878, AUC: 0.7541, Eval Loss: 0.6566, Eval AUC: 0.6732


Epoch 7/30: 100%|██████████| 391/391 [00:05<00:00, 75.11it/s]


Epoch [7/30], Loss: 0.5601, AUC: 0.7826, Eval Loss: 0.6634, Eval AUC: 0.6862


Epoch 8/30: 100%|██████████| 391/391 [00:05<00:00, 75.59it/s]


Epoch [8/30], Loss: 0.5237, AUC: 0.8152, Eval Loss: 0.6782, Eval AUC: 0.6915


Epoch 9/30: 100%|██████████| 391/391 [00:05<00:00, 65.46it/s]


Epoch [9/30], Loss: 0.4943, AUC: 0.8381, Eval Loss: 0.6916, Eval AUC: 0.7005


Epoch 10/30: 100%|██████████| 391/391 [00:06<00:00, 64.42it/s]


Epoch [10/30], Loss: 0.4671, AUC: 0.8579, Eval Loss: 0.7056, Eval AUC: 0.6944


Epoch 11/30: 100%|██████████| 391/391 [00:05<00:00, 73.59it/s]


Epoch [11/30], Loss: 0.4418, AUC: 0.8746, Eval Loss: 0.7240, Eval AUC: 0.6824


Epoch 12/30: 100%|██████████| 391/391 [00:05<00:00, 74.29it/s]


Epoch [12/30], Loss: 0.4131, AUC: 0.8915, Eval Loss: 0.7842, Eval AUC: 0.6757


Epoch 13/30: 100%|██████████| 391/391 [00:06<00:00, 61.16it/s]


Epoch [13/30], Loss: 0.3827, AUC: 0.9075, Eval Loss: 0.8005, Eval AUC: 0.6955


Epoch 14/30: 100%|██████████| 391/391 [00:06<00:00, 62.26it/s]


Epoch [14/30], Loss: 0.3561, AUC: 0.9205, Eval Loss: 0.8450, Eval AUC: 0.6930


Epoch 15/30: 100%|██████████| 391/391 [00:06<00:00, 63.02it/s]


Epoch [15/30], Loss: 0.3306, AUC: 0.9318, Eval Loss: 0.8792, Eval AUC: 0.7036


Epoch 16/30: 100%|██████████| 391/391 [00:05<00:00, 73.75it/s]


Epoch [16/30], Loss: 0.3086, AUC: 0.9408, Eval Loss: 0.9395, Eval AUC: 0.6835


Epoch 17/30: 100%|██████████| 391/391 [00:05<00:00, 74.45it/s]


Epoch [17/30], Loss: 0.2867, AUC: 0.9488, Eval Loss: 1.0179, Eval AUC: 0.6814


Epoch 18/30: 100%|██████████| 391/391 [00:05<00:00, 65.86it/s]


Epoch [18/30], Loss: 0.2643, AUC: 0.9564, Eval Loss: 1.0996, Eval AUC: 0.6661


Epoch 19/30: 100%|██████████| 391/391 [00:06<00:00, 62.54it/s]


Epoch [19/30], Loss: 0.2509, AUC: 0.9606, Eval Loss: 1.1437, Eval AUC: 0.6597


Epoch 20/30: 100%|██████████| 391/391 [00:05<00:00, 71.73it/s]


Epoch [20/30], Loss: 0.2304, AUC: 0.9668, Eval Loss: 1.2106, Eval AUC: 0.6672


Epoch 21/30: 100%|██████████| 391/391 [00:05<00:00, 73.71it/s]


Epoch [21/30], Loss: 0.2146, AUC: 0.9711, Eval Loss: 1.2633, Eval AUC: 0.6724


Epoch 22/30: 100%|██████████| 391/391 [00:05<00:00, 73.03it/s]


Epoch [22/30], Loss: 0.2020, AUC: 0.9747, Eval Loss: 1.3929, Eval AUC: 0.6556


Epoch 23/30: 100%|██████████| 391/391 [00:06<00:00, 63.51it/s]


Epoch [23/30], Loss: 0.1923, AUC: 0.9770, Eval Loss: 1.4547, Eval AUC: 0.6583


Epoch 24/30: 100%|██████████| 391/391 [00:06<00:00, 64.54it/s]


Epoch [24/30], Loss: 0.1917, AUC: 0.9769, Eval Loss: 1.5511, Eval AUC: 0.6461


Epoch 25/30: 100%|██████████| 391/391 [00:05<00:00, 73.09it/s]


Epoch [25/30], Loss: 0.1787, AUC: 0.9798, Eval Loss: 1.5874, Eval AUC: 0.6399


Epoch 26/30: 100%|██████████| 391/391 [00:05<00:00, 74.26it/s]


Epoch [26/30], Loss: 0.1714, AUC: 0.9816, Eval Loss: 1.6534, Eval AUC: 0.6426


Epoch 27/30: 100%|██████████| 391/391 [00:05<00:00, 65.20it/s]


Epoch [27/30], Loss: 0.1677, AUC: 0.9822, Eval Loss: 1.7119, Eval AUC: 0.6428


Epoch 28/30: 100%|██████████| 391/391 [00:06<00:00, 62.63it/s]


Epoch [28/30], Loss: 0.1564, AUC: 0.9845, Eval Loss: 1.7466, Eval AUC: 0.6544


Epoch 29/30: 100%|██████████| 391/391 [00:05<00:00, 72.63it/s]


Epoch [29/30], Loss: 0.1528, AUC: 0.9852, Eval Loss: 1.8289, Eval AUC: 0.6568


Epoch 30/30: 100%|██████████| 391/391 [00:05<00:00, 75.16it/s]


Epoch [30/30], Loss: 0.1509, AUC: 0.9856, Eval Loss: 1.8930, Eval AUC: 0.6507


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleLSTMPkg(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, output_size, batch_size,
                 dropout_prob=0.2, bidirectional=False):
        super(SimpleLSTMPkg, self).__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embed_dim,
                                      padding_idx=0)

        self.rnn = nn.LSTM(batch_size, hidden_size, batch_first=True, dropout=dropout_prob, bidirectional=bidirectional)


        if bidirectional:
          self.fc = nn.Linear(2*hidden_size, output_size)
        else:
          self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        """
        inputs shape: (batch_size, seq_len)
        返回:
          - logits: (batch_size, output_size) 只在最后时间步输出
          - h: (batch_size, hidden_size) 最后时间步的隐藏状态
        """
        # 0. inputs: (batch_size, seq_len)
        device = inputs.device

        # 1. embedding: (batch_size, seq_len, embed_dim)
        inputs = self.embedding(inputs)

        # 2. (batch_size, seq_len, hidden_size)
        inputs, _ = self.rnn(inputs)

        # 3. (batch_size, 1, hidden_size)
        outputs = self.fc(inputs[:,-1,:])

        return outputs

In [27]:
model = SimpleLSTMPkg(len(vocab), 128, 64, 2, batch_size=128, bidirectional=True)  # 定义的网络结构
criterion = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设置设备
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print(device)

cuda




In [28]:
trained_model, metrics = train_model_pkg(
    model=model,
    dataloader=train_loader,
    evalloader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    epochs=10
)

Epoch 1/10: 100%|██████████| 25/25 [00:06<00:00,  3.77it/s]


Epoch [1/10], Loss: 0.6933, AUC: 0.5072, Eval Loss: 0.6936, Eval AUC: 0.5061


Epoch 2/10: 100%|██████████| 25/25 [00:06<00:00,  3.71it/s]


Epoch [2/10], Loss: 0.6889, AUC: 0.5404, Eval Loss: 0.6931, Eval AUC: 0.5135


Epoch 3/10: 100%|██████████| 25/25 [00:06<00:00,  3.67it/s]


Epoch [3/10], Loss: 0.6846, AUC: 0.5615, Eval Loss: 0.6933, Eval AUC: 0.5309


Epoch 4/10: 100%|██████████| 25/25 [00:06<00:00,  3.92it/s]


Epoch [4/10], Loss: 0.6765, AUC: 0.5856, Eval Loss: 0.6995, Eval AUC: 0.5973


Epoch 5/10: 100%|██████████| 25/25 [00:06<00:00,  4.15it/s]


Epoch [5/10], Loss: 0.6649, AUC: 0.6167, Eval Loss: 0.6949, Eval AUC: 0.5932


Epoch 6/10: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


Epoch [6/10], Loss: 0.6557, AUC: 0.6550, Eval Loss: 0.6701, Eval AUC: 0.6425


Epoch 7/10: 100%|██████████| 25/25 [00:05<00:00,  4.26it/s]


Epoch [7/10], Loss: 0.6225, AUC: 0.7128, Eval Loss: 0.6535, Eval AUC: 0.6923


Epoch 8/10: 100%|██████████| 25/25 [00:06<00:00,  4.06it/s]


Epoch [8/10], Loss: 0.6374, AUC: 0.6818, Eval Loss: 0.6391, Eval AUC: 0.6874


Epoch 9/10: 100%|██████████| 25/25 [00:06<00:00,  3.93it/s]


Epoch [9/10], Loss: 0.5836, AUC: 0.7456, Eval Loss: 0.6323, Eval AUC: 0.7003


Epoch 10/10: 100%|██████████| 25/25 [00:06<00:00,  3.69it/s]


Epoch [10/10], Loss: 0.5726, AUC: 0.7565, Eval Loss: 0.6265, Eval AUC: 0.7040
