In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import random
from uuid import uuid4
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import getopt
import sys
import os
import math
import time
import argparse
from visdom import Visdom
from tqdm import tqdm
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

sys.path.insert(0, os.path.join('..', '..'))

import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm_

from dnc.dnc import DNC
from dnc.sdnc import SDNC
from dnc.sam import SAM
from dnc.util import *

from dnc.lib import exp_loss, InputStorage, mse, criterion, ENDSYM, tensor2string

T.autograd.set_detect_anomaly(True)

import copy
from dnc.lib import STEPBYSTEPOBJ
import pickle

import os

batch_size = 100
sequence_length = 3
sequence_max_length = 7
iterations = 1*10**4 #200000
summarize_freq = int(iterations/100)
check_freq = int(iterations/20)
curriculum_freq = int(iterations/10)


  # input_size = output_size = args.input_size
mem_slot = 32
mem_size = 1
read_heads = 1
curriculum_increment = 1
input_size = 3*sequence_max_length + 2
output_size = 64

replaceWithWrong = True

num_layers = 4

In [5]:
st = InputStorage()

def genSeq(sizeTpl, min=None, max=None, default=True):
  if default:
    return np.random.binomial(1, 0.5, sizeTpl)
  assert min is not None and max is not None
  return np.random.randint(min, max, sizeTpl)


def generate_data(batch_size, length, maxlength, testoccurance=True, transposeInput=False, transposeOutput=False):
  minSeq = 0
  maxSeq = 4 # 2= binary, 10=decimal
  input_data = np.zeros((batch_size, maxlength, maxlength), dtype=np.float32)
  target_output = np.zeros((batch_size, maxlength, maxlength), dtype=np.float32)
  sequence1 = genSeq((batch_size, length, 1), min=minSeq, max=maxSeq, default=False)
  sequence2 = genSeq((batch_size, length, 1), min=minSeq, max=maxSeq, default=False)

  if testoccurance: # test if the sequence is in the test data, replace if so
    for i in range(batch_size):
      input_test_data = np.zeros((1, maxlength, maxlength), dtype=np.float32)
      input_test_data[0, 0:length, 0:1] = sequence1[i] #first sequence
      input_test_data[0, length, 1] = ENDSYM  #pause
      input_test_data[0, length+1:length*2+1, 2:3] = sequence2[i] #second sequence
      input_test_data[0, length*2+1, 3] = ENDSYM  #pause
      while st.isSaved(input_test_data[0], flag="testData"):
        if np.random.binomial(1, 0.5, 1) == 1: # replace first sequence
          sequence1[i] = genSeq((length, 1), min=minSeq, max=maxSeq, default=False)
          input_test_data[0, 0:length, 0:1] = sequence1[i]
        else: # replace second sequence
          sequence2[i] = genSeq((length, 1), min=minSeq, max=maxSeq, default=False)
          input_test_data[0, length+1:length*2+1, 2:3] = sequence2[i]

  input_data[:, 0:length, 0:1] = sequence1 #first sequence
  input_data[:, length, 1] = ENDSYM  #pause
  input_data[:, length+1:length*2+1, 2:3] = sequence2 #second sequence
  input_data[:, length*2+1, 3] = ENDSYM  #pause
  if transposeInput:
    for i in range(batch_size):
      input_data[i] = input_data[i].T

  def calcsum(sequenceA, sequenceB, maxval=(maxSeq)): #calculate sum of two binary numbers
    sumsequence = np.zeros((batch_size, length + 1, length +1))
    assert len(sequenceA) == len(sequenceB)
    for k in range(len(sequenceA)):
      carry = 0 # carry bit
      for j in reversed(range(len(sequenceA[k]))):
          sumsequence[k][j+1][-1] = (sequenceA[k][j][0] + sequenceB[k][j][0] + carry) % maxval
          carry = (sequenceA[k][j][0] + sequenceB[k][j][0] + carry) // maxval
        
            # if sequenceA[k][j][0] == 1 and sequenceB[k][j][0] == 1: #1+1=10
            #     sumsequence[k][j+1][-1] = 0+carry
            #     carry = 1
            # elif (sequenceA[k][j][0] == 1 and sequenceB[k][j][0] == 0) or (sequenceA[k][j][0] == 0 and sequenceB[k][j][0] == 1): #1+0=1 and 0+1=1
            #     if carry == 1:
            #         sumsequence[k][j+1][-1] = 0
            #         carry = 1
            #     else:
            #         sumsequence[k][j+1][-1] = 1
            #         carry = 0
            # else:
            #     sumsequence[k][j+1][-1] = 0+carry #0+0=0
            #     carry = 0
      sumsequence[k][0][-1] = carry
    return sumsequence
  
  cs = calcsum(sequence1, sequence2)
  for i in range(batch_size):
    target_output[i, -(length+1):, -(length+1):] = cs[i] #write sum to target output
    if transposeOutput:
      target_output[i] = target_output[i].T

  return input_data, target_output

