<a href="https://colab.research.google.com/github/chriskwon9/deeplearning/blob/main/kyungheenlp_0923_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim

from collections import defaultdict
import time
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

In [21]:
# Functions to read in the corpus
w2i = defaultdict(lambda: len(w2i))
t2i = defaultdict(lambda: len(t2i))
UNK = w2i["<unk>"]

In [22]:
def read_dataset(filename):
  with open(filename, "r") as f:
    for line in f:
      tag, words = line.lower().strip().split(" ||| ")
      yield ([w2i[x] for x in words.split(" ")], t2i[tag])

In [23]:
from google.colab import files
uploaded = files.upload()

Saving train.txt to train (1).txt


In [24]:
from google.colab import files
uploaded = files.upload()

Saving test.txt to test (1).txt


In [25]:
# Read in the data
train_data = list(read_dataset("train.txt"))
w2i = defaultdict(lambda: UNK, w2i)  # Now UNK is used for unknown words
dev_data = list(read_dataset("test.txt"))
nwords = len(w2i)
ntags = len(t2i)

In [26]:
# Custom dataset class to work with DataLoader
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [27]:
# Collate function for padding sentences in each batch
def collate_fn(batch):
    sentences, tags = zip(*batch)
    sentences = [torch.tensor(sent) for sent in sentences]
    tags = torch.tensor(tags)
    padded_sentences = pad_sequence(sentences, batch_first=True, padding_value=UNK)
    return padded_sentences, tags

In [28]:
# Define DataLoader
BATCH_SIZE = 32
train_loader = DataLoader(TextDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(TextDataset(dev_data), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [29]:
# Define the model
EMB_SIZE = 64
HID_SIZE = 64

In [30]:
class RNNModel(nn.Module):
    def __init__(self, nwords, ntags, emb_size, hidden_size):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(nwords, emb_size)
        self.rnn = nn.RNN(emb_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, ntags)

        # Initialize weights
        self.init_weights()

    def init_weights(self):
        # Apply Xavier initialization to all linear and recurrent layers
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)  # Initialize biases to zero

    def forward(self, sentences):
        embeds = self.embedding(sentences)  # [batch_size x len(sentences[1]) x emb_size]
        rnn_out, _ = self.rnn(embeds)       # [batch_size x len(sentences[1]) x hidden_size]
        logits = self.fc(rnn_out[:, -1, :]) # Use the last hidden state for classification
        return logits

model = RNNModel(nwords, ntags, EMB_SIZE, HID_SIZE)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [33]:
# Training loop
max_test_accuracy = 0.0
for ITER in range(100):
    # Perform training
    train_loss = 0.0
    start = time.time()
    model.train()

    for sentences, tags in tqdm(train_loader):
        logits = model(sentences)
        loss = criterion(logits, tags)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"iter {ITER}: train loss/sent={train_loss / len(train_loader):.4f}, time={time.time() - start:.2f}s")

    # Perform evaluation
    model.eval()
    test_correct = 0.0
    with torch.no_grad():
        for sentences, tags in tqdm(dev_loader):
            logits = model(sentences)
            predict = torch.argmax(logits, dim=1)
            test_correct += (predict == tags).sum().item()

    test_accuracy = test_correct / len(dev_data)
    if max_test_accuracy < test_accuracy:
        max_test_accuracy = test_accuracy
    print(f"iter {ITER}: test acc={test_accuracy:.4f}")

print("max test acc=%.4f" % (max_test_accuracy))

100%|██████████| 267/267 [00:06<00:00, 43.16it/s]


iter 0: train loss/sent=1.5762, time=6.19s


100%|██████████| 70/70 [00:00<00:00, 321.54it/s]


iter 0: test acc=0.2443


100%|██████████| 267/267 [00:04<00:00, 57.70it/s]


iter 1: train loss/sent=1.5684, time=4.64s


100%|██████████| 70/70 [00:00<00:00, 433.41it/s]


iter 1: test acc=0.2724


100%|██████████| 267/267 [00:04<00:00, 60.24it/s]


