In [1]:
import math

from mindspore import nn
from mindspore import ops
from mindspore.common.initializer import initializer

from mindnlp import load_dataset, process, Vocab
from mindnlp._legacy.abc import Seq2vecModel
from mindnlp.engine import Trainer
from mindnlp.metrics import Accuracy
from mindnlp.modules import Glove, RNNEncoder, StaticLSTM
from mindnlp.transforms import BasicTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load datasets
imdb_train, imdb_test = load_dataset('imdb', shuffle=True)
print(imdb_train.get_col_names())

['text', 'label']


In [3]:
tokenizer = BasicTokenizer(True)
vocab = Vocab.from_pretrained(name="glove.6B.100d")

imdb_train = process('imdb', imdb_train, tokenizer=tokenizer, vocab=vocab, \
                     bucket_boundaries=[400, 500], max_len=600, drop_remainder=True)
imdb_test = process('imdb', imdb_test, tokenizer=tokenizer, vocab=vocab, \
                     bucket_boundaries=[400, 500], max_len=600, drop_remainder=False)

In [4]:
class SentimentClassification(Seq2vecModel):
    def construct(self, text):
        _, (hidden, _), _ = self.encoder(text)
        context = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.head(context)
        return output

In [5]:
# Hyper-parameters
hidden_size = 256
output_size = 2
num_layers = 2
bidirectional = True
dropout = 0.5
lr = 5e-4

In [6]:
# build encoder
embedding = Glove.from_pretrained('6B', 100, special_tokens=["<unk>", "<pad>"])
lstm_layer = StaticLSTM(100, hidden_size, num_layers=num_layers, batch_first=True,
                     dropout=dropout, bidirectional=bidirectional)
encoder = RNNEncoder(embedding, lstm_layer)

# build head
head = nn.Sequential([
    nn.Dropout(p=dropout),
    nn.Dense(hidden_size * 2, output_size)
])

# build network
network = SentimentClassification(encoder, head)
loss = nn.CrossEntropyLoss()
optimizer = nn.Adam(network.trainable_params(), learning_rate=lr)

In [7]:
def initialize_weights(m):
    if isinstance(m, nn.Dense):
        m.weight.set_data(initializer('xavier_normal', m.weight.shape, m.weight.dtype))
        m.bias.set_data(initializer('zeros', m.bias.shape, m.bias.dtype))
    elif isinstance(m, StaticLSTM):
        for name, param in m.parameters_and_names():
            if 'bias' in name:
                param.set_data(initializer('zeros', param.shape, param.dtype))
            elif 'weight' in name:
                param.set_data(initializer('orthogonal', param.shape, param.dtype))

In [8]:
network.apply(initialize_weights)

SentimentClassification<
  (encoder): RNNEncoder<
    (embedding): Glove<
      (dropout_layer): Dropout<p=0.0>
      >
    (rnn): StaticLSTM<
      (rnn): MultiLayerRNN<
        (cell_list): CellList<
          (0): SingleLSTMLayer_GPU<>
          (1): SingleLSTMLayer_GPU<>
          >
        (dropout): Dropout<p=0.5>
        >
      >
    >
  (head): Sequential<
    (0): Dropout<p=0.5>
    (1): Dense<input_channels=512, output_channels=2, has_bias=True>
    >
  >

In [9]:
# define metrics
metric = Accuracy()

# define trainer
trainer = Trainer(network=network, train_dataset=imdb_train, eval_dataset=imdb_test, metrics=metric,
                  epochs=5, loss_fn=loss, optimizer=optimizer)
trainer.run(tgt_columns="label")

Epoch 0: 100%|█████████████████████████████████████████████████████████████████████| 390/390 [01:16<00:00,  5.12it/s, loss=0.648818]
Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:16<00:00, 24.26it/s]


Evaluate Score: {'Accuracy': 0.54512}


Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 390/390 [01:12<00:00,  5.37it/s, loss=0.6403419]
Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 27.31it/s]


Evaluate Score: {'Accuracy': 0.76188}


Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.32it/s, loss=0.5273511]
Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:15<00:00, 25.91it/s]


Evaluate Score: {'Accuracy': 0.82848}


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.29it/s, loss=0.36224687]
Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 27.32it/s]


Evaluate Score: {'Accuracy': 0.85832}


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.30it/s, loss=0.27380344]
Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 26.93it/s]

Evaluate Score: {'Accuracy': 0.86688}



