<a href="https://colab.research.google.com/github/iiacobac/lstm-reduplication-grokking/blob/main/LSTM_Reduplication_by_Grokking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1)

# Following Prickett 2017 
# The toy languages has eight different consonants and a single vowel.
char_to_ix = {"b": 0, "p": 1, "d": 2, "t":3, "v":4, "f":5, "z":6, "s":7, "V":8, "SOS":9, "EOS":10}  
ix_to_char = {b:a for a, b in char_to_ix.items()}
# The model is exposed to CV stems to reduplicated forms of the shape CVCV, 
#where both C’s in the reduplicated form were identical to the original C in the stem

EMBEDDING_DIM = 18
HIDDEN_DIM = 18

def prepare_sequence(seq, to_ix):
    idxs = [[to_ix[w]] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

l = list(char_to_ix.keys())
training_data = []
for i in range(0,7):
    training_data.append(([l[i],l[8]],[l[i],l[8],l[i],l[8]]))

class EncoderLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(EncoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.char_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

    def forward(self, input):
        embedded = self.char_embeddings(input)                                                                      
        output, hidden = self.lstm(embedded)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)

class DecoderLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        super(DecoderLSTM, self).__init__()
        self.hidden_dim = hidden_dim

        self.char_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, vocab_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.char_embeddings(input).view(1, 1, -1)
        output, hidden = self.lstm(output, hidden)
        output = self.out(output[0])
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_dim, device=device)


encoder = EncoderLSTM(EMBEDDING_DIM, HIDDEN_DIM, len(char_to_ix))
decoder = DecoderLSTM(EMBEDDING_DIM, HIDDEN_DIM, len(char_to_ix))
loss_function = nn.CrossEntropyLoss()

encoder_optimizer = optim.SGD(encoder.parameters(), lr=0.1)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=0.1)





In [11]:

char_embeddings = nn.Embedding(len(char_to_ix), EMBEDDING_DIM)
lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM)
out = nn.Linear(EMBEDDING_DIM, len(char_to_ix))
softmax = nn.LogSoftmax(dim=1)

o, h = lstm(char_embeddings(input))
o.shape
h[0].shape
h[1].shape




torch.Size([1, 1, 6])

In [6]:
sss = training_data[0][0] + ['EOS']
input = prepare_sequence(sss,char_to_ix)
input

tensor([[ 0],
        [ 8],
        [10]])

In [12]:
sss = ['SOS'] + training_data[0][1]
input = prepare_sequence(sss,char_to_ix)

for s in char_embeddings(input):
  o, h = lstm(s.view(1, 1, -1),h)
  
  #print(o)
  #print(out(o))
  print(softmax(out(o)))
#c = char_embeddings(input)
#c[0].view(1, 1, -1)
#lstm(char_embeddings(input),h)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
       grad_fn=<LogSoftmaxBackward0>)


In [58]:
for j in range(1000): 
  for e in training_data[:6]:
    loss = 0
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    sss = e[0] + ['EOS']
    input1 = prepare_sequence(sss,char_to_ix)
    _, h = encoder(input1)
    sss = ['SOS'] + e[1]
    input2 = prepare_sequence(sss,char_to_ix)
    sss = e[1] + ['EOS']
    out = prepare_sequence(sss,char_to_ix)
    for di, s in enumerate(input2):
      o, h = decoder(s, h)
      loss += loss_function(o, out[di])
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    print(loss.item() / len(out))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
0.0033068776130676268
0.003146623820066452
0.0037771120667457582
0.004460082203149795
0.004521599039435387
0.004770161956548691
0.003266872838139534
0.003119629621505737
0.0037487588822841646
0.004414046928286552
0.004478077962994575
0.004720157384872437
0.003227761760354042
0.0030930351465940477
0.003720617666840553
0.004368905350565911
0.004435238987207412
0.004671140015125275
0.00318942591547966
0.0030669834464788436
0.003692973405122757
0.004324561730027199
0.00439329594373703
0.004623014479875565
0.003151914104819298
0.003041236288845539
0.003665659949183464
0.004280993342399597
0.004351987689733505
0.004575854167342186
0.0031152492389082908
0.0030159609392285346
0.0036386530846357346
0.004238294810056687
0.004311386123299598
0.004529610276222229
0.0030792895704507828
0.0029911570250988006
0.003612048923969269
0.004196206480264664
0.0042715385556221005
0.004484117776155472
0.0030441310256719587
0.002966730296611786
0

