<a href="https://colab.research.google.com/github/iiacobac/lstm-reduplication-grokking/blob/main/LSTM_Reduplication_by_Grokking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)

# Following Prickett 2017 
# The toy languages has eight different consonants and a single vowel.
char_to_ix = {"b": 0, "p": 1, "d": 2, "t":3, "v":4, "f":5, "z":6, "s":7, "A":8,"E":9,"I":10,"O":11,"U":12,"SOS":13, "EOS":14}  
ix_to_char = {b:a for a, b in char_to_ix.items()}
# The model is exposed to CV stems to reduplicated forms of the shape CVCV, 
#where both C’s in the reduplicated form were identical to the original C in the stem

EMBEDDING_DIM = 24
HIDDEN_DIM = 24

def prepare_sequence(seq, to_ix):
    idxs = [[to_ix[w]] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

l = list(char_to_ix.keys())
training_data = []
test_data = []

for i in range(0,7):
  for j in range(8,13):
    # Skipping 60% of the training set
    if random.random() < 0.5:
      training_data.append(([l[i],l[j]],[l[i],l[j],l[i],l[j]]))
    else:
      test_data.append(([l[i],l[j]],[l[i],l[j],l[i],l[j]]))
#training_data.append(([l[6],l[12]],[l[6],l[12],l[6],l[12]]))

#for j in range(8,12):
  # Skipping 40% of the training set
  #if random.random() < 0.5:
#  test_data.append(([l[6],l[j]],[l[6],l[j],l[6],l[j]]))
  #else:

print(training_data, len(training_data))

class EncoderLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(EncoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.char_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

    def forward(self, input):
        embedded = self.char_embeddings(input)                                                                      
        output, hidden = self.lstm(embedded)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)

class DecoderLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(DecoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim

        self.char_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, vocab_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.char_embeddings(input).view(1, 1, -1)
        output, hidden = self.lstm(output, hidden)
        output = self.out(output[0])
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)


encoder = EncoderLSTM(EMBEDDING_DIM, HIDDEN_DIM, len(char_to_ix))
decoder = DecoderLSTM(EMBEDDING_DIM, HIDDEN_DIM, len(char_to_ix))
loss_function = nn.CrossEntropyLoss()

encoder_optimizer = optim.SGD(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=0.001)

[(['b', 'E'], ['b', 'E', 'b', 'E']), (['p', 'A'], ['p', 'A', 'p', 'A']), (['p', 'E'], ['p', 'E', 'p', 'E']), (['p', 'O'], ['p', 'O', 'p', 'O']), (['d', 'E'], ['d', 'E', 'd', 'E']), (['d', 'I'], ['d', 'I', 'd', 'I']), (['d', 'U'], ['d', 'U', 'd', 'U']), (['t', 'A'], ['t', 'A', 't', 'A']), (['t', 'E'], ['t', 'E', 't', 'E']), (['t', 'U'], ['t', 'U', 't', 'U']), (['v', 'E'], ['v', 'E', 'v', 'E']), (['f', 'A'], ['f', 'A', 'f', 'A']), (['f', 'E'], ['f', 'E', 'f', 'E']), (['f', 'I'], ['f', 'I', 'f', 'I']), (['z', 'I'], ['z', 'I', 'z', 'I'])] 15


In [47]:
import random

plot_losses = []
print_loss_total = 0  # Reset every print_every
plot_loss_total = 0  # Reset every plot_every
for j in range(10000): 
  e = training_data[random.randint(0,len(training_data)-1)]
  loss = 0
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  sss = e[0] + ['EOS']
  input1 = prepare_sequence(sss,char_to_ix)
  _, h = encoder(input1)
  sss = ['SOS'] + e[1]
  input2 = prepare_sequence(sss,char_to_ix)
  sss = e[1] + ['EOS']
  out = prepare_sequence(sss,char_to_ix)
  for di, s in enumerate(input2):
    o, h = decoder(s, h)
    loss += loss_function(o, out[di])
  loss.backward()
  encoder_optimizer.step()
  decoder_optimizer.step()

  print_loss_total += loss
  plot_loss_total += loss

  if j % 10 == 0:
    print_loss_avg = print_loss_total / 10
    print_loss_total = 0
    print(loss.item() / (len(out) * 10))

  if j % 10 == 0:
    plot_loss_avg = plot_loss_total / 10
    plot_losses.append(plot_loss_avg)
    plot_loss_total = 0

showPlot(plot_losses)
plt.show()

0.2656193923950195
0.14883520126342772
0.21945417404174805
0.19986928939819337
0.15545934677124024
0.11119263648986816
0.10271351814270019
0.13539013862609864
0.11945405006408691
0.06336585998535156
0.04308511734008789
0.034285833835601805
0.06845999717712402
0.06295431137084961
0.027781481742858886
0.03931223392486572
0.04406032085418701
0.08308915138244628
0.08456903457641601
0.01555109143257141
0.02854184150695801
0.030587749481201174
0.033676128387451175
0.04130974292755127
0.05318488597869873
0.019484426975250244
0.034035179615020755
0.020032687187194823
0.004948491156101227
0.017217416763305664
0.011906979084014892
0.013779305219650269
0.026781864166259765
0.00823083460330963
0.0060244971513748165
0.0196507728099823
0.005962597131729126
0.005831497311592102
0.022717463970184325
0.012331329584121704
0.008107249736785888
0.009345716834068298
0.005548759698867798
0.024917914867401122
0.007572717666625976
0.01507989525794983
0.0026789307594299316
0.002001504302024841
0.00233050957322

In [1]:
with torch.no_grad():
  for e in range((len(test_data))):
    sss = test_data[e][0] + ['EOS']
    input1 = prepare_sequence(sss,char_to_ix)
    _, h = encoder(input1)
    sss = ['SOS']
    o = prepare_sequence(sss,char_to_ix)
    s = ""
    while not (o[0][0] == char_to_ix['EOS']): 
      #print(o, o.shape)
      o, h = decoder(o, h)
      #print(o, o.shape)
      #print(F.softmax(o, dim=1))
      t = torch.argmax(F.softmax(o, dim=1))
      #print(t)
      if not int(t.item()) == char_to_ix['EOS']:
        s = s + ix_to_char[int(t.item())]
      o = t.view(1,1)
    print("".join(test_data[e][0]), s)

NameError: ignored

In [13]:
#TRAIN

with torch.no_grad():
  for e in range((len(training_data))):
    sss = training_data[e][0] + ['EOS']
    input1 = prepare_sequence(sss,char_to_ix)
    _, h = encoder(input1)
    sss = ['SOS']
    o = prepare_sequence(sss,char_to_ix)
    s = ""
    while not (o[0][0] == char_to_ix['EOS']): 
      #print(o, o.shape)
      o, h = decoder(o, h)
      #print(o, o.shape)
      #print(F.softmax(o, dim=1))
      t = torch.argmax(F.softmax(o, dim=1))
      #print(t)
      if not int(t.item()) == char_to_ix['EOS']:
        s = s + ix_to_char[int(t.item())]
      o = t.view(1,1)
    print("".join(training_data[e][0]), s)

bA bAbA
pI pIpI
pU pUpU
dA dAdA
dE dEdE
dO dOdO
tE tEtE
tO tOtO
tU tUtU
vA vAvA
vE vEvE
vO vOvO
fE fEfE
fO fOfO
fU fUfU
zE zEzE


In [None]:
import random

plot_losses = []
print_loss_total = 0  # Reset every print_every
plot_loss_total = 0  # Reset every plot_every
for j in range(500000): 
  e = training_data[random.randint(0,len(training_data)-1)]
  loss = 0
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  sss = e[0] + ['EOS']
  input1 = prepare_sequence(sss,char_to_ix)
  _, h = encoder(input1)
  sss = ['SOS'] + e[1]
  input2 = prepare_sequence(sss,char_to_ix)
  sss = e[1] + ['EOS']
  out = prepare_sequence(sss,char_to_ix)
  for di, s in enumerate(input2):
    o, h = decoder(s, h)
    loss += loss_function(o, out[di])
  loss.backward()
  encoder_optimizer.step()
  decoder_optimizer.step()

  print_loss_total += loss
  plot_loss_total += loss

  if j % 1000 == 0:
    print_loss_avg = print_loss_total / 1000
    print_loss_total = 0

    with torch.no_grad():
      for e in test_data: 
        loss_t = 0
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        sss = e[0] + ['EOS']
        input1 = prepare_sequence(sss,char_to_ix)
        _, h = encoder(input1)
        sss = ['SOS'] + e[1]
        input2 = prepare_sequence(sss,char_to_ix)
        sss = e[1] + ['EOS']
        out = prepare_sequence(sss,char_to_ix)
        for di, s in enumerate(input2):
          o, h = decoder(s, h)
          loss_t += loss_function(o, out[di])
    print(loss.item() / (len(out) * 100), loss_t.item())



0.02862919235229492 13.837512969970703
0.026744983673095704 13.803336143493652
0.023419965744018556 13.977001190185547
0.02098562812805176 13.92673110961914
0.01900792694091797 13.838058471679688
0.018826776504516602 13.849437713623047
0.013538766860961914 13.768921852111816
0.01095130729675293 13.462804794311523
0.014445327758789062 13.186999320983887
0.006737404823303222 12.350343704223633
0.00495824670791626 11.986157417297363
0.004104662895202637 11.60407829284668
0.008158884048461913 11.422038078308105
0.003722425699234009 10.728235244750977
0.002207855463027954 10.316939353942871
0.002236244201660156 9.919365882873535
0.0026362009048461915 9.44663143157959
0.0022215027809143065 9.18999195098877
0.002614692211151123 8.894307136535645
0.0008582835197448731 8.695313453674316
0.0005579338073730468 8.720525741577148
0.0016231527328491211 8.543256759643555
0.0004378083646297455 8.732290267944336
0.0009029150009155273 8.589290618896484
0.0006245371103286744 8.531414031982422
0.000957837