In [6]:
rnn = DNC(
        input_size=input_size,
        hidden_size=output_size,
        rnn_type='rnn',
        #rnn_type='lstm',
        num_layers=num_layers,
        num_hidden_layers=1,
        dropout=0,
        nr_cells=mem_slot,
        cell_size=mem_size,
        read_heads=read_heads,
        gpu_id=-1,
        debug='store_true',
        batch_first=True,
        independent_linears=True,
        nonlinearity='tanh',
    )

name = 'add_a88'
lastcp = f'{name}/checkpoint_10000.pth'
print(name)

with open(f"{name}/output_2.txt", "w") as f:
  batch_size=1
  rnn.load_state_dict(T.load(lastcp, weights_only=True))
  rnn.eval()
  
  stepByStep = copy.deepcopy(STEPBYSTEPOBJ)

  i=0
  #llprint("\nIteration %d/%d" % (i, iterations))
  # We test now the learned generalization using sequence_max_length examples
  random_length = np.random.randint(2, sequence_length  + 1)
  input_data, target_output = generate_data(1, random_length, input_size)

  #print (input_data, target_output)

  
  input_data = var(T.from_numpy(input_data))
  target_output = var(T.from_numpy(target_output))

  stepByStep["CurrI"] = i
  stepByStep["currentObj"] = copy.deepcopy(stepByStep["defObj"])
  stepByStep["currentObj"]["i"] = i 
  stepByStep["input"] = input_data.detach().numpy()
  stepByStep["target"] = target_output.detach().numpy()
  stepByStep["MEMORYCOLUMNS"] = mem_slot
  stepByStep["INPUTSIZE"] = input_size
  stepByStep["OUTPUTSIZE"] = output_size
  stepByStep["read_heads"] = read_heads
    
  if rnn.debug:
    output, (chx, mhx, rv), v = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True, stepByStep=stepByStep)
  else:
    output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True, stepByStep=stepByStep)

  stepByStep["output"] = output.detach().numpy()
  stepByStep["objects"].append(copy.deepcopy(stepByStep["currentObj"]))
  stepByStep['loss'] = str(mse(output, target_output).item())
  #output = output[:, -1, :].sum().data.cpu().numpy()
  #target_output = target_output.sum().data.cpu().numpy()
  print("loss", mse(output, target_output).item())
  print(stepByStep["input"].shape)
  print(stepByStep["output"].shape)
  print(stepByStep["target"].shape)
  #raise Exception("STOP")

  print(stepByStep)

  pickle.dump(stepByStep, open(f"{name}/stepByStep.pkl", "wb"))

  print("\n\n")
  print("Input: ", tensor2string(input_data[0]), file=f)
  print("Output: ", tensor2string(torch.round(output[0], decimals=1)), file=f)
  print("Target: ", tensor2string(target_output[0]), file=f)
  print("CE Loss: ", str(mse(output[0], target_output[0]).item()), file=f)
  print("Log Loss: ", str(criterion(output[0], target_output[0]).item()), file=f)
  print("Exp Loss: ", str(exp_loss(output[0], target_output[0]).item()), file=f)
  print("\n\n")
  print("CE Loss: ", str(mse(output, target_output).item()), file=f)
  print("Log Loss: ", str(criterion(output, target_output).item()), file=f)
  print("Exp Loss: ", str(exp_loss(output, target_output).item()), file=f)
  print("\n\n")

  try:
    print("\nReal value: ", ' = ' + str(int(target_output[0])))
    print("Predicted:  ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
  except Exception as e:
    pass

add_a88
loss 0.04404061660170555
(1, 23, 23)
(1, 23, 23)
(1, 23, 23)
{'stepByStep': True, 'CurrI': 0, 'time': 22, 'layer': 3, 'currentObj': {'i': 0, 'time': 22, 'layer': 3, 'inputs': array([[ 1.5622276e-01,  2.2774661e-01,  3.4945603e-02,  1.6473579e-01,
        -1.3058425e-01,  8.1347190e-02,  4.1489896e-01,  6.0271293e-02,
        -4.7627918e-02, -2.8921509e-02,  7.3277861e-02,  8.5580461e-02,
         5.7516418e-02,  2.0328239e-02, -4.3857012e-02,  9.8535880e-02,
        -6.4706601e-02, -2.0074643e-01,  2.0877582e-01,  2.1842621e-01,
         9.9345982e-02,  3.0441117e-03,  9.3908243e-02,  2.3797552e-01,
         1.6305017e-01, -4.3792999e-01, -1.5484407e-02, -2.0734428e-01,
         2.4520056e-02, -3.3372197e-02, -9.9631593e-02,  2.8424390e-02,
         4.4808317e-02,  1.3398661e-01, -1.5305999e-01,  1.0083845e-01,
        -2.8261018e-01,  4.5448677e-03,  3.6318311e-01,  8.5084341e-02,
        -4.1496590e-02,  2.1506369e-01, -2.2469056e-01, -1.5302478e-01,
         9.1678657e-02, -