iter 2: train loss/sent=1.5651, time=4.44s


100%|██████████| 70/70 [00:00<00:00, 426.94it/s]


iter 2: test acc=0.2303


100%|██████████| 267/267 [00:05<00:00, 49.38it/s]


iter 3: train loss/sent=1.5600, time=5.42s


100%|██████████| 70/70 [00:00<00:00, 454.88it/s]


iter 3: test acc=0.2312


100%|██████████| 267/267 [00:04<00:00, 55.91it/s]


iter 4: train loss/sent=1.5540, time=4.79s


100%|██████████| 70/70 [00:00<00:00, 458.23it/s]


iter 4: test acc=0.2674


100%|██████████| 267/267 [00:06<00:00, 43.95it/s]


iter 5: train loss/sent=1.5495, time=6.08s


100%|██████████| 70/70 [00:00<00:00, 207.00it/s]


iter 5: test acc=0.2525


100%|██████████| 267/267 [00:04<00:00, 57.61it/s]


iter 6: train loss/sent=1.5476, time=4.65s


100%|██████████| 70/70 [00:00<00:00, 443.65it/s]


iter 6: test acc=0.2747


100%|██████████| 267/267 [00:05<00:00, 51.85it/s]


iter 7: train loss/sent=1.5444, time=5.16s


100%|██████████| 70/70 [00:00<00:00, 348.58it/s]


iter 7: test acc=0.2398


100%|██████████| 267/267 [00:05<00:00, 50.22it/s]


iter 8: train loss/sent=1.5416, time=5.32s


100%|██████████| 70/70 [00:00<00:00, 467.21it/s]


iter 8: test acc=0.2339


100%|██████████| 267/267 [00:04<00:00, 54.12it/s]


iter 9: train loss/sent=1.5430, time=4.94s


100%|██████████| 70/70 [00:00<00:00, 461.96it/s]


iter 9: test acc=0.2376


100%|██████████| 267/267 [00:05<00:00, 47.68it/s]


iter 10: train loss/sent=1.5397, time=5.61s


100%|██████████| 70/70 [00:00<00:00, 475.14it/s]


iter 10: test acc=0.2471


100%|██████████| 267/267 [00:04<00:00, 54.92it/s]


iter 11: train loss/sent=1.5363, time=4.87s


100%|██████████| 70/70 [00:00<00:00, 441.45it/s]


iter 11: test acc=0.2348


100%|██████████| 267/267 [00:05<00:00, 47.65it/s]


iter 12: train loss/sent=1.5348, time=5.61s


100%|██████████| 70/70 [00:00<00:00, 425.03it/s]


iter 12: test acc=0.2742


100%|██████████| 267/267 [00:04<00:00, 58.51it/s]


iter 13: train loss/sent=1.5355, time=4.57s


100%|██████████| 70/70 [00:00<00:00, 490.56it/s]


iter 13: test acc=0.2787


100%|██████████| 267/267 [00:04<00:00, 58.69it/s]


iter 14: train loss/sent=1.5357, time=4.56s


100%|██████████| 70/70 [00:00<00:00, 478.02it/s]


iter 14: test acc=0.2317


100%|██████████| 267/267 [00:05<00:00, 49.44it/s]


iter 15: train loss/sent=1.5365, time=5.41s


100%|██████████| 70/70 [00:00<00:00, 454.17it/s]


iter 15: test acc=0.2769


100%|██████████| 267/267 [00:04<00:00, 58.69it/s]


iter 16: train loss/sent=1.5339, time=4.56s


100%|██████████| 70/70 [00:00<00:00, 486.68it/s]


iter 16: test acc=0.2801


100%|██████████| 267/267 [00:05<00:00, 50.02it/s]


iter 17: train loss/sent=1.5294, time=5.35s


100%|██████████| 70/70 [00:00<00:00, 447.50it/s]


iter 17: test acc=0.2330


100%|██████████| 267/267 [00:04<00:00, 56.18it/s]


iter 18: train loss/sent=1.5270, time=4.76s


