In [1]:
# !pip3 install torch torchvision torchaudio
# !pip install pandas
import torch
import torch.nn as nn
import numpy as np
import pandas as pd

In [2]:
df = pd.read_csv('df_kw_subs.csv', sep=';')
df = df.astype({'subs_text': 'string'})
df = df.drop('imdb_id', axis=1)
df

Unnamed: 0,keylist_fifteen,subs_text
0,"['android', 'spaceopera', 'rebellion', 'planet...","['hear', 'shut', 'main', 'reactor', 'destroy',..."
1,"['basedonnovel', 'love', 'friendship', 'flashb...","['hello', 'name', 'forrest', 'forrest', 'gump'..."
2,"['nudity', 'femalenudity', 'malenudity', 'comi...","['need', 'father', 'role', 'model', 'horny', '..."
3,"['murder', 'friendship', 'smalltown', 'robbery...","['sweat', 'know', 'im', 'excite', 'though', 's..."
4,"['love', 'alien', 'newyorkcity', 'future', 'sh...","['come', 'come', 'please', 'aziz', 'aziz', 'az..."
...,...,...
1310,"['murder', 'sex', 'nudity', 'suspense', 'femal...","['hi', 'adrian', 'viktor', 'remember', 'cave',..."
1311,"['artist', 'painting', 'art', 'resistance', 'c...","['present', 'afterimage', 'star', 'come', 'lat..."
1312,"['chicago', 'coma', 'immigrant', 'arrangedmarr...","['keep', 'go', 'next', 'performer', 'man', 'mr..."
1313,"['violence', 'dog', 'postapocalyptic', 'forest...","['music', 'play', 'heavy', 'breathe', 'sarah',..."


15

In [6]:
all_sub_words=set()
for sub in df.subs_text:
   for word in eval(sub):
       if word not in all_sub_words:
           all_sub_words.add(word)

all_kw_words=set()
for keylist in df.keylist_fifteen:
   for word in eval(keylist):
       if word not in all_kw_words:
           all_kw_words.add(word)

input_words = sorted(list(all_sub_words))
target_words = sorted(list(all_kw_words))
num_encoder_tokens = len(all_sub_words)
num_decoder_tokens = len(all_kw_words)

input_token_index = dict(
   [(word, i) for i, word in enumerate(input_words)])
target_token_index = dict(
   [(word, i) for i, word in enumerate(target_words)])

In [5]:
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

In [6]:
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dic = {n: i for i, n in enumerate(char_arr)}

In [7]:
#  ?
n_class = len(num_dic)
n_step = n_class # number of cells(= number of Step)
n_hidden = 128 # number of hidden units in one cell

NameError: name 'word_dict' is not defined

In [None]:
def make_batch():

    input_batch, output_batch, target_batch = [], [], []
    # ?
    # n_step = 1 

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))

        input = [num_dic[n] for n in seq[0]]
        output = [num_dic[n] for n in ('S' + seq[1])]
        target = [num_dic[n] for n in (seq[1] + 'E')]

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target) # not one-hot

    # make tensor
    return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)

In [None]:
class Seq2Seq(nn.Module):
    
    def __init__(self):
        super(Seq2Seq, self).__init__()

        self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, enc_input, enc_hidden, dec_input):
        enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_class]
        dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_class]

        # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _, enc_states = self.enc_cell(enc_input, enc_hidden)
        # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]
        outputs, _ = self.dec_cell(dec_input, enc_states)

        model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]
        
        return model

In [None]:
model = Seq2Seq()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

input_batch, output_batch, target_batch = make_batch()

In [None]:
batch_size = 6

for epoch in range(5000):
    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
    hidden = torch.zeros(1, batch_size, n_hidden)

    optimizer.zero_grad()
    # input_batch : [batch_size, max_len(=n_step, time step), n_class]
    # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
    # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
    output = model(input_batch, hidden, output_batch)
    # output : [max_len+1, batch_size, n_class]
    output = output.transpose(0, 1) # [batch_size, max_len+1(=6), n_class]
    loss = 0
    for i in range(0, len(target_batch)):
        # output[i] : [max_len+1, n_class, target_batch[i] : max_len+1]
        loss += criterion(output[i], target_batch[i])
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

In [None]:
input_batch, output_batch, _ = make_batch()

# make hidden shape [num_layers * num_directions, batch_size, n_hidden]
hidden = torch.zeros(1, 6, 128)
for i in range(0, len(input_batch)):
    output = model(input_batch, hidden, output_batch)
    # output : [max_len+1(=6), batch_size(=1), n_class]

    output = output.transpose(0, 1)
    output = torch.argmax(output.data, -1)[i].numpy()

    decode = [char_arr[c] for c in output]

    end = decode.index('E')
    translated = ''.join(decode[:end])

    input_word = [char_arr[w] for w in input_batch[i].max(-1)[1]]
    input_word = ''.join(input_word[:end])
    print(f"{input_word.replace('P', '')} -> {translated.replace('P', '')}")