In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sentences = ["i like dog", "i love coffee", "i hate milk", "you like cat", "you love milk", "you hate coffee"]
dtype = torch.float


In [3]:
"""
Word Processing
"""
word_list = list(set(" ".join(sentences).split()))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
n_class = len(word_dict)



In [4]:

"""
TextRNN Parameter
"""
batch_size = len(sentences)
n_step = 2  # 학습 하려고 하는 문장의 길이 - 1
n_hidden = 5  # 은닉층 사이즈

def make_batch(sentences):
  input_batch = []
  target_batch = []

  for sen in sentences:
    word = sen.split()
    input = [word_dict[n] for n in word[:-1]]
    target = word_dict[word[-1]]

    input_batch.append(np.eye(n_class)[input])  # One-Hot Encoding
    target_batch.append(target)
  
  return input_batch, target_batch

input_batch, target_batch = make_batch(sentences)
input_batch = torch.tensor(input_batch, dtype=torch.float32, requires_grad=True)
target_batch = torch.tensor(target_batch, dtype=torch.int64)



  input_batch = torch.tensor(input_batch, dtype=torch.float32, requires_grad=True)


In [5]:

"""
TextLSTM
"""
class TextLSTM(nn.Module):
  def __init__(self):
    super(TextLSTM, self).__init__()

    self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, dropout=0.3)
    self.fc = nn.Linear(n_hidden, n_class)

  def forward(self, hidden_and_cell, X):
    X = X.transpose(0, 1)
    outputs, (h_n,c_n) = self.lstm(X, hidden_and_cell)
    outputs = h_n[-1]  # 최종 예측 Hidden Layer

    model = self.fc(outputs)  # 최종 예측 최종 출력 층
    return model
	


In [9]:
prunFreq = 1

"""
Training
"""
model = TextLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(500):
  hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
  cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)


  ### Prune
  if epoch % prunFreq == 0:   
    
    parameters_to_prune = (
    (model.lstm, 'weight_ih_l0'),
    (model.lstm, 'weight_hh_l0'),
    (model.lstm, 'bias_ih_l0'),
    (model.lstm, 'bias_hh_l0'),
    (model.fc, 'weight'),
    (model.fc, 'bias')
    )

    prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.6
    )
    prune.remove(model.lstm, 'weight_ih_l0')
    prune.remove(model.lstm, 'weight_hh_l0')
    prune.remove(model.lstm, 'bias_ih_l0')
    prune.remove(model.lstm, 'bias_hh_l0')
    prune.remove(model.fc, 'weight')
    prune.remove(model.fc, 'bias')

    print("before")
    peter_print()

  output = model((hidden, cell), input_batch)
  loss = criterion(output, target_batch)

  if (epoch + 1) % 100 == 0:
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  print("After")
  peter_print()


before
Sparsity in lstm.weight_ih_l0: 62.78%
Sparsity in lstm.weight_hh_l0: 62.78%
Sparsity in lstm.bias_ih_l0: 50.00%
Sparsity in lstm.bias_hh_l0: 55.00%
Sparsity in fc.weight: 60.00%
Sparsity in fc.bias: 44.44%
Global sparsity: 59.89%
After
Sparsity in lstm.weight_ih_l0: 32.78%
Sparsity in lstm.weight_hh_l0: 32.78%
Sparsity in lstm.bias_ih_l0: 0.00%
Sparsity in lstm.bias_hh_l0: 0.00%
Sparsity in fc.weight: 0.00%
Sparsity in fc.bias: 0.00%
Global sparsity: 16.31%
before
Sparsity in lstm.weight_ih_l0: 62.78%
Sparsity in lstm.weight_hh_l0: 62.78%
Sparsity in lstm.bias_ih_l0: 50.00%
Sparsity in lstm.bias_hh_l0: 55.00%
Sparsity in fc.weight: 60.00%
Sparsity in fc.bias: 44.44%
Global sparsity: 59.89%
After
Sparsity in lstm.weight_ih_l0: 32.78%
Sparsity in lstm.weight_hh_l0: 32.78%
Sparsity in lstm.bias_ih_l0: 0.00%
Sparsity in lstm.bias_hh_l0: 0.00%
Sparsity in fc.weight: 0.00%
Sparsity in fc.bias: 0.00%
Global sparsity: 16.31%
before
Sparsity in lstm.weight_ih_l0: 62.78%
Sparsity in lstm.