100%|██████████| 70/70 [00:00<00:00, 493.41it/s]


iter 18: test acc=0.2480


100%|██████████| 267/267 [00:05<00:00, 53.39it/s]


iter 19: train loss/sent=1.5315, time=5.01s


100%|██████████| 70/70 [00:00<00:00, 322.41it/s]


iter 19: test acc=0.2525


100%|██████████| 267/267 [00:05<00:00, 49.52it/s]


iter 20: train loss/sent=1.5252, time=5.40s


100%|██████████| 70/70 [00:00<00:00, 480.18it/s]


iter 20: test acc=0.2389


100%|██████████| 267/267 [00:04<00:00, 54.53it/s]


iter 21: train loss/sent=1.5233, time=4.90s


100%|██████████| 70/70 [00:00<00:00, 440.47it/s]


iter 21: test acc=0.2330


100%|██████████| 267/267 [00:05<00:00, 48.31it/s]


iter 22: train loss/sent=1.5193, time=5.54s


100%|██████████| 70/70 [00:00<00:00, 450.41it/s]


iter 22: test acc=0.2480


100%|██████████| 267/267 [00:04<00:00, 56.83it/s]


iter 23: train loss/sent=1.5105, time=4.70s


100%|██████████| 70/70 [00:00<00:00, 486.94it/s]


iter 23: test acc=0.2706


100%|██████████| 267/267 [00:05<00:00, 52.11it/s]


iter 24: train loss/sent=1.5063, time=5.13s


100%|██████████| 70/70 [00:00<00:00, 337.35it/s]


iter 24: test acc=0.2353


100%|██████████| 267/267 [00:04<00:00, 54.88it/s]


iter 25: train loss/sent=1.5168, time=4.87s


100%|██████████| 70/70 [00:00<00:00, 452.60it/s]


iter 25: test acc=0.2362


100%|██████████| 267/267 [00:04<00:00, 56.52it/s]


iter 26: train loss/sent=1.4904, time=4.73s


100%|██████████| 70/70 [00:00<00:00, 446.55it/s]


iter 26: test acc=0.2534


100%|██████████| 267/267 [00:05<00:00, 49.78it/s]


iter 27: train loss/sent=1.4491, time=5.37s


100%|██████████| 70/70 [00:00<00:00, 483.35it/s]


iter 27: test acc=0.2579


100%|██████████| 267/267 [00:04<00:00, 55.74it/s]


iter 28: train loss/sent=1.4183, time=4.80s


100%|██████████| 70/70 [00:00<00:00, 407.09it/s]


iter 28: test acc=0.2443


100%|██████████| 267/267 [00:05<00:00, 50.32it/s]


iter 29: train loss/sent=1.3616, time=5.31s


100%|██████████| 70/70 [00:00<00:00, 319.76it/s]


iter 29: test acc=0.2308


100%|██████████| 267/267 [00:04<00:00, 57.48it/s]


iter 30: train loss/sent=1.3248, time=4.65s


100%|██████████| 70/70 [00:00<00:00, 452.93it/s]


iter 30: test acc=0.2529


100%|██████████| 267/267 [00:04<00:00, 56.22it/s]


iter 31: train loss/sent=1.2655, time=4.76s


100%|██████████| 70/70 [00:00<00:00, 446.51it/s]


iter 31: test acc=0.2629


100%|██████████| 267/267 [00:05<00:00, 47.99it/s]


iter 32: train loss/sent=1.2001, time=5.57s


100%|██████████| 70/70 [00:00<00:00, 448.96it/s]


iter 32: test acc=0.2462


100%|██████████| 267/267 [00:04<00:00, 56.85it/s]


iter 33: train loss/sent=1.1311, time=4.70s


100%|██████████| 70/70 [00:00<00:00, 426.70it/s]


iter 33: test acc=0.2362


100%|██████████| 267/267 [00:05<00:00, 48.52it/s]


iter 34: train loss/sent=1.1153, time=5.51s


100%|██████████| 70/70 [00:00<00:00, 465.71it/s]