In [109]:
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
e = training_data[0]
sss = e[0] + ['EOS']
input1 = prepare_sequence(sss,char_to_ix)
_, h = encoder(input1)


In [110]:
sss = ['SOS'] + e[1]
input2 = prepare_sequence(sss,char_to_ix)
sss = e[1] + ['EOS']
out = prepare_sequence(sss,char_to_ix)


In [111]:
loss = 0
for di, s in enumerate(input2):
  o, h = decoder(s, h)
  print(o)
  loss += loss_function(o, out[di])
loss

tensor([[-0.1469, -0.1878,  0.4044, -0.2718, -0.0234,  0.0778, -0.2826,  0.1836,
         -0.1210,  0.2494,  0.3337]], grad_fn=<AddmmBackward0>)
tensor([[-0.1299, -0.1860,  0.4735, -0.2535,  0.0615,  0.1418, -0.2286,  0.2050,
         -0.0954,  0.2242,  0.3689]], grad_fn=<AddmmBackward0>)
tensor([[-0.1352, -0.2391,  0.4671, -0.2139, -0.0081,  0.0462, -0.3311,  0.1794,
         -0.1894,  0.2246,  0.3631]], grad_fn=<AddmmBackward0>)
tensor([[-0.1258, -0.1966,  0.4940, -0.2210,  0.0552,  0.1274, -0.2483,  0.2117,
         -0.1289,  0.2086,  0.3748]], grad_fn=<AddmmBackward0>)
tensor([[-0.1310, -0.2501,  0.4955, -0.1867,  0.0014,  0.0414, -0.3388,  0.1840,
         -0.2125,  0.2117,  0.3753]], grad_fn=<AddmmBackward0>)


tensor(12.4288, grad_fn=<AddBackward0>)

In [112]:
loss.backward()

   

In [113]:
encoder_optimizer.step()
decoder_optimizer.step()

EncoderLSTM(
  (char_embeddings): Embedding(11, 6)
  (lstm): LSTM(6, 6)
)

In [31]:
with torch.no_grad():
  sss = training_data[0][0] + ['EOS']
  input1 = prepare_sequence(sss,char_to_ix)
  _, h = encoder(input1)
  sss = ['SOS']
  o = prepare_sequence(sss,char_to_ix)
  print(o, o.shape)
  o, h = decoder(o, h)
  print(o.shape)
  print(F.softmax(o, dim=1))
  t = torch.argmax(F.softmax(o, dim=1))
  print(t)
  o = t.view(1,1)
  o, h = decoder(o, h)
  print(o.shape)
  print(F.softmax(o, dim=1))
  t = torch.argmax(F.softmax(o, dim=1))
  print(t)





tensor([[9]]) torch.Size([1, 1])
torch.Size([1, 11])
tensor([[9.9568e-01, 5.4534e-08, 2.8756e-03, 3.0122e-08, 4.9959e-04, 1.1459e-05,
         9.2903e-04, 4.6928e-07, 1.5566e-06, 2.4349e-07, 5.2906e-10]])
tensor(0)
torch.Size([1, 11])
tensor([[2.3418e-04, 5.1701e-07, 7.3907e-05, 7.5694e-07, 4.1083e-06, 5.9297e-05,
         8.3627e-05, 2.0363e-06, 9.9953e-01, 4.0578e-06, 8.1728e-06]])
tensor(8)


In [59]:
with torch.no_grad():
  for e in range(7):
    sss = training_data[e][0] + ['EOS']
    input1 = prepare_sequence(sss,char_to_ix)
    _, h = encoder(input1)
    sss = ['SOS']
    o = prepare_sequence(sss,char_to_ix)
    s = ""
    while not (o[0][0] == 10): 
      #print(o, o.shape)
      o, h = decoder(o, h)
      #print(o, o.shape)
      #print(F.softmax(o, dim=1))
      t = torch.argmax(F.softmax(o, dim=1))
      #print(t)
      if not int(t.item()) == 10:
        s = s + ix_to_char[int(t.item())]
      o = t.view(1,1)
    print("".join(training_data[e][0]), s)





bV bVbV
pV pVpV
dV dVdV
tV tVtV
vV vVvV
fV fVfV
zV fV
