# Autoencoder

In [None]:
import os
import glob
import logging
from datetime import datetime
from tqdm import tqdm
tqdm.monitor_interval = 0
import pandas as pd

logging.basicConfig(filename='logs/autoencoder.log', filemode='w', level=logging.INFO, 
                        format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device('cpu')

OUT_DIR = 'representations'

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.lstm(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

class Autoencoder:
    
    def __init__(self, dim, word2id):
        self.dim = dim
        self.word2id = word2id.copy()
        self.word2id['SOS'] = len(self.word2id)
        self.word2id['EOS'] = len(self.word2id)
        self.SOS = self.word2id['SOS']
        self.EOS = self.word2id['EOS']
        self.vocab_size = len(self.word2id)
        self.encoder = Encoder(self.vocab_size, dim)
        self.decoder = Decoder(dim, self.vocab_size) 
    
    def indices_from_sentence(self, sentence):
        return [self.word2id[word] for word in sentence.split(' ') if word]
    
    def tensor_from_sentence(self, sentence):
        indices = self.indices_from_sentence(sentence)
        indices.append(self.EOS)
        return torch.tensor(indices, dtype=torch.long, device=device).view(-1, 1)
    
    def encode(self, s):
        sent_tensor = self.tensor_from_sentence(s)
        sent_length = sent_tensor.size(0)
        encoder_hidden = self.encoder.init_hidden()
        encoder_cell = self.encoder.init_hidden()
        for i in range(sent_length):
            encoder_output, (encoder_hidden, encoder_cell) = self.encoder(sent_tensor[i], 
                                                                         (encoder_hidden, encoder_cell))
        return encoder_hidden.data.numpy()
    
    def save(self, dirname):
        fname = os.path.join(dirname, 'autoencoder-{}.params'.format(self.dim))
        torch.save(self.encoder.state_dict(), fname)
    
    def load(self, fname):
        self.encoder.load_state_dict(torch.load(fname))
    
    def save_embeddings(self, dirname):
        embeddings = {}
        for word, index in self.word2id.items():
            t = torch.tensor(index, dtype=torch.long, device=device).view(-1, 1)
            vector = self.encoder.embedding(t).data.numpy().flatten()
            embeddings[word] = vector
        df = pd.DataFrame.from_dict(embeddings, orient='index')
        df.index.name = 'lx_obj'
        fname = os.path.join(dirname, 'autoencoder-{}-embeddings.csv'.format(self.dim))
        df.to_csv(fname)
        
    def train(self, sentences, epochs=1, lr=0.01):
        encoder_optimizer = optim.SGD(self.encoder.parameters(), lr=lr)
        decoder_optimizer = optim.SGD(self.decoder.parameters(), lr=lr)
        criterion = nn.NLLLoss()
        for _ in range(epochs):
            for n, sent in tqdm(enumerate(sentences)):
                sent_tensor = self.tensor_from_sentence(sent)
                sent_length = sent_tensor.size(0)
                encoder_hidden = self.encoder.init_hidden()
                encoder_cell = self.encoder.init_hidden()
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss = 0
                for i in range(sent_length):
                    encoder_output, (encoder_hidden, encoder_cell) = self.encoder(sent_tensor[i], 
                                                                                 (encoder_hidden, encoder_cell))
                decoder_input = torch.tensor([[self.SOS]], device=device)
                decoder_hidden = encoder_hidden
                decoder_cell = self.decoder.init_hidden()
                for i in range(sent_length):
                    decoder_output, (decoder_hidden, decoder_cell) = self.decoder(decoder_input, 
                                                                                     (decoder_hidden, decoder_cell))
                    loss += criterion(decoder_output, sent_tensor[i])
                    decoder_input = sent_tensor[i]
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()

In [None]:
class Trainer:
    
    DIMS = [100, 300, 500]
    
    def __init__(self, lg):
        self.lg = lg
        self.out_dir = os.path.join(OUT_DIR, lg)
        self.word2id = self.read_word2id()
    
#     def read_word2id(self):
#         fname = os.path.join(self.lg, 'metadata.pkl')
#         with open(fname, 'rb') as f:
#             obj = pickle.load(f)
#         return obj['word2index']
    
    def read_word2id(self):
        fnames = glob.iglob(os.path.join('wikipedia', self.lg, 'articles/*.txt'))
        vocab = set()
        for fname in fnames:
            with open(fname, encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    tokens = line.split()
                    vocab.update(tokens)
        return {w:i for (w, i) in zip(vocab, range(len(vocab)))}
    
    def read(self):
        fnames = glob.iglob(os.path.join('wikipedia', self.lg, 'articles/*.txt'))
        for fname in fnames:
            with open(fname, encoding='utf-8') as f:
                for line in f:
                    yield line.strip()
    
    def train(self):
        for d in self.DIMS:
            start = datetime.now()
            autoencoder = Autoencoder(d, self.word2id)
            autoencoder.train(self.read())
            autoencoder.save(self.out_dir)
            autoencoder.save_embeddings(self.out_dir)
            end = datetime.now()
            msg = 'Training {} autoencoder with {} dimensions took {}'.format(self.lg, d, end-start)
            logging.info(msg)

In [7]:
d = 100
lg = 'en'
fnames = glob.iglob(os.path.join('wikipedia', lg, 'articles/*.txt'))
vocab = set()
for fname in fnames:
    with open(fname, encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            tokens = line.split()
            vocab.update(tokens)
word2id = {w:i for (w, i) in zip(vocab, range(len(vocab)))}
def read():
    for fname in fnames:
        with open(fname, encoding='utf-8') as f:
            for line in f:
                yield line.strip()

sentences = read()
start = datetime.now()
autoencoder = Autoencoder(d, word2id)
autoencoder.train(sentences)
out_dir = os.path.join('representations', lg)
autoencoder.save(out_dir)
autoencoder.save_embeddings(out_dir)
end = datetime.now()
msg = 'Training {} autoencoder with {} dimensions took {}'.format(lg, d, end-start)
logging.info(msg)

0it [00:00, ?it/s]