iter 34: test acc=0.2434


100%|██████████| 267/267 [00:04<00:00, 57.69it/s]


iter 35: train loss/sent=1.0607, time=4.63s


100%|██████████| 70/70 [00:00<00:00, 473.44it/s]


iter 35: test acc=0.2452


100%|██████████| 267/267 [00:04<00:00, 57.25it/s]


iter 36: train loss/sent=0.9640, time=4.68s


100%|██████████| 70/70 [00:00<00:00, 352.46it/s]


iter 36: test acc=0.2738


100%|██████████| 267/267 [00:05<00:00, 49.31it/s]


iter 37: train loss/sent=0.9625, time=5.42s


100%|██████████| 70/70 [00:00<00:00, 466.89it/s]


iter 37: test acc=0.2566


100%|██████████| 267/267 [00:04<00:00, 56.91it/s]


iter 38: train loss/sent=0.9197, time=4.70s


100%|██████████| 70/70 [00:00<00:00, 426.72it/s]


iter 38: test acc=0.2751


100%|██████████| 267/267 [00:05<00:00, 49.50it/s]


iter 39: train loss/sent=0.8433, time=5.40s


100%|██████████| 70/70 [00:00<00:00, 461.54it/s]


iter 39: test acc=0.2738


100%|██████████| 267/267 [00:04<00:00, 55.39it/s]


iter 40: train loss/sent=0.7570, time=4.83s


100%|██████████| 70/70 [00:00<00:00, 425.82it/s]


iter 40: test acc=0.2801


100%|██████████| 267/267 [00:05<00:00, 51.11it/s]


iter 41: train loss/sent=0.6985, time=5.23s


100%|██████████| 70/70 [00:00<00:00, 325.87it/s]


iter 41: test acc=0.2819


100%|██████████| 267/267 [00:05<00:00, 47.35it/s]


iter 42: train loss/sent=0.6893, time=5.64s


100%|██████████| 70/70 [00:00<00:00, 468.14it/s]


iter 42: test acc=0.2860


100%|██████████| 267/267 [00:05<00:00, 50.50it/s]


iter 43: train loss/sent=0.6887, time=5.29s


100%|██████████| 70/70 [00:00<00:00, 430.95it/s]


iter 43: test acc=0.2959


100%|██████████| 267/267 [00:06<00:00, 43.68it/s]


iter 44: train loss/sent=0.6867, time=6.12s


100%|██████████| 70/70 [00:00<00:00, 441.32it/s]


iter 44: test acc=0.2851


100%|██████████| 267/267 [00:05<00:00, 52.26it/s]


iter 45: train loss/sent=0.6131, time=5.12s


100%|██████████| 70/70 [00:00<00:00, 434.46it/s]


iter 45: test acc=0.2941


100%|██████████| 267/267 [00:05<00:00, 46.00it/s]


iter 46: train loss/sent=0.5536, time=5.81s


100%|██████████| 70/70 [00:00<00:00, 434.83it/s]


iter 46: test acc=0.2896


100%|██████████| 267/267 [00:04<00:00, 55.33it/s]


iter 47: train loss/sent=0.5113, time=4.83s


100%|██████████| 70/70 [00:00<00:00, 474.31it/s]


iter 47: test acc=0.2710


100%|██████████| 267/267 [00:05<00:00, 49.38it/s]


iter 48: train loss/sent=0.8337, time=5.42s


100%|██████████| 70/70 [00:00<00:00, 312.49it/s]


iter 48: test acc=0.3104


100%|██████████| 267/267 [00:04<00:00, 55.66it/s]


iter 49: train loss/sent=0.9096, time=4.80s


100%|██████████| 70/70 [00:00<00:00, 429.37it/s]


iter 49: test acc=0.2765


100%|██████████| 267/267 [00:04<00:00, 57.47it/s]


iter 50: train loss/sent=0.6708, time=4.65s


100%|██████████| 70/70 [00:00<00:00, 449.74it/s]


iter 50: test acc=0.3000


100%|██████████| 267/267 [00:05<00:00, 49.07it/s]


