# Freezing and Unfreezing Embeddings

In [0]:
N_EPOCHS = 10
FREEZE_FOR = 5

best_valid_loss = float('inf')

#freeze embeddings
model.embedding.weight.requires_grad = unfrozen = False

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s | Frozen? {not unfrozen}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tutC-model.pt')
    
    if (epoch + 1) >= FREEZE_FOR:
        #unfreeze embeddings
        model.embedding.weight.requires_grad = unfrozen = True

In [0]:
if valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    torch.save(model.state_dict(), 'tutC-model.pt')
else:
    #unfreeze embeddings
    model.embedding.weight.requires_grad = unfrozen = True

# Saving Embeddings

In [0]:
from tqdm import tqdm

def write_embeddings(path, embeddings, vocab):
    
    with open(path, 'w') as f:
        for i, embedding in enumerate(tqdm(embeddings)):
            word = vocab.itos[i]
            #skip words with unicode symbols
            if len(word) != len(word.encode()):
                continue
            vector = ' '.join([str(i) for i in embedding.tolist()])
            f.write(f'{word} {vector}\n')

We'll write our embeddings to trained_embeddings.txt.

In [0]:
write_embeddings('custom_embeddings/trained_embeddings.txt', 
                 model.embedding.weight.data, 
                 TEXT.vocab)

To double check they've written correctly, we can load them as Vectors.

In [0]:

trained_embeddings = vocab.Vectors(name = 'custom_embeddings/trained_embeddings.txt',
                                   cache = 'custom_embeddings',
                                   unk_init = torch.Tensor.normal_)

In [0]:
print(trained_embeddings.vectors[:5])

In [0]:
print(model.embedding.weight.data[:5])