In [1]:
using BERT

config = BertConfig(128, 30022, 256, 512, 4, 2, 8, 2, 3, Array{Float32}, 0.1, 0.1, "relu")

model = BertPreTraining(config)

x = [213 234 7789; 712 9182 8912; 7812 12 432; 12389 1823 8483] # 4x3
segment_ids = [1 1 1;1 2 1;1 2 1;1 1 1]
mlm_labels = [-1 234 -1; -1 -1 8912; -1 -1 -1; 12389 -1 -1]
nsp_labels = [1, 2, 1]

loss = model(x, segment_ids, mlm_labels, nsp_labels)
println(loss)

10.87536912086468


In [19]:
using BERT
using Knet
import Base: length, iterate
using Random
using CSV
using PyCall
using Dates

┌ Info: Precompiling CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b]
└ @ Base loading.jl:1273


In [16]:
VOCABFILE = "../../bert-base-uncased-vocab.txt"
NUM_CLASSES = 2
LEARNING_RATE = 2e-5
NUM_OF_EPOCHS = 30
TRAIN = true

true

In [20]:
token2int = Dict()
f = open(VOCABFILE) do file
    lines = readlines(file)
    for (i,line) in enumerate(lines)
        token2int[line] = i
    end
end
int2token = Dict(value => key for (key, value) in token2int)
VOCABSIZE = length(token2int)

30522

In [21]:
function convert_to_int_array(text, dict; lower_case=true)
    tokens = bert_tokenize(text, dict, lower_case=lower_case)
    out = Int[]
    for token in tokens
        if token in keys(dict)
            push!(out, dict[token])
        else
            push!(out, dict["[UNK]"])
        end
    end
    return out
end

convert_to_int_array (generic function with 1 method)

"&\\''"

In [36]:
function read_and_process(filename, dict; lower_case=true)
    data = CSV.File(filename, delim="\t")
    x = Array{Int,1}[]
    y = Int8[]
    for i in data
        push!(x, convert_to_int_array(i.sentence, dict, lower_case=lower_case))
        push!(y, Int8(i.label + 1)) # negative 1, positive 2
    end
    
    # Padding to maximum
#     max_seq = findmax(length.(x))[1]
#     for i in 1:length(x)
#         append!(x[i], fill(1, max_seq - length(x[i]))) # 1 is for "[PAD]"
#     end
    
    return (x, y)
end

read_and_process (generic function with 1 method)

In [37]:
mutable struct ClassificationData
    input_ids
    input_mask
    segment_ids
    labels
    batchsize
    ninstances
    shuffled
end

In [38]:
function ClassificationData(input_file, token2int; batchsize=8, shuffled=true, seq_len=64)
    input_ids = []
    input_mask = []
    segment_ids = []
    labels = []
    (x, labels) = read_and_process(input_file, token2int)
    for i in 1:length(x)
        if length(x[i]) >= seq_len
            x[i] = x[i][1:seq_len]
            mask = Array{Int64}(ones(seq_len))
        else
            mask = Array{Int64}(ones(length(x[i])))
            append!(x[i], fill(1, seq_len - length(x[i]))) # 1 is for "[PAD]"
            append!(mask, fill(0, seq_len - length(mask))) # 0's vanish with masking operation
        end
        push!(input_ids, x[i])
        push!(input_mask, mask)
        push!(segment_ids, Array{Int64}(ones(seq_len)))
    end
    ninstances = length(input_ids)
    return ClassificationData(input_ids, input_mask, segment_ids, labels, batchsize, ninstances, shuffled)
end

ClassificationData

In [39]:
function length(d::ClassificationData)
    d, r = divrem(d.ninstances, d.batchsize)
    return r == 0 ? d : d+1
end

length (generic function with 178 methods)

In [40]:
function iterate(d::ClassificationData, state=ifelse(d.shuffled, randperm(d.ninstances), 1:d.ninstances))

    state === nothing && return nothing

    if length(state) > d.batchsize
        new_state = state[d.batchsize+1:end]
        input_ids = hcat(d.input_ids[state[1:d.batchsize]]...)
        input_mask = hcat(d.input_mask[state[1:d.batchsize]]...)
        segment_ids = hcat(d.segment_ids[state[1:d.batchsize]]...)
        labels = hcat(d.labels[state[1:d.batchsize]]...)
    else
        new_state = nothing
        input_ids = hcat(d.input_ids[state]...)
        input_mask = hcat(d.input_mask[state]...)
        segment_ids = hcat(d.segment_ids[state]...)
        labels = hcat(d.labels[state]...)
    end
    
    return ((input_ids, input_mask, segment_ids, labels), new_state)
end

iterate (generic function with 354 methods)

In [41]:
config = BertConfig(768, 30522, 3072, 512, 64, 2, 12, 12, 8, KnetArray{Float32}, 0.1, 0.1, "gelu")

BertConfig(768, 30522, 3072, 512, 64, 2, 12, 12, 8, KnetArray{Float32,N} where N, 0.1, 0.1, "gelu")

In [43]:
int2token

Dict{Int64,String} with 30522 entries:
  15769 => "stocks"
  13575 => "prohibition"
  10094 => "tel"
  30216 => "##ろ"
  22035 => "mcgee"
  6265  => "murray"
  9934  => "malta"
  21807 => "##tenberg"
  29201 => "substitutes"
  8805  => "quinn"
  1333  => "य"
  13120 => "motorway"
  3120  => "trade"
  9911  => "##ree"
  11942 => "mare"
  20368 => "##ester"
  19698 => "kellan"
  29981 => "##თ"
  16429 => "rd"
  10458 => "backs"
  25568 => "luminous"
  7237  => "categories"
  28907 => "352"
  25786 => "##oint"
  29728 => "##λ"
  ⋮     => ⋮