iter 51: train loss/sent=0.5555, time=5.45s


100%|██████████| 70/70 [00:00<00:00, 464.34it/s]


iter 51: test acc=0.2928


100%|██████████| 267/267 [00:04<00:00, 56.61it/s]


iter 52: train loss/sent=0.5161, time=4.72s


100%|██████████| 70/70 [00:00<00:00, 491.75it/s]


iter 52: test acc=0.2959


100%|██████████| 267/267 [00:05<00:00, 48.87it/s]


iter 53: train loss/sent=0.4493, time=5.47s


100%|██████████| 70/70 [00:00<00:00, 318.93it/s]


iter 53: test acc=0.3090


100%|██████████| 267/267 [00:04<00:00, 56.42it/s]


iter 54: train loss/sent=0.4428, time=4.74s


100%|██████████| 70/70 [00:00<00:00, 443.17it/s]


iter 54: test acc=0.3014


100%|██████████| 267/267 [00:04<00:00, 54.36it/s]


iter 55: train loss/sent=0.4144, time=4.92s


100%|██████████| 70/70 [00:00<00:00, 438.16it/s]


iter 55: test acc=0.3023


100%|██████████| 267/267 [00:05<00:00, 47.80it/s]


iter 56: train loss/sent=0.3432, time=5.59s


100%|██████████| 70/70 [00:00<00:00, 442.63it/s]


iter 56: test acc=0.2959


100%|██████████| 267/267 [00:04<00:00, 54.78it/s]


iter 57: train loss/sent=0.3125, time=4.88s


100%|██████████| 70/70 [00:00<00:00, 449.61it/s]


iter 57: test acc=0.2860


100%|██████████| 267/267 [00:05<00:00, 48.03it/s]


iter 58: train loss/sent=0.3777, time=5.57s


100%|██████████| 70/70 [00:00<00:00, 399.63it/s]


iter 58: test acc=0.3018


100%|██████████| 267/267 [00:04<00:00, 55.56it/s]


iter 59: train loss/sent=0.2907, time=4.82s


100%|██████████| 70/70 [00:00<00:00, 409.43it/s]


iter 59: test acc=0.2801


100%|██████████| 267/267 [00:05<00:00, 51.24it/s]


iter 60: train loss/sent=0.2692, time=5.22s


100%|██████████| 70/70 [00:00<00:00, 327.75it/s]


iter 60: test acc=0.3032


100%|██████████| 267/267 [00:05<00:00, 52.16it/s]


iter 61: train loss/sent=0.2838, time=5.12s


100%|██████████| 70/70 [00:00<00:00, 424.07it/s]


iter 61: test acc=0.3045


100%|██████████| 267/267 [00:04<00:00, 56.08it/s]


iter 62: train loss/sent=0.2374, time=4.77s


100%|██████████| 70/70 [00:00<00:00, 466.07it/s]


iter 62: test acc=0.3041


100%|██████████| 267/267 [00:05<00:00, 48.31it/s]


iter 63: train loss/sent=0.2599, time=5.53s


100%|██████████| 70/70 [00:00<00:00, 460.24it/s]


iter 63: test acc=0.3032


100%|██████████| 267/267 [00:04<00:00, 57.02it/s]


iter 64: train loss/sent=0.2327, time=4.69s


100%|██████████| 70/70 [00:00<00:00, 431.79it/s]


iter 64: test acc=0.3018


100%|██████████| 267/267 [00:05<00:00, 50.91it/s]


iter 65: train loss/sent=0.3406, time=5.25s


100%|██████████| 70/70 [00:00<00:00, 311.22it/s]


iter 65: test acc=0.3000


100%|██████████| 267/267 [00:04<00:00, 56.49it/s]


iter 66: train loss/sent=0.2661, time=4.73s


100%|██████████| 70/70 [00:00<00:00, 416.34it/s]


iter 66: test acc=0.2914


100%|██████████| 267/267 [00:04<00:00, 58.25it/s]


iter 67: train loss/sent=0.3074, time=4.59s


