In [1]:
using Flux, PyCall

In [2]:
using Conda
Conda.add("pytorch")
pytorch = pyimport("torch") # very simple to import a Python module

┌ Info: Running `conda install -y pytorch` in root environment
└ @ Conda /Users/khyatt/.julia/packages/Conda/3rPhK/src/Conda.jl:113


Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

# All requested packages already installed.



PyObject <module 'torch' from '/Users/khyatt/.julia/conda/3/lib/python3.7/site-packages/torch/__init__.py'>

In [3]:
# make a tokenized dictionary of the WikiData
struct Corpus
    word_dict::Dict{String, Int}
    test_data
    train_data 
    valid_data
end

function tokenize!(word_dict, path)
    # add the unique words to the dictionary
    for el in eachline(path)
        words = vcat(split(el, " "), "<eos>")
        for word in filter(word -> !haskey(word_dict, word), words)
            word_dict[word] = length(word_dict)
        end
    end
    # turn the words in the file into tokens
    ids = Vector{Int}(undef, 0)
    for el in eachline(path)
        words = vcat(split(el, " "), "<eos>")
        ids_  = zeros(Int, length(words))
        for wi in 1:length(ids_)
            ids_[wi] = get(word_dict, words[wi], 0)
        end
        append!(ids, ids_)
    end
    return word_dict, ids
end

function Corpus(data_path::String)
    @assert isdir(data_path) "Data filepath $data_path does not exist"
    word_dict = Dict{String, Int}()
    word_dict, test_data  = tokenize!(word_dict, joinpath(data_path, "test.txt"))
    word_dict, train_data = tokenize!(word_dict, joinpath(data_path, "train.txt"))
    word_dict, valid_data = tokenize!(word_dict, joinpath(data_path, "valid.txt"))
    Corpus(word_dict, test_data, train_data, valid_data)
end

Corpus

In [4]:
wiki_data = Corpus(joinpath(pwd(), "wikitext2"));

In [5]:
# batchify the data
batch_size    = 20
eval_batch    = 10
import Flux: chunk
train_batches = chunk(wiki_data.train_data, div(length(wiki_data.train_data), batch_size));
test_batches  = chunk(wiki_data.test_data, div(length(wiki_data.test_data), eval_batch));
valid_batches = chunk(wiki_data.valid_data, div(length(wiki_data.valid_data), eval_batch));
# data is ready, let's make a model!

In [6]:
hidden  = 200
nlayers = 2
Ntoken  = 200
batch_size    = 20
eval_batch    = 20
seq_len       = 35
throttler     = 20
lr            = 1e-5
epochs        = 2
import Flux: chunk, onehot, onecold, onehotbatch, crossentropy, throttle
using Statistics, Random
using BSON: @save

# use Flux's inbuilt RNN stuff
function train(; kws...)
    @info("Loading data....")
    wiki_data = Corpus(joinpath(pwd(), "wikitext2"))
    train_batches = chunk(wiki_data.train_data, div(length(wiki_data.train_data), batch_size))
    test_batches  = chunk(wiki_data.test_data,  div(length(wiki_data.test_data),  eval_batch))
    valid_batches = chunk(wiki_data.valid_data, div(length(wiki_data.valid_data), eval_batch))

    test_data = Vector{Tuple{Flux.OneHotMatrix, Vector{Int}}}(undef, length(test_batches))
    for (ii, td) in enumerate(test_batches)
        split = min(seq_len, length(td)-1)
        test_data[ii] = (onehotbatch(td[1:split], 1:length(wiki_data.word_dict)), td[split+1:end])
    end
    train_data = Vector{Tuple{Flux.OneHotMatrix, Vector{Int}}}(undef, length(train_batches))
    for (ii, td) in enumerate(train_batches)
        split = min(seq_len, length(td)-1)
        train_data[ii] = (onehotbatch(td[1:split], 1:length(wiki_data.word_dict)), td[split+1:end])
    end
    valid_data = Vector{Tuple{Flux.OneHotMatrix, Vector{Int}}}(undef, length(valid_batches))
    for (ii, td) in enumerate(valid_batches)
        split = min(seq_len, length(td)-1)
        valid_data[ii] = (onehotbatch(td[1:split], 1:length(wiki_data.word_dict)), td[split+1:end])
    end

    @info("Constructing model....")
    encoding = Dense(length(wiki_data.word_dict), Ntoken)
    decoding = Dense(Ntoken, length(wiki_data.word_dict))

    model      = Chain(encoding, Dropout(0.5), LSTM(hidden, hidden), LSTM(Ntoken, hidden), Dropout(0.5), decoding, softmax)
    loss(x, y) = crossentropy(model(x), onehotbatch(y, 1:length(wiki_data.word_dict)))
    opt        = ADAM(lr)
    evalcb     = () -> @show mean([loss(data[1], data[2]) for data in valid_data[1:10]])

    @info("Training Model...")
    for epoch in 1:epochs
        Flux.train!(loss, params(model), shuffle!(train_data[1:10]), opt, cb = throttle(evalcb, throttler))
    end

    @save "wikimodel.bson" model
end

train (generic function with 1 method)

In [7]:
train()

┌ Info: Loading data....
└ @ Main In[6]:16
┌ Info: Constructing model....
└ @ Main In[6]:38
┌ Info: Training Model...
└ @ Main In[6]:47


mean([loss(data[1], data[2]) for data = valid_data[1:10]]) = 208.25589f0
mean([loss(data[1], data[2]) for data = valid_data[1:10]]) = 208.2609f0