In [8]:

input = [sen.split()[:2] for sen in sentences]

hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)


print("before")
peter_print()
predict = model((hidden, cell), input_batch).data.max(1, keepdim=True)[1]

print("after")
peter_print()



print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])

before
Sparsity in lstm.weight_ih_l0: 31.67%
Sparsity in lstm.weight_hh_l0: 31.67%
Sparsity in lstm.bias_ih_l0: 0.00%
Sparsity in lstm.bias_hh_l0: 0.00%
Sparsity in fc.weight: 0.00%
Sparsity in fc.bias: 0.00%
Global sparsity: 15.24%
after
Sparsity in lstm.weight_ih_l0: 31.67%
Sparsity in lstm.weight_hh_l0: 31.67%
Sparsity in lstm.bias_ih_l0: 0.00%
Sparsity in lstm.bias_hh_l0: 0.00%
Sparsity in fc.weight: 0.00%
Sparsity in fc.bias: 0.00%
Global sparsity: 15.24%
[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['you', 'like'], ['you', 'love'], ['you', 'hate']] -> ['dog', 'coffee', 'milk', 'cat', 'milk', 'coffee']


In [None]:
list(model.named_parameters())


In [None]:
model.lstm.weight_ih_l0

In [6]:
def peter_print():
    print(
        "Sparsity in lstm.weight_ih_l0: {:.2f}%".format(
            100. * float(torch.sum(model.lstm.weight_ih_l0 == 0))
            / float(model.lstm.weight_ih_l0.nelement())
        )
    )

    print(
        "Sparsity in lstm.weight_hh_l0: {:.2f}%".format(
            100. * float(torch.sum(model.lstm.weight_ih_l0 == 0))
            / float(model.lstm.weight_ih_l0.nelement())
        )
    )

    print(
        "Sparsity in lstm.bias_ih_l0: {:.2f}%".format(
            100. * float(torch.sum(model.lstm.bias_ih_l0 == 0))
            / float(model.lstm.bias_ih_l0.nelement())
        )
    )

    print(
        "Sparsity in lstm.bias_hh_l0: {:.2f}%".format(
            100. * float(torch.sum(model.lstm.bias_hh_l0 == 0))
            / float(model.lstm.bias_hh_l0.nelement())
        )
    )

    print(
        "Sparsity in fc.weight: {:.2f}%".format(
            100. * float(torch.sum(model.fc.weight == 0))
            / float(model.fc.weight.nelement())
        )
    )
    print(
        "Sparsity in fc.bias: {:.2f}%".format(
            100. * float(torch.sum(model.fc.bias == 0))
            / float(model.fc.bias.nelement())
        )
    )


    print(
        "Global sparsity: {:.2f}%".format(
            100. * float(
                torch.sum(model.lstm.weight_ih_l0 == 0)
                + torch.sum(model.lstm.weight_hh_l0 == 0)
                + torch.sum(model.lstm.bias_hh_l0 == 0)
                + torch.sum(model.lstm.bias_ih_l0 == 0)
                + torch.sum(model.fc.weight == 0)
                + torch.sum(model.fc.bias == 0)
            )
            / float(
                model.lstm.weight_ih_l0.nelement()
                + model.lstm.weight_hh_l0.nelement()
                + model.lstm.bias_hh_l0.nelement()
                + model.lstm.bias_ih_l0.nelement()
                + model.fc.weight.nelement()
                + model.fc.bias.nelement()
            )
        )
    )

In [None]:
peter_print()