100%|██████████| 70/70 [00:00<00:00, 497.96it/s]


iter 67: test acc=0.2986


100%|██████████| 267/267 [00:05<00:00, 48.88it/s]


iter 68: train loss/sent=0.4982, time=5.47s


100%|██████████| 70/70 [00:00<00:00, 453.12it/s]


iter 68: test acc=0.3195


100%|██████████| 267/267 [00:04<00:00, 58.10it/s]


iter 69: train loss/sent=0.3538, time=4.60s


100%|██████████| 70/70 [00:00<00:00, 447.67it/s]


iter 69: test acc=0.3009


100%|██████████| 267/267 [00:05<00:00, 49.92it/s]


iter 70: train loss/sent=0.2554, time=5.35s


100%|██████████| 70/70 [00:00<00:00, 333.58it/s]


iter 70: test acc=0.2923


100%|██████████| 267/267 [00:04<00:00, 56.84it/s]


iter 71: train loss/sent=0.2151, time=4.71s


100%|██████████| 70/70 [00:00<00:00, 463.76it/s]


iter 71: test acc=0.2873


100%|██████████| 267/267 [00:04<00:00, 57.75it/s]


iter 72: train loss/sent=0.2537, time=4.63s


100%|██████████| 70/70 [00:00<00:00, 411.13it/s]


iter 72: test acc=0.2955


100%|██████████| 267/267 [00:05<00:00, 48.88it/s]


iter 73: train loss/sent=1.1378, time=5.47s


100%|██████████| 70/70 [00:00<00:00, 473.64it/s]


iter 73: test acc=0.2484


100%|██████████| 267/267 [00:04<00:00, 57.67it/s]


iter 74: train loss/sent=1.5472, time=4.64s


100%|██████████| 70/70 [00:00<00:00, 475.52it/s]


iter 74: test acc=0.2570


100%|██████████| 267/267 [00:05<00:00, 49.01it/s]


iter 75: train loss/sent=1.5337, time=5.45s


100%|██████████| 70/70 [00:00<00:00, 450.03it/s]


iter 75: test acc=0.2452


100%|██████████| 267/267 [00:04<00:00, 57.47it/s]


iter 76: train loss/sent=1.5038, time=4.65s


100%|██████████| 70/70 [00:00<00:00, 471.12it/s]


iter 76: test acc=0.2638


100%|██████████| 267/267 [00:04<00:00, 56.90it/s]


iter 77: train loss/sent=1.2983, time=4.70s


100%|██████████| 70/70 [00:00<00:00, 422.42it/s]


iter 77: test acc=0.2878


100%|██████████| 267/267 [00:05<00:00, 48.39it/s]


iter 78: train loss/sent=1.1324, time=5.53s


100%|██████████| 70/70 [00:00<00:00, 476.18it/s]


iter 78: test acc=0.2611


100%|██████████| 267/267 [00:04<00:00, 56.93it/s]


iter 79: train loss/sent=0.8766, time=4.70s


100%|██████████| 70/70 [00:00<00:00, 464.85it/s]


iter 79: test acc=0.2814


100%|██████████| 267/267 [00:05<00:00, 49.51it/s]


iter 80: train loss/sent=0.7979, time=5.40s


100%|██████████| 70/70 [00:00<00:00, 465.15it/s]


iter 80: test acc=0.2688


100%|██████████| 267/267 [00:04<00:00, 58.89it/s]


iter 81: train loss/sent=0.9675, time=4.54s


100%|██████████| 70/70 [00:00<00:00, 434.73it/s]


iter 81: test acc=0.2778


100%|██████████| 267/267 [00:05<00:00, 51.93it/s]


iter 82: train loss/sent=0.6292, time=5.15s


100%|██████████| 70/70 [00:00<00:00, 355.01it/s]


iter 82: test acc=0.2910


100%|██████████| 267/267 [00:05<00:00, 51.08it/s]


iter 83: train loss/sent=0.4963, time=5.24s


100%|██████████| 70/70 [00:00<00:00, 454.66it/s]


