In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
vs = 20001
emb_dim = 200
seq_len = 31
lstm_units = 128
l2_reg = 1e-4

In [6]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam

## Simulate base model

In [7]:
model = Sequential()
model.add(Embedding(vs, emb_dim, input_length=seq_len, name='word_embedding'))
model.add(LSTM(lstm_units, 
               kernel_regularizer=l2(l2_reg),
               recurrent_regularizer=l2(l2_reg),
               bias_regularizer=l2(l2_reg),
               name='lstm'))
model.add(Dense(vs, activation='softmax', kernel_regularizer=l2(l2_reg), bias_regularizer=l2(l2_reg), name='prediction'))

In [8]:
opt = Adam(beta_1=0.8, beta_2=0.99)

In [9]:
model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

In [10]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
word_embedding (Embedding)   (None, 31, 200)           4000200   
_________________________________________________________________
lstm (LSTM)                  (None, 128)               168448    
_________________________________________________________________
prediction (Dense)           (None, 20001)             2580129   
Total params: 6,748,777
Trainable params: 6,748,777
Non-trainable params: 0
_________________________________________________________________


## Build transfer model

In [11]:
lstm = model.get_layer(name='lstm')

In [14]:
lstm.get_weights()

[array([[ 0.05784441,  0.00556013, -0.05876821, ..., -0.08893563,
         -0.0185671 , -0.04207173],
        [ 0.06946001,  0.06603055, -0.01034956, ..., -0.05820586,
         -0.06237112, -0.05804583],
        [ 0.06381498,  0.02139571, -0.07674293, ...,  0.0504185 ,
         -0.07229254, -0.01836178],
        ...,
        [ 0.04364493,  0.05540826,  0.03366634, ..., -0.02832787,
         -0.00213686, -0.00531023],
        [-0.0451461 ,  0.05915882,  0.046291  , ..., -0.04769681,
         -0.04013777, -0.03035577],
        [-0.04148764, -0.00959321, -0.02567694, ...,  0.08164992,
         -0.08284491, -0.02456603]], dtype=float32),
 array([[-0.03836513, -0.01535209, -0.01214978, ...,  0.02098113,
         -0.09337205, -0.0319512 ],
        [-0.00545963,  0.01140785,  0.04630283, ...,  0.01380163,
         -0.01360659, -0.01689173],
        [-0.00977908, -0.02002477, -0.05896986, ..., -0.07822406,
          0.07078375,  0.07281852],
        ...,
        [ 0.04565244, -0.01714043, -0.0

In [15]:
model.pop()

In [16]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
word_embedding (Embedding)   (None, 31, 200)           4000200   
_________________________________________________________________
lstm (LSTM)                  (None, 128)               168448    
Total params: 6,748,777
Trainable params: 6,748,777
Non-trainable params: 0
_________________________________________________________________


In [24]:
model.add(Dense(6, activation='softmax', name='new_prediction'))
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
word_embedding (Embedding)   (None, 31, 200)           4000200   
_________________________________________________________________
lstm (LSTM)                  (None, 128)               168448    
_________________________________________________________________
new_prediction (Dense)       (None, 6)                 774       
Total params: 4,169,422
Trainable params: 4,169,422
Non-trainable params: 0
_________________________________________________________________


In [18]:
model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

In [19]:
lstm = model.get_layer(name='lstm')
lstm.get_weights()

[array([[ 0.05784441,  0.00556013, -0.05876821, ..., -0.08893563,
         -0.0185671 , -0.04207173],
        [ 0.06946001,  0.06603055, -0.01034956, ..., -0.05820586,
         -0.06237112, -0.05804583],
        [ 0.06381498,  0.02139571, -0.07674293, ...,  0.0504185 ,
         -0.07229254, -0.01836178],
        ...,
        [ 0.04364493,  0.05540826,  0.03366634, ..., -0.02832787,
         -0.00213686, -0.00531023],
        [-0.0451461 ,  0.05915882,  0.046291  , ..., -0.04769681,
         -0.04013777, -0.03035577],
        [-0.04148764, -0.00959321, -0.02567694, ...,  0.08164992,
         -0.08284491, -0.02456603]], dtype=float32),
 array([[-0.03836513, -0.01535209, -0.01214978, ...,  0.02098113,
         -0.09337205, -0.0319512 ],
        [-0.00545963,  0.01140785,  0.04630283, ...,  0.01380163,
         -0.01360659, -0.01689173],
        [-0.00977908, -0.02002477, -0.05896986, ..., -0.07822406,
          0.07078375,  0.07281852],
        ...,
        [ 0.04565244, -0.01714043, -0.0

In [25]:
model.layers

[<tensorflow.python.keras.layers.embeddings.Embedding at 0x7fdf7d856668>,
 <tensorflow.python.keras.layers.recurrent.LSTM at 0x7fdf7d856f28>,
 <tensorflow.python.keras.layers.core.Dense at 0x7fdf7c1eda90>]

In [28]:
for layer in model.layers[:-1]:
    layer.trainable = False
model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

In [29]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
word_embedding (Embedding)   (None, 31, 200)           4000200   
_________________________________________________________________
lstm (LSTM)                  (None, 128)               168448    
_________________________________________________________________
new_prediction (Dense)       (None, 6)                 774       
Total params: 4,169,422
Trainable params: 774
Non-trainable params: 4,168,648
_________________________________________________________________


In [30]:
lstm = model.get_layer(name='lstm')
lstm.get_weights()[0]

array([[ 0.05784441,  0.00556013, -0.05876821, ..., -0.08893563,
        -0.0185671 , -0.04207173],
       [ 0.06946001,  0.06603055, -0.01034956, ..., -0.05820586,
        -0.06237112, -0.05804583],
       [ 0.06381498,  0.02139571, -0.07674293, ...,  0.0504185 ,
        -0.07229254, -0.01836178],
       ...,
       [ 0.04364493,  0.05540826,  0.03366634, ..., -0.02832787,
        -0.00213686, -0.00531023],
       [-0.0451461 ,  0.05915882,  0.046291  , ..., -0.04769681,
        -0.04013777, -0.03035577],
       [-0.04148764, -0.00959321, -0.02567694, ...,  0.08164992,
        -0.08284491, -0.02456603]], dtype=float32)