iter 83: test acc=0.2851


100%|██████████| 267/267 [00:04<00:00, 59.15it/s]


iter 84: train loss/sent=0.4848, time=4.52s


100%|██████████| 70/70 [00:00<00:00, 420.24it/s]


iter 84: test acc=0.2828


100%|██████████| 267/267 [00:05<00:00, 48.98it/s]


iter 85: train loss/sent=0.6527, time=5.46s


100%|██████████| 70/70 [00:00<00:00, 465.32it/s]


iter 85: test acc=0.3014


100%|██████████| 267/267 [00:04<00:00, 59.57it/s]


iter 86: train loss/sent=0.7340, time=4.49s


100%|██████████| 70/70 [00:00<00:00, 452.09it/s]


iter 86: test acc=0.2729


100%|██████████| 267/267 [00:04<00:00, 54.22it/s]


iter 87: train loss/sent=0.9535, time=4.93s


100%|██████████| 70/70 [00:00<00:00, 327.24it/s]


iter 87: test acc=0.2633


100%|██████████| 267/267 [00:05<00:00, 50.98it/s]


iter 88: train loss/sent=1.2746, time=5.24s


100%|██████████| 70/70 [00:00<00:00, 445.46it/s]


iter 88: test acc=0.2502


100%|██████████| 267/267 [00:04<00:00, 58.31it/s]


iter 89: train loss/sent=1.2717, time=4.58s


100%|██████████| 70/70 [00:00<00:00, 484.16it/s]


iter 89: test acc=0.2362


100%|██████████| 267/267 [00:05<00:00, 48.68it/s]


iter 90: train loss/sent=1.3216, time=5.49s


100%|██████████| 70/70 [00:00<00:00, 430.00it/s]


iter 90: test acc=0.2511


100%|██████████| 267/267 [00:04<00:00, 56.54it/s]


iter 91: train loss/sent=1.3137, time=4.73s


100%|██████████| 70/70 [00:00<00:00, 458.50it/s]


iter 91: test acc=0.2452


100%|██████████| 267/267 [00:05<00:00, 52.15it/s]


iter 92: train loss/sent=1.1165, time=5.13s


100%|██████████| 70/70 [00:00<00:00, 307.39it/s]


iter 92: test acc=0.2801


100%|██████████| 267/267 [00:05<00:00, 53.25it/s]


iter 93: train loss/sent=0.8565, time=5.02s


100%|██████████| 70/70 [00:00<00:00, 474.36it/s]


iter 93: test acc=0.2887


100%|██████████| 267/267 [00:04<00:00, 56.19it/s]


iter 94: train loss/sent=0.7910, time=4.76s


100%|██████████| 70/70 [00:00<00:00, 447.08it/s]


iter 94: test acc=0.2502


100%|██████████| 267/267 [00:05<00:00, 48.67it/s]


iter 95: train loss/sent=0.9001, time=5.49s


100%|██████████| 70/70 [00:00<00:00, 406.91it/s]


iter 95: test acc=0.2783


100%|██████████| 267/267 [00:04<00:00, 55.45it/s]


iter 96: train loss/sent=1.3534, time=4.82s


100%|██████████| 70/70 [00:00<00:00, 448.60it/s]


iter 96: test acc=0.3136


100%|██████████| 267/267 [00:05<00:00, 49.36it/s]


iter 97: train loss/sent=1.4809, time=5.42s


100%|██████████| 70/70 [00:00<00:00, 278.75it/s]


iter 97: test acc=0.3167


100%|██████████| 267/267 [00:04<00:00, 58.35it/s]


iter 98: train loss/sent=1.4893, time=4.58s


100%|██████████| 70/70 [00:00<00:00, 421.40it/s]


iter 98: test acc=0.3163


100%|██████████| 267/267 [00:04<00:00, 56.32it/s]


iter 99: train loss/sent=1.4960, time=4.75s


100%|██████████| 70/70 [00:00<00:00, 479.84it/s]

iter 99: test acc=0.2579
max test acc=0.3195



