In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import random
from uuid import uuid4
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import getopt
import sys
import os
import math
import time
import argparse
from visdom import Visdom
from tqdm import tqdm
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

sys.path.insert(0, os.path.join('..', '..'))

import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm_

from dnc.dnc import DNC
from dnc.sdnc import SDNC
from dnc.sam import SAM
from dnc.util import *

from dnc.lib import exp_loss, InputStorage, mse, criterion, CELoss, L1loss, ENDSYM, tensor2string, LEARNABLEOBJECTIVES, LEARNTHISOBJECTIVES, RETURNOTHEROBJ

T.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7746c1b25ea0>

# Todo:
- generate_data, InputStorage to Torch objects

In [2]:
viz = Visdom()
# assert viz.check_connection()


dataoutputformat = 2 # binary / mod 2

NoneClassOutput = False

outputformat = dataoutputformat
if NoneClassOutput:
  outputformat = dataoutputformat +1


def llprint(message):
  sys.stdout.write(message)
  sys.stdout.flush()



st = InputStorage()

def genSeq(sizeTpl, min=None, max=None, default=True):
  if default:
    return np.random.binomial(1, 0.5, sizeTpl)
  assert min is not None and max is not None
  return np.random.randint(min, max, sizeTpl)

def calcsum(sequenceA, sequenceB, maxval=dataoutputformat, batch_size=100, length=6): #calculate sum of two binary numbers
    sumsequence = np.zeros((batch_size, length +1))
    assert len(sequenceA) == len(sequenceB)
    for k in range(len(sequenceA)):
      carry = 0 # carry bit
      for j in reversed(range(len(sequenceA[k]))):
          sumsequence[k][j+1] = (sequenceA[k][j][0] + sequenceB[k][j][0] + carry) % maxval
          carry = (sequenceA[k][j][0] + sequenceB[k][j][0] + carry) // maxval
      sumsequence[k][0] = carry
    return sumsequence

def generate_data(batch_size, length, maxlength, testoccurance=True, transposeInput=False):
  minSeq = 0
  maxSeq = dataoutputformat # 2= binary, 10=decimal

  input_data = np.zeros((batch_size, maxlength, maxlength), dtype=np.float32)
  target_output = np.zeros((batch_size, maxlength, outputformat), dtype=np.float32)
  sequence1 = genSeq((batch_size, length, 1), min=minSeq, max=maxSeq, default=False)
  sequence2 = genSeq((batch_size, length, 1), min=minSeq, max=maxSeq, default=False)

  if testoccurance: # test if the sequence is in the test data, replace if so
    for i in range(batch_size):
      input_test_data = np.zeros((1, maxlength, maxlength), dtype=np.float32)
      input_test_data[0, 0:length, 0:1] = sequence1[i] #first sequence
      input_test_data[0, length, 1] = ENDSYM  #pause
      input_test_data[0, length+1:length*2+1, 2:3] = sequence2[i] #second sequence
      input_test_data[0, length*2+1, 3] = ENDSYM  #pause
      while st.isSaved(input_test_data[0], flag="testData"):
        if np.random.binomial(1, 0.5, 1) == 1: # replace first sequence
          sequence1[i] = genSeq((length, 1), min=minSeq, max=maxSeq, default=False)
          input_test_data[0, 0:length, 0:1] = sequence1[i]
        else: # replace second sequence
          sequence2[i] = genSeq((length, 1), min=minSeq, max=maxSeq, default=False)
          input_test_data[0, length+1:length*2+1, 2:3] = sequence2[i]

  input_data[:, 0:length, 0:1] = sequence1 #first sequence
  input_data[:, length, 1] = ENDSYM  #pause
  input_data[:, length+1:length*2+1, 2:3] = sequence2 #second sequence
  input_data[:, length*2+1, 3] = ENDSYM  #pause
  if transposeInput:
    for i in range(batch_size):
      input_data[i] = input_data[i].T

  
  
  cs = calcsum(sequence1, sequence2, maxval=dataoutputformat, batch_size=batch_size, length=length)
  offset = 0
  if NoneClassOutput:
    offset = 1

  for i in range(batch_size):
    for j in reversed(range(1, cs.shape[1]+1)):
      target_output[i, -j, offset+int(cs[i,-j])] = 1
      
    if NoneClassOutput:
      for j in range(target_output.shape[1]):
        if np.sum(target_output[i, j]) == 0:
          target_output[i, j, 0] = 1


  return input_data, target_output




def combLoss(prediction, target):
  sumlos = 0
  for i in range(prediction.shape[0]):
    labels = target[i].argmax(dim=1)
    sumlos += CELoss(prediction[i], labels)
  return sumlos/prediction.shape[0]

def incrementCurriculum(trainError, epoch, sequence_length, maxsequence_length, curriculum_fre):
  return epoch != 0 and sequence_length < maxsequence_length and epoch % curriculum_fre == 0

def calcAccuracy(prediction, target, total=False):
  if not isinstance(prediction, T.Tensor):
    raise Exception("prediction is not a tensor")
  if not isinstance(target, T.Tensor):
    raise Exception("target is not a tensor")
  
  firstone = 0
  for i in range(target.shape[1]):
    if target[:,i].sum() != 0:
      firstone = i
      break
  prediction = prediction[:,firstone:]
  target = target[:, firstone:]
  labels = target.argmax(dim=2)
  accuracy = (prediction.argmax(dim=2) == labels).int().to(T.float32).mean().item()
  #accuracy = T.sum(T.isclose(prediction, target, atol=0.25).int().to(T.float32), dim=2).mean().item()/outputformat
      
  if total:
    return 1 if accuracy == 1 else 0
  return accuracy

#calcAccuracy(T.tensor([[[0,1],[0,0]]]), T.tensor([[[0,1],[0,0]]]))
  

Setting up a new session...


In [3]:
#d = generate_data(1, 3, 9, testoccurance=False)
#print(d)

In [4]:
import copy
from dnc.lib import STEPBYSTEPOBJ
import pickle

import os

batch_size = 50#int(1360*(1-0.15))-1
sequence_length = 5
sequence_max_length = 5
iterations = int(1*10**3) #200000
summarize_freq = int(iterations/100)
check_freq = int(iterations/10)
curriculum_freq = int(iterations/10)

print("batch_size", batch_size)
print("sequence_length", sequence_length)
print("iterations", iterations)
print("summarize_freq", summarize_freq)
print("check_freq", check_freq)
print("curriculum_freq", curriculum_freq)


  # input_size = output_size = args.input_size
mem_slot = 20#48 #112
mem_size = 1
read_heads = 1
curriculum_increment = 1
input_size = 3*sequence_max_length + 2
output_size = 64

replaceWithWrong = True

num_layers = 2 #5 

print("input_size", input_size)
print("output_size", output_size)
print("mem_slot", mem_slot)
print("mem_size", mem_size)
print("read_heads", read_heads)
print("num_layers", num_layers)


# mem operations = input_size*num_layers

batch_size 50
sequence_length 5
iterations 1000
summarize_freq 10
check_freq 100
curriculum_freq 100
input_size 17
output_size 64
mem_slot 20
mem_size 1
read_heads 1
num_layers 2


In [5]:
def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print("Directory created successfully!")
    else:
        print("Directory already exists.")

name = 'add_' + str(uuid4().hex)[:3] + '_comp_'

lastcp = None

create_directory_if_not_exists(name)

datas = []

print(name)

loadcp = False#'add_868/checkpoint_500.pth'#'add_dc0/checkpoint_1500.pth'#False#'add_e9a/checkpoint_1000.pth'#'add_587/checkpoint_1000.pth' #= 'checkpoint_add_46_242000.pth

print(input_size, output_size)
print(name)

Directory created successfully!
add_5b6_comp_
17 64
add_5b6_comp_


In [6]:

def lossfnwithReturnOther(output, target, otherReturn, OneAndZero=False, writegateThreshold=0.0, printlosses=False, returnTuple=False, learnthisobjective=LEARNTHISOBJECTIVES, alwaysone=True, epoch=1, lossfunction=None, speciallossfn=None, mnum_layers=num_layers):
  if lossfunction is None:
    lossfunction = mse

  if epoch % 50 == 0:
    with open(f'{name}/output_lossfn.txt', 'a') as lossfile:
      print("EPOCH: ", epoch, file=lossfile)
      for key in otherReturn.keys():
        if isinstance(otherReturn[key], T.Tensor):
          print(key, otherReturn[key].shape, file=lossfile)
          for i in range(otherReturn[key].shape[1] // mnum_layers):
            print(key, " Input ", i, file=lossfile)
            print(otherReturn[key][0, i*mnum_layers:(i+1)*mnum_layers], file=lossfile)

      
  sf = 100 # scale factor for softmax
  base = 0
  if learnthisobjective["general_loss"]:
    if speciallossfn is None:
      smoutput = output[:, -(sequence_max_length+1):, :]
      smtarget = target[:, -(sequence_max_length+1):, :]

      for i in range(smoutput.shape[0]):
        labels = smtarget[i].argmax(dim=1)
        base += speciallossfn(smoutput[i], labels)
    else:
      base = speciallossfn(output, target)

    #base += criterion(smoutput, smtarget)#mse(smoutput, smtarget) + exp_loss(smoutput, smtarget) #T.sum(T.abs(smoutput - smtarget)) / batch_size
    #(exp_loss(smoutput, smtarget) +  mse(smoutput, smtarget))*target.numel()/batch_size

  allocation_weight_loss = 0
  if learnthisobjective["allocation_weights"]  and isinstance(otherReturn["allocation_weights"], T.Tensor):
    zarr = T.zeros_like(otherReturn["allocation_weights"])
    for b in range(otherReturn["allocation_weights"].shape[0]):
      for i in range(otherReturn["allocation_weights"].shape[1]):
        #zarr[b, i, :, (i % input_size) % zarr.shape[3]] = 1
        zarr[b, i, :, (i // mnum_layers) % zarr.shape[3]] = 1

    if epoch % 50 == 0:
      with open(f'{name}/output_lossfn.txt', 'a') as lossfile:
        print("Allocation Weights (target): ", zarr.shape, file=lossfile)
        for i in range(zarr.shape[1] // mnum_layers):
            print("Allocation Weights (target) Input ", i, file=lossfile)
            print(zarr[0, i*mnum_layers:(i+1)*mnum_layers], file=lossfile)
    allocation_weight_loss += lossfunction(otherReturn["allocation_weights"], zarr) #T.sum(T.abs(otherReturn["allocation_weights"] - zarr)) / batch_size

  allocation_gate_loss = 0
  if learnthisobjective["allocation_gate"]  and isinstance(otherReturn["allocation_gate"], T.Tensor):
    if OneAndZero:
      allocation_gate_loss += lossfunction(T.pow(((otherReturn["allocation_gate"]-0.5)*2),2), T.ones(otherReturn["allocation_gate"].shape))
      #T.sum(T.abs(T.pow(((otherReturn["allocation_gate"]-0.5)*2),2) - T.ones(otherReturn["allocation_gate"].shape))) / batch_size
    else: # no writing by content similarity
      allocation_gate_loss += lossfunction(otherReturn["allocation_gate"], T.ones(otherReturn["allocation_gate"].shape))
      
      #T.sum(T.abs(otherReturn["allocation_gate"]-T.ones(otherReturn["allocation_gate"].shape))) / batch_size

  write_gate_loss = 0
  if learnthisobjective["write_gate"]  and isinstance(otherReturn["write_gate"], T.Tensor):
    write_gate_loss += lossfunction(otherReturn["write_gate"], T.ones(otherReturn["write_gate"].shape))
    #T.sum(T.abs(otherReturn["write_gate"] - T.ones(otherReturn["write_gate"].shape))) / batch_size

  read_modes_loss = 0
  if learnthisobjective["read_modes"]  and isinstance(otherReturn["read_modes"], T.Tensor):
    read_modes_loss += lossfunction(otherReturn["read_modes"], T.nn.functional.softmax(otherReturn["read_modes"].clone().detach()*sf, 3))
    #T.sum(T.abs(otherReturn["read_modes"] - T.nn.functional.softmax(otherReturn["read_modes"].clone().detach()*sf, 3))) / batch_size

  write_weights_loss = 0
  if learnthisobjective["write_weights"]  and isinstance(otherReturn["write_weights"], T.Tensor):
    #write_weights_loss = T.sum(T.abs(otherReturn["write_weights"] - T.nn.functional.softmax(otherReturn["write_weights"].clone().detach()*sf, 3))) / batch_size
    zarr = T.zeros_like(otherReturn["write_weights"])
    for b in range(otherReturn["write_weights"].shape[0]):
      for i in range(otherReturn["write_weights"].shape[1]):
        zarr[b, i, :, (i // mnum_layers) % zarr.shape[3]] = 1
    write_weights_loss += lossfunction(otherReturn["write_weights"], zarr)
    #T.sum(T.abs(otherReturn["write_weights"] - zarr)) / batch_size
    if epoch % 10 == 0:
      with open(f'{name}/output_lossfn.txt', 'a') as lossfile:
        print("Write Weights (target): ", zarr.shape, file=lossfile)
        for i in range(zarr.shape[1] // mnum_layers):
            print("Write Weights (target) Input ", i, file=lossfile)
            print(zarr[0, i*mnum_layers:(i+1)*mnum_layers], file=lossfile)
  
  usage_vector_loss = 0
  if learnthisobjective["usage_vector"]  and isinstance(otherReturn["usage_vector"], T.Tensor):
    uv = T.zeros_like(otherReturn["usage_vector"])
    for b in range(otherReturn["usage_vector"].shape[0]):
      for i in range(otherReturn["usage_vector"].shape[1]-1):
        uv[b, i+1:, (i // mnum_layers) % uv.shape[2]] = 1
    usage_vector_loss += lossfunction(otherReturn["usage_vector"], uv)
    #T.sum(T.abs(otherReturn["usage_vector"] - uv)) / batch_size
    #print("Usage Vector LOSS: ", usage_vector_loss)
    #print("Usage Vector (output): ", otherReturn["usage_vector"])
    #print("Usage Vector (target): ", uv)
    if epoch % 10 == 0:
      with open(f'{name}/output_lossfn.txt', 'a') as lossfile:
        print("Usage Vector (target): ", uv.shape, file=lossfile)
        for i in range(uv.shape[1] // mnum_layers):
            print("Usage Vector (target) Input ", i, file=lossfile)
            print(uv[0, i*mnum_layers:(i+1)*mnum_layers], file=lossfile)


  if printlosses:
    print("losses: ", base, "\n Allocation Weight ", allocation_weight_loss,  "\n allocation gate", allocation_gate_loss, "\n write gate", write_gate_loss, "\n read modes", read_modes_loss,"\n write weights", write_weights_loss, "\n usage vector", usage_vector_loss)
  
  
  if returnTuple:
    return base, allocation_weight_loss, usage_vector_loss, allocation_gate_loss, write_gate_loss, read_modes_loss, write_weights_loss
  return base + allocation_weight_loss + allocation_gate_loss + write_gate_loss + read_modes_loss + write_weights_loss + usage_vector_loss



otherkey:  allocation_weights
torch.Size([100, 80, 1, 48])
otherkey:  allocation_gate
torch.Size([100, 80, 1])
otherkey:  write_gate
torch.Size([100, 80, 1])
otherkey:  write_weights
torch.Size([100, 80, 1, 48])
otherkey:  read_modes
torch.Size([100, 80, 1, 48])
otherkey:  read_weights
torch.Size([100, 80, 1, 48])
otherkey:  free_gates
torch.Size([100, 80, 1])
otherkey:  erase_vector
torch.Size([100, 80, 1, 1])
otherkey:  write_vector
torch.Size([100, 80, 1, 1])
otherkey:  usage_vector
torch.Size([100, 80, 48])

In [7]:
class CaclulateFactors:
  def __init__(self, iterations, nofactors=7, justOnes=False, softmax=False, softmaxTemp=1, minimum=0.01, maximum=100, resetepoch=50, sizeadjust=False, factadjust=None):
    self.losses = T.zeros((iterations+1, nofactors))
    self.factors = T.ones(nofactors)
    self.noFactors = nofactors
    self.justOnes = justOnes
    self.softmax = softmax
    self.softmaxTemp = softmaxTemp
    self.minimum = minimum
    self.maximum = maximum
    self.sizeAdjust = T.ones(nofactors)
    self.performsizeadjust = sizeadjust
    if factadjust is not None and isinstance(factadjust, T.Tensor):
      self.factadjust = factadjust
      self.adjusttwice = True
      print("Adjusting twice")
      print(factadjust)
    else:
      self.adjusttwice = False

  def setFactors(self, factors):
    self.factors = factors
  
  def setJustOnes(self, justOnes):
    self.justOnes = justOnes

  def reset(self):
    self.factors = T.ones(self.noFactors)

  def __call__(self, epoch, currentlosses, lookback=5, resetinterval=int(iterations//50), rescaleby=2):
    if self.performsizeadjust and epoch > 1:
      lb = max(1, epoch-lookback)
      sumlosses = T.sum(self.losses[epoch-lb:epoch]) / lb
      devisors = T.sum(self.losses[epoch-lb:epoch], dim=0)/lb
      devisors = T.where(devisors < 10**-2, 10**-2, devisors)
      self.sizeAdjust = T.nan_to_num(sumlosses / devisors, nan=1, posinf=1, neginf=1)
      self.sizeAdjust = T.where(self.sizeAdjust > 10**2, 10**2, self.sizeAdjust)
      self.sizeAdjust = T.where(self.sizeAdjust < 10**-2, 10**-2, self.sizeAdjust)
      self.sizeAdjust = T.where(self.sizeAdjust == 0, 1, self.sizeAdjust)

      #print(T.sum(self.losses[epoch-lookback:epoch], dim=0))
      #print("sumlosses: ", sumlosses)
      #print("Size Adjust: ", self.sizeAdjust)
      
    for i in range(len(currentlosses)):
      if isinstance(currentlosses[i], T.Tensor):
        self.losses[epoch,i] = currentlosses[i].detach().item()

    if self.justOnes:
      ret = T.ones(self.noFactors)
      if self.softmax:
        ret =  T.nn.functional.softmax(ret*self.softmaxTemp, 0)
      if self.performsizeadjust:
        ret = ret * self.sizeAdjust
      if self.adjusttwice:
        ret = ret * self.factadjust
      return ret

    
    if T.nonzero(self.losses[epoch] > 1).squeeze().numel() <= 1:
      self.factors = T.ones(self.noFactors)
      ret = self.factors.clone()
      if self.softmax:
        ret = T.nn.functional.softmax(ret*self.softmaxTemp, 0)
      if self.performsizeadjust:
        ret = ret * self.sizeAdjust
      if self.adjusttwice:
        ret = ret * self.factadjust
      return ret

    
    if epoch <= 1 or (epoch % resetinterval) == 0 or self.justOnes:
      self.factors = T.ones(self.noFactors)
      ret = self.factors.clone()

      if self.softmax:
        ret = T.nn.functional.softmax(ret*self.softmaxTemp, 0)
      if self.performsizeadjust:
        ret = ret * self.sizeAdjust
      if self.adjusttwice:
        ret = ret * self.factadjust
      return ret

    oldlosses = self.losses[max(0,epoch-lookback):epoch]
    oldlosses = T.where(oldlosses == 0, self.losses[epoch].unsqueeze(1).T.expand(oldlosses.shape), oldlosses)
    omeandist = T.mean(oldlosses, dim=0)-self.losses[epoch]
    # signdiff = T.sign(self.losses[epoch]-T.mean(oldlosses, dim=0))
    firstsignificant = T.where(omeandist > 1.2, 0, T.floor(T.nan_to_num(T.log10(T.mean(T.abs(oldlosses), dim=0)), nan=0, posinf=0, neginf=0)))

    omeandist = omeandist / 10**firstsignificant
    meandist = T.where(omeandist > 1.2, 2.01, omeandist)
    meandist = T.where(omeandist < 0, 0.01, meandist)    
    meandist = (2-meandist)+1


    newfactors = self.factors*meandist
    if T.any((newfactors > self.maximum)):
      newfactors = newfactors / rescaleby
    newfactors[self.losses[epoch] < 10**-4] = 1 # if loss is zero, keep factor at 1 NEW

    newfactors[newfactors < self.minimum] = self.minimum

    self.factors = newfactors


    ret = self.factors.clone()
    if self.softmax:
      ret = T.nn.functional.softmax(ret*self.softmaxTemp, 0)
    if self.performsizeadjust:
      ret = ret * self.sizeAdjust
    if self.adjusttwice:
      ret = ret * self.factadjust
    return ret



In [8]:
def CalcLossonValidationData(st, rnn, mhx, batch_size, input_size):
  testset = st.getDataByFlag("testData") # get test data
  testlosses = []
  testaccuracy = []
  if len(testset) == 0:
    raise ValueError("No test data available")
  

  for k in range(int(len(testset) / batch_size)+1): # split testdata into batch_size chunks
    input_TEST_data = np.zeros((batch_size, input_size, input_size))
    target_TEST_output = np.zeros((batch_size, input_size, outputformat))
    for i in range(batch_size):
      if i + k * batch_size < len(testset):
        input_TEST_data[i] = testset[k*batch_size+i]["input"]
        target_TEST_output[i] = testset[k*batch_size+i]["output"]
      else: # if there is not enough test data fill the remaining slots with random entries
        robj = random.choice(testset)
        input_TEST_data[i] = robj["input"]
        target_TEST_output[i] = robj["output"]

    input_TEST_data = var(T.from_numpy(input_TEST_data)).type(T.float32)
    target_TEST_output = var(T.from_numpy(target_TEST_output)).type(T.float32)
    if rnn.debug:
      TEST_output, _, _, _ = rnn(input_TEST_data, (None, mhx, None), reset_experience=True, pass_through_memory=True, retOther=True)
    else:
      TEST_output, _, _ = rnn(input_TEST_data, (None, mhx, None), reset_experience=True, pass_through_memory=True, retOther=True)

    #print(TEST_output)
    MyTestloss = combLoss((TEST_output), target_TEST_output).item() # calculate test loss
    MyTestaccuracy = calcAccuracy(TEST_output, target_TEST_output)
    testlosses.append(MyTestloss)
    testaccuracy.append(MyTestaccuracy)
  Testloss = np.mean(testlosses) # calculate test loss mean
  Testaccuracy = np.mean(testaccuracy) # calculate test accuracy mean
  return Testloss, Testaccuracy
    

In [9]:
def generateTrainingData(sequenceMaxLen, st):
    Traindata = {}
    
    for seqLen in tqdm(range(1, sequenceMaxLen)):
        Traindata[seqLen] = []
        inputspace = outputformat*2**seqLen
        for i in range(inputspace):
            for j in range(inputspace):
                sequenceA = np.zeros((1,seqLen, 1))
                sequenceB = np.zeros((1,seqLen, 1))
                for k in range(seqLen):
                    sequenceA[0][k] = (i >> k) & 1
                    sequenceB[0][k] = (j >> k) & 1
                sumsequence = calcsum(sequenceA, sequenceB, maxval=outputformat, batch_size=1, length=seqLen)
                input_data = np.zeros((1, input_size, input_size))
                target_output = np.zeros((1, input_size, outputformat))
                input_data[0, 0:seqLen, 0:1] = sequenceA
                input_data[0, seqLen, 1] = ENDSYM
                input_data[0, seqLen+1:seqLen*2+1, 2:3] = sequenceB
                input_data[0, seqLen*2+1, 3] = ENDSYM

                if st.isSaved(input_data[0], flag="testData"):
                    continue
                

                offset = 0
                if NoneClassOutput:
                    offset = 1
                for j in reversed(range(1, sumsequence.shape[1]+1)):
                    target_output[0, -j, offset+int(sumsequence[0,-j])] = 1
                if NoneClassOutput:
                    for j in range(target_output.shape[1]):
                        if np.sum(target_output[0, j]) == 0:
                            target_output[0, j, 0] = 1
                Traindata[seqLen].append({"input": input_data, "target": target_output})
    return Traindata


def getTrainingData(storedTrainingData, sequenceMinLen, sequenceLen, shuffle=True, batchsize=100, outputbatchsize=100, factors=None):
    if not isinstance(factors, T.Tensor) and factors.size != (sequenceLen-sequenceMinLen) and any(factors < 0) and any(factors > 1):
        factors = T.ones(sequenceLen-sequenceMinLen)
        print("Invalid factors, using default factors", factors.size != (sequenceLen-sequenceMinLen), any(factors < 0), any(factors > 1))
    allTrainData = []
    for i in range(sequenceMinLen, sequenceLen):
        random.shuffle(storedTrainingData[i])
        choosesize = int(len(storedTrainingData[i]) * factors[i-sequenceMinLen].item())
        print("Sequence Length: ", i, " Choosesize: ", choosesize)
        allTrainData.extend(storedTrainingData[i][0:choosesize])

    print("Possible Training Data: ", len(allTrainData))
    if shuffle:
        random.shuffle(allTrainData)
    if batchsize is None:
        batchsize = len(allTrainData)
    input_data = np.zeros((outputbatchsize, input_size, input_size))
    target_output = np.zeros((outputbatchsize, input_size, outputformat))
    for i in range(batchsize):
        alltrainindex = i 
        if i >= outputbatchsize:
            break
        if i >= len(allTrainData):
            alltrainindex = random.randint(0, len(allTrainData)-1)
        input_data[i] = allTrainData[alltrainindex]["input"]
        target_output[i] = allTrainData[alltrainindex]["target"]
    return input_data, target_output

def zeroslike(arr):
    if isinstance(arr, T.Tensor):
        return T.zeros_like(arr)
    if isinstance(arr, np.ndarray):
        return np.zeros_like(arr)
    
def shuffleNPArrays(arr1, arr2):
    size0 = arr1.shape[0]
    usedindices = []
    unusedindices = list(range(size0))
    newarr1 = zeroslike(arr1)
    newarr2 = zeroslike(arr2)
    for i in range(size0):
        index = random.choice(unusedindices)
        #print(index)
        usedindices.append(index)
        #print(usedindices)
        unusedindices.remove(index)
        #print(unusedindices)
        newarr1[index] = arr1[i]
        newarr2[index] = arr2[i]
        #print(newarr1)
        #print(newarr2)
    return newarr1, newarr2






In [10]:
def testshufflenparrays():
    a = np.array([[1,2,3],[4,5,6],[7,8,9]])
    b = np.array([[10,20,30],[40,50,60],[70,80,90]])
    print(a)
    print(b)
    a, b = shuffleNPArrays(a, b)
    print(a)
    print(b)

#testshufflenparrays()

In [None]:
settings = { 
  "address_every_slot": False,
  "factadjust": T.Tensor([1, 0, 0, 0, 0, 0, 0]),
}
modifcations = False
if modifcations:
  settings["address_every_slot"] = True
  settings["factadjust"] = T.Tensor([10,1,1,1,1,1,1])

import datetime
import select

def mCELoss(output, target):
  return CELoss(output, target)

lossfn = [mCELoss]
nums_layers = [2]
nums_read_heads = [1,2,5]
comp_obj = []

for i, l in enumerate(lossfn):
  for j, mnum_layers in enumerate(nums_layers):
    for k, mread_heads in enumerate(nums_read_heads):
      rnn = DNC(
            input_size=input_size,
            hidden_size=output_size*3, #new *2
            output_size=outputformat, #new binary: 3 -> none, 0, 1
            rnn_type='lstm',
            num_layers=mnum_layers,
            num_hidden_layers=1, #1
            dropout=0,
            nr_cells=mem_slot,
            cell_size=mem_size,
            read_heads=mread_heads,
            gpu_id=-1,
            debug='store_true',
            batch_first=True,
            independent_linears=True,
            nonlinearity='celu', #tanh
            address_every_slot=settings["address_every_slot"]
        )

      testrnn = DNC(
        input_size=input_size,
        hidden_size=output_size*3, #new *2
        output_size=outputformat, #new binary: 3 -> none, 0, 1
        rnn_type='lstm',
        num_layers=mnum_layers,
        num_hidden_layers=1, #1
        dropout=0,
        nr_cells=mem_slot,
        cell_size=mem_size,
        read_heads=mread_heads,
        gpu_id=-1,
        debug='store_true',
        batch_first=True,
        independent_linears=True,
        nonlinearity='celu', #tanh
        address_every_slot=settings["address_every_slot"]
      )

      optimizer = optim.Adam(rnn.parameters(), lr=0.001, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
      comp_obj.append({"lossfn": l, 
                      "rnn": rnn, 
                      "testrnn": testrnn, 
                      "optimizer": optimizer, 
                      "chx": None, 
                      "mhx": None, 
                      "rv": None, 
                      "last_save_losses": [], 
                      "datas": [], 
                      "i": i+j*10+k*100,
                      "mnum_layers": mnum_layers,
                      "mread_heads": mread_heads,
                      })

for i in range(2, sequence_max_length,1): # generate test data
    inputdataspace = 2**i*dataoutputformat # 2 i bit sequences
    testdatasize = int(inputdataspace*0.05)+1 #5%
    input_data, target_output = generate_data(testdatasize, i, input_size)
    for i in range(testdatasize):
      st.saveInput(input_data[i], output=target_output[i], withoutIncrement=True, flag="testData") #saveData

storedTrainingData = generateTrainingData(sequence_max_length, st) # generate training data

NPinput_data, NPtarget_output = getTrainingData(storedTrainingData, 
                                                  1, sequence_length, 
                                                  shuffle=True, batchsize=None, 
                                                  outputbatchsize=batch_size, 
                                                  factors=T.Tensor([1,1,1,1,1,0.5,0.25])
                                                  )



for comp in comp_obj:
  curlossfn = comp["lossfn"]
  rnn = comp["rnn"]
  testrnn = comp["testrnn"]
  optimizer = comp["optimizer"]
  chx = comp["chx"]
  mhx = comp["mhx"]
  rv = comp["rv"]
  last_save_losses = comp["last_save_losses"]
  datas = comp["datas"]
  lossfni = comp["i"]
  mnum_layers = comp["mnum_layers"]
  mread_heads = comp["mread_heads"]
  with open(f'{name}/output.txt', 'a') as f:
    print(name)
    print(name, file=f)
    print("Loss Function: ", comp["lossfn"].__name__)
    print("Loss Function: ", comp["lossfn"].__name__, file=f)
    print("Number of Layers: ", mnum_layers)
    print("Number of Layers: ", mnum_layers, file=f)
    print("Number of Read Heads: ", mread_heads)
    print("Number of Read Heads: ", mread_heads, file=f)

    
    
    
    if loadcp != False:
      rnn.load_state_dict(T.load(loadcp, weights_only=True))
      rnn.eval()
      rnn.train()
    
    print(rnn)
    print(rnn, file=f)

    last_save_losses = []

    Testloss = 0 # loss of test data
    
    learnthisobjective = copy.deepcopy(LEARNTHISOBJECTIVES)
    for key in learnthisobjective.keys():
      learnthisobjective[key] = False #all False


    learnthisobjective["read_modes"] = False
    learnthisobjective["general_loss"] = True



    factors = CaclulateFactors(iterations, nofactors=7, justOnes=True, 
                              softmax=False, softmaxTemp=1*10**-3, sizeadjust=False, 
                              factadjust=settings["factadjust"]
                              )

    lastobjectivechange = 0
    optimizerdict = optimizer.state_dict()


    learnthisobjextivecounter = 0

    start_time = datetime.datetime.now()


    
    for epoch in tqdm(range(iterations + 1)):
      elapsed = datetime.datetime.now() - start_time
      rate = (elapsed.total_seconds() / (epoch + 1)) if epoch > 0 else 0
      remaining_time = datetime.timedelta(seconds=rate * (iterations + 1 - epoch - 1))
      finish_time = start_time + elapsed + remaining_time
      if epoch % 10 == 0:
        tqdm.write(f"Epoch {epoch}/{iterations+1} | ETA: {finish_time.strftime('%Y-%m-%d %H:%M:%S')}")



      summarize = (epoch % summarize_freq == 0)
      take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)


      input_data, target_output = shuffleNPArrays(NPinput_data, NPtarget_output)
      input_data = var(T.from_numpy(input_data)).type(T.float32)
      target_output = var(T.from_numpy(target_output)).type(T.float32)


      bucket = 2
      smallbatchsize = int(batch_size/bucket)

      loss = 0
      accuracy = 0
      currentlosses = (0,0,0,0,0,0,0)
      base = 0
      allocation_weight_loss = 0
      cwl = 0
      allocation_gate_loss = 0
      write_gate_loss = 0
      read_modes_loss = 0
      write_content_weights_loss = 0
      Rloss = 0
      partiallosses = T.zeros(7)
      currfactors = T.zeros(7)
      # lossmse = 0
      # lossce = 0
      # lossexp = 0
      # lossl1 = 0

      for i in range(bucket):
        optimizer.zero_grad()
        currentinput = input_data[i*smallbatchsize:(i+1)*smallbatchsize]
        currenttarget = target_output[i*smallbatchsize:(i+1)*smallbatchsize]
        if rnn.debug:
          currentoutput, (chx, mhx, rv), v, otherReturn = rnn(currentinput, (None, mhx, None), reset_experience=True, pass_through_memory=True, retOther=True)
        else:
          currentoutput, (chx, mhx, rv), otherReturn = rnn(currentinput, (None, mhx, None), reset_experience=True, pass_through_memory=True, retOther=True)
        currentloss = combLoss((currentoutput), currenttarget)
        currentaccuracy = calcAccuracy(currentoutput, currenttarget)
        loss += currentloss.item()
        accuracy += currentaccuracy

        # lossmse += mse(currentoutput, currenttarget).item()
        # lossce += mCELoss(currentoutput, currenttarget).item()
        # lossexp += exp_loss(currentoutput, currenttarget).item()
        # lossl1 += L1loss(currentoutput, currenttarget).item()


        currentlosses = lossfnwithReturnOther(currentoutput, currenttarget, otherReturn, 
                                              printlosses=False,#epoch % 13 == 0 and i == 0, 
                                              returnTuple=True, 
                                              learnthisobjective=learnthisobjective, 
                                              epoch=epoch, 
                                              speciallossfn=curlossfn,
                                              mnum_layers=mnum_layers
                                              )
        curbase, curallocation_weight_loss, curcwl, curallocation_gate_loss, curwrite_gate_loss, curread_modes_loss, curwrite_content_weights_loss = currentlosses
        base += curbase.detach() if isinstance(curbase, T.Tensor) else 0
        allocation_weight_loss += curallocation_weight_loss.detach() if isinstance(curallocation_weight_loss, T.Tensor) else 0
        cwl += curcwl.detach() if isinstance(curcwl, T.Tensor) else 0
        allocation_gate_loss += curallocation_gate_loss.detach() if isinstance(curallocation_gate_loss, T.Tensor) else 0
        write_gate_loss += curwrite_gate_loss.detach() if isinstance(curwrite_gate_loss, T.Tensor) else 0
        read_modes_loss += curread_modes_loss.detach() if isinstance(curread_modes_loss, T.Tensor) else 0
        write_content_weights_loss += curwrite_content_weights_loss.detach() if isinstance(curwrite_content_weights_loss, T.Tensor) else 0
        curcurrfactors = factors(epoch, currentlosses)
        curRloss = 0
        curpartiallosses = T.zeros(len(currentlosses))
        
        for i in range(curcurrfactors.shape[0]):
          if T.isclose(curcurrfactors[i],T.tensor(0.0)):
            continue
          curpartiallosses[i] = curcurrfactors[i] * currentlosses[i]
          curRloss += curcurrfactors[i] * currentlosses[i]

        currfactors += curcurrfactors
        Rloss += curRloss
        Rloss = Rloss.detach()
        partiallosses += curpartiallosses.detach()

        if np.isnan(curRloss.item()) or np.isinf(curRloss.item()) or np.isclose(curRloss.item(), 0):
          print("Loss is nan or inf or close to zero")
          continue
        curRloss.backward()
        T.nn.utils.clip_grad_norm_(rnn.parameters(), 30)
        optimizer.step()
        curRloss = curRloss.detach()
        mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }


      loss /= bucket
      accuracy /= bucket
      base /= bucket
      allocation_weight_loss /= bucket
      cwl /= bucket
      allocation_gate_loss /= bucket
      write_gate_loss /= bucket
      read_modes_loss /= bucket
      write_content_weights_loss /= bucket
      Rloss /= bucket
      partiallosses /= bucket
      currfactors /= bucket

      # lossmse /= bucket
      # lossce /= bucket
      # lossexp /= bucket
      # lossl1 /= bucket


      if epoch % summarize_freq == 0:
        currentweights= rnn.state_dict()
        testrnn.load_state_dict(currentweights)
        Testloss, Testaccuracy = CalcLossonValidationData(st, testrnn, None, batch_size, input_size) # mhx -> None
      
    
      

      if epoch == 10:
        for mybool in RETURNOTHEROBJ["bools"]:
          otherkey = [kc[1] for kc in RETURNOTHEROBJ["keycombs"] if kc[0] == mybool]
          #print("otherkey: ", otherkey)
          if len(otherkey) == 0:
            continue
          otherkey = otherkey[0]
          print("otherkey: ", otherkey)
          if otherReturn[otherkey] is not None and isinstance(otherReturn[otherkey], T.Tensor):
            print(otherReturn[otherkey].shape)


      
      
      partiallosses = partiallosses.detach().tolist()
      print(f"Epoch: {epoch}, Accuracy: {accuracy}, Loss: {loss}, weighted Loss: {Rloss.item()}, factors: {currfactors}, losses: {currentlosses}",file=f)

      #if summarize:
        #print("REAL Loss: ", Rloss.item())
        #print("Factors: ", currfactors)
        #print("losses: ", currentlosses)


      datas.append({
        "epoch": epoch, 
        "loss": loss, 
        "testloss": Testloss, 
        "sequencelength": sequence_length, 
        "accuracy": accuracy,
        "testaccuracy": Testaccuracy,
        "loss_base": base.item() if isinstance(base, T.Tensor) else 0,
        "loss_allocation_weight": allocation_weight_loss.item() if isinstance(allocation_weight_loss, T.Tensor) else 0,
        "loss_allocation_gate": allocation_gate_loss.item() if isinstance(allocation_gate_loss, T.Tensor) else 0,
        "loss_write_gate": write_gate_loss.item() if isinstance(write_gate_loss, T.Tensor) else 0,
        "loss_read_modes": read_modes_loss.item() if isinstance(read_modes_loss, T.Tensor) else 0,
        "loss_write_weights": write_content_weights_loss.item() if isinstance(write_content_weights_loss, T.Tensor) else 0,
        "loss_usage_vector": cwl.item() if isinstance(cwl, T.Tensor) else 0,
        "factor_base": currfactors[0].item(),
        "factor_allocation_weight": currfactors[1].item(),
        "factor_allocation_gate": currfactors[2].item(),
        "factor_write_gate": currfactors[3].item(),
        "factor_read_modes": currfactors[4].item(),
        "factor_write_weights": currfactors[5].item(),
        "factor_usage_vector": currfactors[6].item(),
        "factored_loss_base": partiallosses[0],
        "factored_loss_allocation_weight": partiallosses[1],
        "factored_loss_allocation_gate": partiallosses[2],
        "factored_loss_write_gate": partiallosses[3],
        "factored_loss_read_modes": partiallosses[4],
        "factored_loss_write_weights": partiallosses[5],
        "factored_loss_usage_vector": partiallosses[6],
        "factors": currfactors,
        "weighted_loss": Rloss.item(),
        # "loss_mse": lossmse,
        # "loss_ce": lossce,
        # "loss_exp": lossexp,
        # "loss_l1": lossl1,
        }) #append to the datas df



    
      loss_value = loss

      mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }

      last_save_losses.append(loss_value)
      loss = np.mean(last_save_losses)

      comp["chx"] = chx
      comp["mhx"] = mhx
      comp["rv"] = rv
      comp["last_save_losses"] = last_save_losses
      comp["datas"] = datas
      comp["rnn"] = rnn

      if take_checkpoint:
        cur_weights = rnn.state_dict()
        T.save(cur_weights, f'{name}/checkpoint_{lossfni}_{epoch}.pth')
        lastcp = f'{name}/checkpoint_{lossfni}_{epoch}.pth'
        df = pd.DataFrame(datas)
        pickle.dump(df, open(f"{name}/df__{lossfni}{epoch}.pkl", "wb"))


  




100%|██████████| 4/4 [00:01<00:00,  2.04it/s]


Sequence Length:  1  Choosesize:  16
Sequence Length:  2  Choosesize:  64
Sequence Length:  3  Choosesize:  256
Sequence Length:  4  Choosesize:  1024
Possible Training Data:  1360
add_5b6_comp_
Loss Function:  mCELoss
Number of Layers:  2
Number of Read Heads:  1

----------------------------------------
DNC(17, 192, num_layers=2, num_hidden_layers=1, nr_cells=20, read_heads=1, cell_size=1, nonlinearity=celu, independent_linears=True, debug=store_true)
DNC(
  (lstm_layer_0): LSTM(18, 192, batch_first=True)
  (lstm_layer_1): LSTM(193, 192, batch_first=True)
  (rnn_layer_memory_shared): Memory(
    (read_keys_transform): Linear(in_features=192, out_features=1, bias=True)
    (read_strengths_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_key_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_strength_transform): Linear(in_features=192, out_features=1, bias=True)
    (erase_vector_transform): Linear(in_features=192, out_features=1, bias=Tr

  0%|          | 0/1001 [00:00<?, ?it/s]

Epoch 0/1001 | ETA: 2025-02-27 15:56:54


  1%|          | 10/1001 [01:13<1:44:27,  6.32s/it]

Epoch 10/1001 | ETA: 2025-02-27 17:48:44


  1%|          | 11/1001 [01:22<1:55:37,  7.01s/it]

otherkey:  allocation_weights
torch.Size([50, 34, 1, 20])
otherkey:  allocation_gate
torch.Size([50, 34, 1])
otherkey:  write_gate
torch.Size([50, 34, 1])
otherkey:  write_weights
torch.Size([50, 34, 1, 20])
otherkey:  read_modes
torch.Size([50, 34, 1, 3])
otherkey:  read_weights
torch.Size([50, 34, 1, 20])
otherkey:  free_gates
torch.Size([50, 34, 1])
otherkey:  erase_vector
torch.Size([50, 34, 1, 1])
otherkey:  write_vector
torch.Size([50, 34, 1, 1])
otherkey:  usage_vector
torch.Size([50, 34, 20])


  2%|▏         | 20/1001 [02:19<1:46:12,  6.50s/it]

Epoch 20/1001 | ETA: 2025-02-27 17:47:41


  3%|▎         | 30/1001 [03:28<1:42:26,  6.33s/it]

Epoch 30/1001 | ETA: 2025-02-27 17:49:02


  4%|▍         | 40/1001 [04:31<1:38:05,  6.12s/it]

Epoch 40/1001 | ETA: 2025-02-27 17:47:20


  5%|▍         | 50/1001 [05:38<1:42:22,  6.46s/it]

Epoch 50/1001 | ETA: 2025-02-27 17:47:45


  6%|▌         | 60/1001 [06:48<1:38:44,  6.30s/it]

Epoch 60/1001 | ETA: 2025-02-27 17:48:34


  7%|▋         | 70/1001 [07:52<1:34:48,  6.11s/it]

Epoch 70/1001 | ETA: 2025-02-27 17:47:57


  8%|▊         | 80/1001 [08:56<1:33:34,  6.10s/it]

Epoch 80/1001 | ETA: 2025-02-27 17:47:22


  9%|▉         | 90/1001 [10:02<1:36:06,  6.33s/it]

Epoch 90/1001 | ETA: 2025-02-27 17:47:25


 10%|▉         | 100/1001 [11:10<1:40:33,  6.70s/it]

Epoch 100/1001 | ETA: 2025-02-27 17:47:36


 11%|█         | 110/1001 [12:18<1:36:18,  6.49s/it]

Epoch 110/1001 | ETA: 2025-02-27 17:47:53


 12%|█▏        | 120/1001 [13:22<1:29:27,  6.09s/it]

Epoch 120/1001 | ETA: 2025-02-27 17:47:33


 13%|█▎        | 130/1001 [14:25<1:29:07,  6.14s/it]

Epoch 130/1001 | ETA: 2025-02-27 17:47:10


 14%|█▍        | 140/1001 [15:30<1:32:27,  6.44s/it]

Epoch 140/1001 | ETA: 2025-02-27 17:47:00


 15%|█▍        | 150/1001 [16:48<1:41:23,  7.15s/it]

Epoch 150/1001 | ETA: 2025-02-27 17:48:22


 16%|█▌        | 160/1001 [17:59<1:29:22,  6.38s/it]

Epoch 160/1001 | ETA: 2025-02-27 17:48:45


 17%|█▋        | 170/1001 [19:01<1:21:19,  5.87s/it]

Epoch 170/1001 | ETA: 2025-02-27 17:48:16


 18%|█▊        | 180/1001 [20:06<1:21:51,  5.98s/it]

Epoch 180/1001 | ETA: 2025-02-27 17:48:05


 19%|█▉        | 190/1001 [21:15<1:28:00,  6.51s/it]

Epoch 190/1001 | ETA: 2025-02-27 17:48:18


 20%|█▉        | 200/1001 [22:18<1:19:40,  5.97s/it]

Epoch 200/1001 | ETA: 2025-02-27 17:48:00


 21%|██        | 210/1001 [23:25<1:20:46,  6.13s/it]

Epoch 210/1001 | ETA: 2025-02-27 17:48:01


 22%|██▏       | 220/1001 [24:29<1:21:02,  6.23s/it]

Epoch 220/1001 | ETA: 2025-02-27 17:47:50


 23%|██▎       | 230/1001 [25:35<1:22:28,  6.42s/it]

Epoch 230/1001 | ETA: 2025-02-27 17:47:47


 24%|██▍       | 240/1001 [26:44<1:21:59,  6.46s/it]

Epoch 240/1001 | ETA: 2025-02-27 17:47:59


 25%|██▍       | 250/1001 [27:58<1:25:14,  6.81s/it]

Epoch 250/1001 | ETA: 2025-02-27 17:48:29


 26%|██▌       | 260/1001 [29:04<1:16:51,  6.22s/it]

Epoch 260/1001 | ETA: 2025-02-27 17:48:25


 27%|██▋       | 270/1001 [30:09<1:17:15,  6.34s/it]

Epoch 270/1001 | ETA: 2025-02-27 17:48:19


 28%|██▊       | 280/1001 [31:16<1:15:16,  6.26s/it]

Epoch 280/1001 | ETA: 2025-02-27 17:48:18


 29%|██▉       | 290/1001 [32:20<1:12:54,  6.15s/it]

Epoch 290/1001 | ETA: 2025-02-27 17:48:09


 30%|██▉       | 300/1001 [33:25<1:11:11,  6.09s/it]

Epoch 300/1001 | ETA: 2025-02-27 17:48:05


 31%|███       | 310/1001 [34:32<1:12:34,  6.30s/it]

Epoch 310/1001 | ETA: 2025-02-27 17:48:04


 32%|███▏      | 320/1001 [35:37<1:12:41,  6.40s/it]

Epoch 320/1001 | ETA: 2025-02-27 17:47:59


 33%|███▎      | 330/1001 [36:43<1:10:08,  6.27s/it]

Epoch 330/1001 | ETA: 2025-02-27 17:47:58


 34%|███▍      | 340/1001 [37:48<1:09:29,  6.31s/it]

Epoch 340/1001 | ETA: 2025-02-27 17:47:53


 35%|███▍      | 350/1001 [38:52<1:05:24,  6.03s/it]

Epoch 350/1001 | ETA: 2025-02-27 17:47:46


 36%|███▌      | 360/1001 [39:57<1:04:54,  6.08s/it]

Epoch 360/1001 | ETA: 2025-02-27 17:47:42


 37%|███▋      | 370/1001 [41:05<1:07:28,  6.42s/it]

Epoch 370/1001 | ETA: 2025-02-27 17:47:47


 38%|███▊      | 380/1001 [42:10<1:06:49,  6.46s/it]

Epoch 380/1001 | ETA: 2025-02-27 17:47:43


 39%|███▉      | 390/1001 [43:15<1:02:16,  6.12s/it]

Epoch 390/1001 | ETA: 2025-02-27 17:47:38


 40%|███▉      | 400/1001 [44:19<1:03:10,  6.31s/it]

Epoch 400/1001 | ETA: 2025-02-27 17:47:34


 41%|████      | 410/1001 [45:28<1:06:27,  6.75s/it]

Epoch 410/1001 | ETA: 2025-02-27 17:47:39


 42%|████▏     | 420/1001 [46:32<58:40,  6.06s/it]  

Epoch 420/1001 | ETA: 2025-02-27 17:47:33


 43%|████▎     | 430/1001 [47:37<58:49,  6.18s/it]  

Epoch 430/1001 | ETA: 2025-02-27 17:47:30


 44%|████▍     | 440/1001 [48:41<56:57,  6.09s/it]  

Epoch 440/1001 | ETA: 2025-02-27 17:47:25


 45%|████▍     | 450/1001 [49:47<57:09,  6.22s/it]  

Epoch 450/1001 | ETA: 2025-02-27 17:47:25


 46%|████▌     | 460/1001 [50:55<55:44,  6.18s/it]  

Epoch 460/1001 | ETA: 2025-02-27 17:47:29


 47%|████▋     | 470/1001 [52:00<54:48,  6.19s/it]  

Epoch 470/1001 | ETA: 2025-02-27 17:47:26


 48%|████▊     | 480/1001 [53:04<54:46,  6.31s/it]  

Epoch 480/1001 | ETA: 2025-02-27 17:47:21


 49%|████▉     | 490/1001 [54:09<52:27,  6.16s/it]  

Epoch 490/1001 | ETA: 2025-02-27 17:47:18


 50%|████▉     | 500/1001 [55:12<52:19,  6.27s/it]

Epoch 500/1001 | ETA: 2025-02-27 17:47:13


 51%|█████     | 510/1001 [56:18<50:01,  6.11s/it]  

Epoch 510/1001 | ETA: 2025-02-27 17:47:12


 52%|█████▏    | 520/1001 [57:23<50:53,  6.35s/it]

Epoch 520/1001 | ETA: 2025-02-27 17:47:11


 53%|█████▎    | 530/1001 [58:27<47:04,  6.00s/it]

Epoch 530/1001 | ETA: 2025-02-27 17:47:05


 54%|█████▍    | 540/1001 [59:32<48:11,  6.27s/it]

Epoch 540/1001 | ETA: 2025-02-27 17:47:05


 55%|█████▍    | 550/1001 [1:00:42<47:33,  6.33s/it]

Epoch 550/1001 | ETA: 2025-02-27 17:47:12


 56%|█████▌    | 560/1001 [1:01:45<44:15,  6.02s/it]

Epoch 560/1001 | ETA: 2025-02-27 17:47:07


 57%|█████▋    | 570/1001 [1:02:42<39:12,  5.46s/it]

Epoch 570/1001 | ETA: 2025-02-27 17:46:51


 58%|█████▊    | 580/1001 [1:03:38<37:54,  5.40s/it]

Epoch 580/1001 | ETA: 2025-02-27 17:46:33


 59%|█████▉    | 590/1001 [1:04:35<37:41,  5.50s/it]

Epoch 590/1001 | ETA: 2025-02-27 17:46:18


 60%|█████▉    | 600/1001 [1:05:32<36:19,  5.44s/it]

Epoch 600/1001 | ETA: 2025-02-27 17:46:04


 61%|██████    | 610/1001 [1:06:28<34:57,  5.36s/it]

Epoch 610/1001 | ETA: 2025-02-27 17:45:48


 62%|██████▏   | 620/1001 [1:07:24<34:07,  5.37s/it]

Epoch 620/1001 | ETA: 2025-02-27 17:45:33


 63%|██████▎   | 630/1001 [1:08:19<32:51,  5.31s/it]

Epoch 630/1001 | ETA: 2025-02-27 17:45:18


 64%|██████▍   | 640/1001 [1:09:16<32:39,  5.43s/it]

Epoch 640/1001 | ETA: 2025-02-27 17:45:06


 65%|██████▍   | 650/1001 [1:10:13<31:48,  5.44s/it]

Epoch 650/1001 | ETA: 2025-02-27 17:44:52


 66%|██████▌   | 660/1001 [1:11:09<30:29,  5.36s/it]

Epoch 660/1001 | ETA: 2025-02-27 17:44:40


 67%|██████▋   | 670/1001 [1:12:05<29:21,  5.32s/it]

Epoch 670/1001 | ETA: 2025-02-27 17:44:26


 68%|██████▊   | 680/1001 [1:13:01<29:03,  5.43s/it]

Epoch 680/1001 | ETA: 2025-02-27 17:44:15


 69%|██████▉   | 690/1001 [1:13:57<27:44,  5.35s/it]

Epoch 690/1001 | ETA: 2025-02-27 17:44:03


 70%|██████▉   | 700/1001 [1:14:53<26:47,  5.34s/it]

Epoch 700/1001 | ETA: 2025-02-27 17:43:50


 71%|███████   | 710/1001 [1:15:52<27:34,  5.69s/it]

Epoch 710/1001 | ETA: 2025-02-27 17:43:43


 72%|███████▏  | 720/1001 [1:16:49<25:48,  5.51s/it]

Epoch 720/1001 | ETA: 2025-02-27 17:43:33


 73%|███████▎  | 730/1001 [1:17:46<24:31,  5.43s/it]

Epoch 730/1001 | ETA: 2025-02-27 17:43:24


 74%|███████▍  | 740/1001 [1:18:42<23:42,  5.45s/it]

Epoch 740/1001 | ETA: 2025-02-27 17:43:14


 75%|███████▍  | 750/1001 [1:19:39<22:33,  5.39s/it]

Epoch 750/1001 | ETA: 2025-02-27 17:43:04


 76%|███████▌  | 760/1001 [1:20:35<21:53,  5.45s/it]

Epoch 760/1001 | ETA: 2025-02-27 17:42:55


 77%|███████▋  | 770/1001 [1:21:31<20:43,  5.38s/it]

Epoch 770/1001 | ETA: 2025-02-27 17:42:45


 78%|███████▊  | 780/1001 [1:22:28<20:15,  5.50s/it]

Epoch 780/1001 | ETA: 2025-02-27 17:42:37


 79%|███████▉  | 790/1001 [1:23:24<18:51,  5.36s/it]

Epoch 790/1001 | ETA: 2025-02-27 17:42:27


 80%|███████▉  | 800/1001 [1:24:21<18:20,  5.48s/it]

Epoch 800/1001 | ETA: 2025-02-27 17:42:20


 81%|████████  | 810/1001 [1:25:18<18:11,  5.71s/it]

Epoch 810/1001 | ETA: 2025-02-27 17:42:12


 82%|████████▏ | 820/1001 [1:26:15<16:14,  5.38s/it]

Epoch 820/1001 | ETA: 2025-02-27 17:42:05


 83%|████████▎ | 830/1001 [1:27:13<15:33,  5.46s/it]

Epoch 830/1001 | ETA: 2025-02-27 17:41:58


 84%|████████▍ | 840/1001 [1:28:09<14:26,  5.38s/it]

Epoch 840/1001 | ETA: 2025-02-27 17:41:50


 85%|████████▍ | 850/1001 [1:29:06<13:32,  5.38s/it]

Epoch 850/1001 | ETA: 2025-02-27 17:41:42


 86%|████████▌ | 860/1001 [1:30:05<13:05,  5.57s/it]

Epoch 860/1001 | ETA: 2025-02-27 17:41:39


 87%|████████▋ | 870/1001 [1:31:01<11:53,  5.45s/it]

Epoch 870/1001 | ETA: 2025-02-27 17:41:31


 88%|████████▊ | 880/1001 [1:31:59<10:57,  5.44s/it]

Epoch 880/1001 | ETA: 2025-02-27 17:41:25


 89%|████████▉ | 890/1001 [1:32:57<10:09,  5.49s/it]

Epoch 890/1001 | ETA: 2025-02-27 17:41:20


 90%|████████▉ | 900/1001 [1:33:54<09:28,  5.63s/it]

Epoch 900/1001 | ETA: 2025-02-27 17:41:14


 91%|█████████ | 910/1001 [1:34:52<08:17,  5.47s/it]

Epoch 910/1001 | ETA: 2025-02-27 17:41:09


 92%|█████████▏| 920/1001 [1:35:49<07:24,  5.49s/it]

Epoch 920/1001 | ETA: 2025-02-27 17:41:03


 93%|█████████▎| 930/1001 [1:36:46<06:35,  5.58s/it]

Epoch 930/1001 | ETA: 2025-02-27 17:40:57


 94%|█████████▍| 940/1001 [1:37:44<05:43,  5.63s/it]

Epoch 940/1001 | ETA: 2025-02-27 17:40:52


 95%|█████████▍| 950/1001 [1:38:40<04:34,  5.38s/it]

Epoch 950/1001 | ETA: 2025-02-27 17:40:46


 96%|█████████▌| 960/1001 [1:39:38<03:51,  5.64s/it]

Epoch 960/1001 | ETA: 2025-02-27 17:40:42


 97%|█████████▋| 970/1001 [1:40:36<03:00,  5.82s/it]

Epoch 970/1001 | ETA: 2025-02-27 17:40:37


 98%|█████████▊| 980/1001 [1:41:41<02:05,  5.96s/it]

Epoch 980/1001 | ETA: 2025-02-27 17:40:40


 99%|█████████▉| 990/1001 [1:42:37<00:59,  5.40s/it]

Epoch 990/1001 | ETA: 2025-02-27 17:40:33


100%|█████████▉| 1000/1001 [1:43:31<00:05,  5.24s/it]

Epoch 1000/1001 | ETA: 2025-02-27 17:40:26


100%|██████████| 1001/1001 [1:43:39<00:00,  6.21s/it]


add_5b6_comp_
Loss Function:  mCELoss
Number of Layers:  3
Number of Read Heads:  1

----------------------------------------
DNC(17, 192, num_layers=3, num_hidden_layers=1, nr_cells=20, read_heads=1, cell_size=1, nonlinearity=celu, independent_linears=True, debug=store_true)
DNC(
  (lstm_layer_0): LSTM(18, 192, batch_first=True)
  (lstm_layer_1): LSTM(193, 192, batch_first=True)
  (lstm_layer_2): LSTM(193, 192, batch_first=True)
  (rnn_layer_memory_shared): Memory(
    (read_keys_transform): Linear(in_features=192, out_features=1, bias=True)
    (read_strengths_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_key_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_strength_transform): Linear(in_features=192, out_features=1, bias=True)
    (erase_vector_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_vector_transform): Linear(in_features=192, out_features=1, bias=True)
    (free_gates_transform): Linear(in_feature

  0%|          | 0/1001 [00:00<?, ?it/s]

Epoch 0/1001 | ETA: 2025-02-27 17:40:34


  1%|          | 10/1001 [01:21<2:09:37,  7.85s/it]

Epoch 10/1001 | ETA: 2025-02-27 19:43:31


  1%|          | 11/1001 [01:32<2:27:04,  8.91s/it]

otherkey:  allocation_weights
torch.Size([50, 51, 1, 20])
otherkey:  allocation_gate
torch.Size([50, 51, 1])
otherkey:  write_gate
torch.Size([50, 51, 1])
otherkey:  write_weights
torch.Size([50, 51, 1, 20])
otherkey:  read_modes
torch.Size([50, 51, 1, 3])
otherkey:  read_weights
torch.Size([50, 51, 1, 20])
otherkey:  free_gates
torch.Size([50, 51, 1])
otherkey:  erase_vector
torch.Size([50, 51, 1, 1])
otherkey:  write_vector
torch.Size([50, 51, 1, 1])
otherkey:  usage_vector
torch.Size([50, 51, 20])


  2%|▏         | 20/1001 [02:40<2:06:21,  7.73s/it]

Epoch 20/1001 | ETA: 2025-02-27 19:48:27


  3%|▎         | 30/1001 [04:03<2:05:45,  7.77s/it]

Epoch 30/1001 | ETA: 2025-02-27 19:51:22


  4%|▍         | 40/1001 [05:23<2:03:57,  7.74s/it]

Epoch 40/1001 | ETA: 2025-02-27 19:52:17


  5%|▍         | 50/1001 [06:45<2:05:10,  7.90s/it]

Epoch 50/1001 | ETA: 2025-02-27 19:53:09


  6%|▌         | 60/1001 [08:06<2:01:48,  7.77s/it]

Epoch 60/1001 | ETA: 2025-02-27 19:53:43


  7%|▋         | 70/1001 [09:27<1:59:17,  7.69s/it]

Epoch 70/1001 | ETA: 2025-02-27 19:53:49


  8%|▊         | 80/1001 [10:47<2:00:29,  7.85s/it]

Epoch 80/1001 | ETA: 2025-02-27 19:54:00


  9%|▉         | 90/1001 [12:08<1:58:00,  7.77s/it]

Epoch 90/1001 | ETA: 2025-02-27 19:54:09


 10%|▉         | 100/1001 [13:30<1:58:27,  7.89s/it]

Epoch 100/1001 | ETA: 2025-02-27 19:54:30


 11%|█         | 110/1001 [14:52<1:56:56,  7.87s/it]

Epoch 110/1001 | ETA: 2025-02-27 19:54:43


 12%|█▏        | 120/1001 [16:14<1:53:38,  7.74s/it]

Epoch 120/1001 | ETA: 2025-02-27 19:54:54


 13%|█▎        | 130/1001 [17:36<1:55:54,  7.98s/it]

Epoch 130/1001 | ETA: 2025-02-27 19:55:07


 14%|█▍        | 140/1001 [19:00<1:56:50,  8.14s/it]

Epoch 140/1001 | ETA: 2025-02-27 19:55:31


 15%|█▍        | 150/1001 [20:22<1:50:58,  7.82s/it]

Epoch 150/1001 | ETA: 2025-02-27 19:55:36


 16%|█▌        | 160/1001 [21:44<1:52:55,  8.06s/it]

Epoch 160/1001 | ETA: 2025-02-27 19:55:46


 17%|█▋        | 170/1001 [23:07<1:51:24,  8.04s/it]

Epoch 170/1001 | ETA: 2025-02-27 19:55:56


 18%|█▊        | 180/1001 [24:29<1:49:36,  8.01s/it]

Epoch 180/1001 | ETA: 2025-02-27 19:56:02


 19%|█▉        | 190/1001 [25:51<1:46:38,  7.89s/it]

Epoch 190/1001 | ETA: 2025-02-27 19:56:05


 20%|█▉        | 200/1001 [27:13<1:46:24,  7.97s/it]

Epoch 200/1001 | ETA: 2025-02-27 19:56:09


 21%|██        | 210/1001 [28:35<1:40:54,  7.65s/it]

Epoch 210/1001 | ETA: 2025-02-27 19:56:10


 22%|██▏       | 220/1001 [29:56<1:40:56,  7.76s/it]

Epoch 220/1001 | ETA: 2025-02-27 19:56:10


 23%|██▎       | 230/1001 [31:17<1:42:09,  7.95s/it]

Epoch 230/1001 | ETA: 2025-02-27 19:56:12


 24%|██▍       | 240/1001 [32:39<1:37:49,  7.71s/it]

Epoch 240/1001 | ETA: 2025-02-27 19:56:12


 25%|██▍       | 250/1001 [34:02<1:39:44,  7.97s/it]

Epoch 250/1001 | ETA: 2025-02-27 19:56:17


 26%|██▌       | 260/1001 [35:23<1:38:35,  7.98s/it]

Epoch 260/1001 | ETA: 2025-02-27 19:56:19


 27%|██▋       | 270/1001 [36:45<1:36:58,  7.96s/it]

Epoch 270/1001 | ETA: 2025-02-27 19:56:21


 28%|██▊       | 280/1001 [38:07<1:34:50,  7.89s/it]

Epoch 280/1001 | ETA: 2025-02-27 19:56:22


 29%|██▉       | 290/1001 [39:29<1:35:17,  8.04s/it]

Epoch 290/1001 | ETA: 2025-02-27 19:56:25


 30%|██▉       | 300/1001 [40:51<1:31:58,  7.87s/it]

Epoch 300/1001 | ETA: 2025-02-27 19:56:26


 31%|███       | 310/1001 [42:12<1:29:47,  7.80s/it]

Epoch 310/1001 | ETA: 2025-02-27 19:56:26


 32%|███▏      | 320/1001 [43:35<1:29:47,  7.91s/it]

Epoch 320/1001 | ETA: 2025-02-27 19:56:29


 33%|███▎      | 330/1001 [44:56<1:27:11,  7.80s/it]

Epoch 330/1001 | ETA: 2025-02-27 19:56:29


 34%|███▍      | 340/1001 [46:18<1:26:06,  7.82s/it]

Epoch 340/1001 | ETA: 2025-02-27 19:56:30


 35%|███▍      | 350/1001 [47:41<1:25:56,  7.92s/it]

Epoch 350/1001 | ETA: 2025-02-27 19:56:33


 36%|███▌      | 360/1001 [49:02<1:23:10,  7.79s/it]

Epoch 360/1001 | ETA: 2025-02-27 19:56:33


 37%|███▋      | 370/1001 [50:23<1:21:32,  7.75s/it]

Epoch 370/1001 | ETA: 2025-02-27 19:56:31


 38%|███▊      | 380/1001 [51:45<1:22:39,  7.99s/it]

Epoch 380/1001 | ETA: 2025-02-27 19:56:34


 39%|███▉      | 390/1001 [53:06<1:17:53,  7.65s/it]

Epoch 390/1001 | ETA: 2025-02-27 19:56:31


 40%|███▉      | 400/1001 [54:27<1:17:45,  7.76s/it]

Epoch 400/1001 | ETA: 2025-02-27 19:56:30


 41%|████      | 410/1001 [55:50<1:19:37,  8.08s/it]

Epoch 410/1001 | ETA: 2025-02-27 19:56:33


 42%|████▏     | 420/1001 [57:10<1:14:18,  7.67s/it]

Epoch 420/1001 | ETA: 2025-02-27 19:56:30


 43%|████▎     | 430/1001 [58:31<1:15:31,  7.94s/it]

Epoch 430/1001 | ETA: 2025-02-27 19:56:29


 44%|████▍     | 440/1001 [59:53<1:14:07,  7.93s/it]

Epoch 440/1001 | ETA: 2025-02-27 19:56:30


 45%|████▍     | 450/1001 [1:01:14<1:11:16,  7.76s/it]

Epoch 450/1001 | ETA: 2025-02-27 19:56:29


 46%|████▌     | 460/1001 [1:02:37<1:11:24,  7.92s/it]

Epoch 460/1001 | ETA: 2025-02-27 19:56:32


 47%|████▋     | 470/1001 [1:03:59<1:10:44,  7.99s/it]

Epoch 470/1001 | ETA: 2025-02-27 19:56:34


 48%|████▊     | 480/1001 [1:05:21<1:07:58,  7.83s/it]

Epoch 480/1001 | ETA: 2025-02-27 19:56:35


 49%|████▉     | 490/1001 [1:06:44<1:07:55,  7.98s/it]

Epoch 490/1001 | ETA: 2025-02-27 19:56:38


 50%|████▉     | 500/1001 [1:08:08<1:06:47,  8.00s/it]

Epoch 500/1001 | ETA: 2025-02-27 19:56:42


 51%|█████     | 510/1001 [1:09:30<1:05:17,  7.98s/it]

Epoch 510/1001 | ETA: 2025-02-27 19:56:43


 52%|█████▏    | 520/1001 [1:10:51<1:03:27,  7.92s/it]

Epoch 520/1001 | ETA: 2025-02-27 19:56:43


 53%|█████▎    | 530/1001 [1:12:14<1:02:33,  7.97s/it]

Epoch 530/1001 | ETA: 2025-02-27 19:56:46


 54%|█████▍    | 540/1001 [1:13:38<1:00:55,  7.93s/it]

Epoch 540/1001 | ETA: 2025-02-27 19:56:49


 55%|█████▍    | 550/1001 [1:15:00<58:57,  7.84s/it]  

Epoch 550/1001 | ETA: 2025-02-27 19:56:51


 56%|█████▌    | 560/1001 [1:16:23<57:43,  7.85s/it]  

Epoch 560/1001 | ETA: 2025-02-27 19:56:52


 57%|█████▋    | 570/1001 [1:17:46<57:42,  8.03s/it]  

Epoch 570/1001 | ETA: 2025-02-27 19:56:55


 58%|█████▊    | 580/1001 [1:19:09<55:39,  7.93s/it]  

Epoch 580/1001 | ETA: 2025-02-27 19:56:56


 59%|█████▉    | 590/1001 [1:20:31<53:43,  7.84s/it]  

Epoch 590/1001 | ETA: 2025-02-27 19:56:56


 60%|█████▉    | 600/1001 [1:21:53<53:17,  7.97s/it]  

Epoch 600/1001 | ETA: 2025-02-27 19:56:57


 61%|██████    | 610/1001 [1:23:15<51:18,  7.87s/it]  

Epoch 610/1001 | ETA: 2025-02-27 19:56:58


 62%|██████▏   | 620/1001 [1:24:38<50:30,  7.96s/it]

Epoch 620/1001 | ETA: 2025-02-27 19:57:00


 63%|██████▎   | 630/1001 [1:26:00<48:29,  7.84s/it]

Epoch 630/1001 | ETA: 2025-02-27 19:57:00


 64%|██████▍   | 640/1001 [1:27:22<47:14,  7.85s/it]

Epoch 640/1001 | ETA: 2025-02-27 19:57:00


 65%|██████▍   | 650/1001 [1:28:44<45:38,  7.80s/it]

Epoch 650/1001 | ETA: 2025-02-27 19:57:01


 66%|██████▌   | 660/1001 [1:30:07<45:08,  7.94s/it]

Epoch 660/1001 | ETA: 2025-02-27 19:57:03


 67%|██████▋   | 670/1001 [1:31:29<44:13,  8.02s/it]

Epoch 670/1001 | ETA: 2025-02-27 19:57:04


 68%|██████▊   | 680/1001 [1:32:52<42:13,  7.89s/it]

Epoch 680/1001 | ETA: 2025-02-27 19:57:04


 69%|██████▉   | 690/1001 [1:34:13<41:00,  7.91s/it]

Epoch 690/1001 | ETA: 2025-02-27 19:57:04


 70%|██████▉   | 700/1001 [1:35:35<39:11,  7.81s/it]

Epoch 700/1001 | ETA: 2025-02-27 19:57:05


 71%|███████   | 710/1001 [1:36:58<37:54,  7.81s/it]

Epoch 710/1001 | ETA: 2025-02-27 19:57:05


 72%|███████▏  | 720/1001 [1:38:19<36:20,  7.76s/it]

Epoch 720/1001 | ETA: 2025-02-27 19:57:04


 73%|███████▎  | 730/1001 [1:39:41<35:30,  7.86s/it]

Epoch 730/1001 | ETA: 2025-02-27 19:57:04


 74%|███████▍  | 740/1001 [1:41:02<33:52,  7.79s/it]

Epoch 740/1001 | ETA: 2025-02-27 19:57:04


 75%|███████▍  | 750/1001 [1:42:24<32:40,  7.81s/it]

Epoch 750/1001 | ETA: 2025-02-27 19:57:04


 76%|███████▌  | 760/1001 [1:43:47<31:45,  7.91s/it]

Epoch 760/1001 | ETA: 2025-02-27 19:57:06


 77%|███████▋  | 770/1001 [1:45:10<30:22,  7.89s/it]

Epoch 770/1001 | ETA: 2025-02-27 19:57:07


 78%|███████▊  | 780/1001 [1:46:33<28:48,  7.82s/it]

Epoch 780/1001 | ETA: 2025-02-27 19:57:08


 79%|███████▉  | 790/1001 [1:47:55<27:47,  7.90s/it]

Epoch 790/1001 | ETA: 2025-02-27 19:57:09


 80%|███████▉  | 800/1001 [1:49:18<26:21,  7.87s/it]

Epoch 800/1001 | ETA: 2025-02-27 19:57:09


 81%|████████  | 810/1001 [1:50:40<25:19,  7.96s/it]

Epoch 810/1001 | ETA: 2025-02-27 19:57:10


 82%|████████▏ | 820/1001 [1:52:03<24:10,  8.01s/it]

Epoch 820/1001 | ETA: 2025-02-27 19:57:11


 83%|████████▎ | 830/1001 [1:53:27<22:23,  7.85s/it]

Epoch 830/1001 | ETA: 2025-02-27 19:57:14


 84%|████████▍ | 840/1001 [1:54:49<20:59,  7.83s/it]

Epoch 840/1001 | ETA: 2025-02-27 19:57:14


 85%|████████▍ | 850/1001 [1:56:10<19:23,  7.71s/it]

Epoch 850/1001 | ETA: 2025-02-27 19:57:13


 86%|████████▌ | 860/1001 [1:57:33<18:24,  7.83s/it]

Epoch 860/1001 | ETA: 2025-02-27 19:57:15


 87%|████████▋ | 870/1001 [1:58:55<17:11,  7.88s/it]

Epoch 870/1001 | ETA: 2025-02-27 19:57:15


 88%|████████▊ | 880/1001 [2:00:17<16:07,  7.99s/it]

Epoch 880/1001 | ETA: 2025-02-27 19:57:15


 89%|████████▉ | 890/1001 [2:01:38<14:15,  7.70s/it]

Epoch 890/1001 | ETA: 2025-02-27 19:57:14


 90%|████████▉ | 900/1001 [2:03:01<13:13,  7.85s/it]

Epoch 900/1001 | ETA: 2025-02-27 19:57:14


 91%|█████████ | 910/1001 [2:04:23<11:54,  7.86s/it]

Epoch 910/1001 | ETA: 2025-02-27 19:57:14


 92%|█████████▏| 920/1001 [2:05:45<10:58,  8.13s/it]

Epoch 920/1001 | ETA: 2025-02-27 19:57:15


 93%|█████████▎| 930/1001 [2:07:07<09:23,  7.94s/it]

Epoch 930/1001 | ETA: 2025-02-27 19:57:15


 94%|█████████▍| 940/1001 [2:08:28<07:56,  7.81s/it]

Epoch 940/1001 | ETA: 2025-02-27 19:57:14


 95%|█████████▍| 950/1001 [2:09:50<06:39,  7.83s/it]

Epoch 950/1001 | ETA: 2025-02-27 19:57:14


 96%|█████████▌| 960/1001 [2:11:12<05:20,  7.82s/it]

Epoch 960/1001 | ETA: 2025-02-27 19:57:14


 97%|█████████▋| 970/1001 [2:12:34<04:04,  7.90s/it]

Epoch 970/1001 | ETA: 2025-02-27 19:57:14


 98%|█████████▊| 980/1001 [2:13:55<02:43,  7.78s/it]

Epoch 980/1001 | ETA: 2025-02-27 19:57:13


 99%|█████████▉| 990/1001 [2:15:16<01:25,  7.75s/it]

Epoch 990/1001 | ETA: 2025-02-27 19:57:12


100%|█████████▉| 1000/1001 [2:16:37<00:07,  7.76s/it]

Epoch 1000/1001 | ETA: 2025-02-27 19:57:11


100%|██████████| 1001/1001 [2:16:49<00:00,  8.20s/it]


add_5b6_comp_
Loss Function:  mCELoss
Number of Layers:  5
Number of Read Heads:  1

----------------------------------------
DNC(17, 192, num_layers=5, num_hidden_layers=1, nr_cells=20, read_heads=1, cell_size=1, nonlinearity=celu, independent_linears=True, debug=store_true)
DNC(
  (lstm_layer_0): LSTM(18, 192, batch_first=True)
  (lstm_layer_1): LSTM(193, 192, batch_first=True)
  (lstm_layer_2): LSTM(193, 192, batch_first=True)
  (lstm_layer_3): LSTM(193, 192, batch_first=True)
  (lstm_layer_4): LSTM(193, 192, batch_first=True)
  (rnn_layer_memory_shared): Memory(
    (read_keys_transform): Linear(in_features=192, out_features=1, bias=True)
    (read_strengths_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_key_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_strength_transform): Linear(in_features=192, out_features=1, bias=True)
    (erase_vector_transform): Linear(in_features=192, out_features=1, bias=True)
    (write_vector_transf

  0%|          | 0/1001 [00:00<?, ?it/s]

Epoch 0/1001 | ETA: 2025-02-27 19:57:23


  1%|          | 10/1001 [02:13<3:33:09, 12.91s/it]

Epoch 10/1001 | ETA: 2025-02-27 23:20:06


  1%|          | 11/1001 [02:32<4:03:17, 14.75s/it]

otherkey:  allocation_weights
torch.Size([50, 85, 1, 20])
otherkey:  allocation_gate
torch.Size([50, 85, 1])
otherkey:  write_gate
torch.Size([50, 85, 1])
otherkey:  write_weights
torch.Size([50, 85, 1, 20])
otherkey:  read_modes
torch.Size([50, 85, 1, 3])
otherkey:  read_weights
torch.Size([50, 85, 1, 20])
otherkey:  free_gates
torch.Size([50, 85, 1])
otherkey:  erase_vector
torch.Size([50, 85, 1, 1])
otherkey:  write_vector
torch.Size([50, 85, 1, 1])
otherkey:  usage_vector
torch.Size([50, 85, 20])


  2%|▏         | 20/1001 [04:29<3:31:28, 12.93s/it]

Epoch 20/1001 | ETA: 2025-02-27 23:31:11


  3%|▎         | 30/1001 [06:42<3:27:51, 12.84s/it]

Epoch 30/1001 | ETA: 2025-02-27 23:34:15


  4%|▍         | 40/1001 [08:57<3:26:14, 12.88s/it]

Epoch 40/1001 | ETA: 2025-02-27 23:36:08


  5%|▍         | 50/1001 [11:12<3:23:46, 12.86s/it]

Epoch 50/1001 | ETA: 2025-02-27 23:37:23


  6%|▌         | 60/1001 [13:29<3:22:32, 12.91s/it]

Epoch 60/1001 | ETA: 2025-02-27 23:38:50


  7%|▋         | 70/1001 [15:45<3:22:29, 13.05s/it]

Epoch 70/1001 | ETA: 2025-02-27 23:39:29


  8%|▊         | 80/1001 [18:02<3:20:21, 13.05s/it]

Epoch 80/1001 | ETA: 2025-02-27 23:40:24


  9%|▉         | 90/1001 [20:19<3:19:50, 13.16s/it]

Epoch 90/1001 | ETA: 2025-02-27 23:41:00


 10%|▉         | 100/1001 [22:36<3:15:20, 13.01s/it]

Epoch 100/1001 | ETA: 2025-02-27 23:41:32


 11%|█         | 110/1001 [24:54<3:15:52, 13.19s/it]

Epoch 110/1001 | ETA: 2025-02-27 23:42:03


 12%|█▏        | 120/1001 [27:10<3:12:47, 13.13s/it]

Epoch 120/1001 | ETA: 2025-02-27 23:42:16


 13%|█▎        | 130/1001 [29:27<3:09:57, 13.09s/it]

Epoch 130/1001 | ETA: 2025-02-27 23:42:31


 14%|█▍        | 140/1001 [31:42<3:05:28, 12.92s/it]

Epoch 140/1001 | ETA: 2025-02-27 23:42:31


 15%|█▍        | 150/1001 [33:59<3:03:19, 12.93s/it]

Epoch 150/1001 | ETA: 2025-02-27 23:42:43


 16%|█▌        | 160/1001 [36:14<3:02:03, 12.99s/it]

Epoch 160/1001 | ETA: 2025-02-27 23:42:43


 17%|█▋        | 170/1001 [38:27<2:55:46, 12.69s/it]

Epoch 170/1001 | ETA: 2025-02-27 23:42:32


 18%|█▊        | 180/1001 [40:42<2:57:04, 12.94s/it]

Epoch 180/1001 | ETA: 2025-02-27 23:42:30


 19%|█▉        | 190/1001 [42:56<2:54:30, 12.91s/it]

Epoch 190/1001 | ETA: 2025-02-27 23:42:26


 20%|█▉        | 200/1001 [45:11<2:54:11, 13.05s/it]

Epoch 200/1001 | ETA: 2025-02-27 23:42:28


 21%|██        | 210/1001 [47:25<2:48:49, 12.81s/it]

Epoch 210/1001 | ETA: 2025-02-27 23:42:23


 22%|██▏       | 220/1001 [49:43<2:55:18, 13.47s/it]

Epoch 220/1001 | ETA: 2025-02-27 23:42:38


 23%|██▎       | 230/1001 [51:58<2:47:11, 13.01s/it]

Epoch 230/1001 | ETA: 2025-02-27 23:42:38


 24%|██▍       | 240/1001 [54:13<2:42:52, 12.84s/it]

Epoch 240/1001 | ETA: 2025-02-27 23:42:38


 25%|██▍       | 250/1001 [56:29<2:44:04, 13.11s/it]

Epoch 250/1001 | ETA: 2025-02-27 23:42:43


 26%|██▌       | 260/1001 [58:45<2:39:43, 12.93s/it]

Epoch 260/1001 | ETA: 2025-02-27 23:42:44


 27%|██▋       | 270/1001 [1:01:02<2:40:23, 13.17s/it]

Epoch 270/1001 | ETA: 2025-02-27 23:42:50


 28%|██▊       | 280/1001 [1:03:16<2:36:13, 13.00s/it]

Epoch 280/1001 | ETA: 2025-02-27 23:42:47


 29%|██▉       | 290/1001 [1:05:32<2:34:03, 13.00s/it]

Epoch 290/1001 | ETA: 2025-02-27 23:42:52


 30%|██▉       | 300/1001 [1:07:48<2:31:02, 12.93s/it]

Epoch 300/1001 | ETA: 2025-02-27 23:42:52


 31%|███       | 310/1001 [1:10:05<2:33:07, 13.30s/it]

Epoch 310/1001 | ETA: 2025-02-27 23:42:59


 32%|███▏      | 320/1001 [1:12:21<2:32:16, 13.42s/it]

Epoch 320/1001 | ETA: 2025-02-27 23:43:03


 33%|███▎      | 330/1001 [1:14:37<2:25:07, 12.98s/it]

Epoch 330/1001 | ETA: 2025-02-27 23:43:02


 34%|███▍      | 340/1001 [1:16:54<2:25:35, 13.22s/it]

Epoch 340/1001 | ETA: 2025-02-27 23:43:08


 35%|███▍      | 350/1001 [1:19:09<2:19:08, 12.82s/it]

Epoch 350/1001 | ETA: 2025-02-27 23:43:07


 36%|███▌      | 360/1001 [1:21:25<2:18:04, 12.92s/it]

Epoch 360/1001 | ETA: 2025-02-27 23:43:11


 37%|███▋      | 370/1001 [1:23:39<2:15:24, 12.88s/it]

Epoch 370/1001 | ETA: 2025-02-27 23:43:07


 38%|███▊      | 380/1001 [1:25:54<2:13:04, 12.86s/it]

Epoch 380/1001 | ETA: 2025-02-27 23:43:05


 39%|███▉      | 390/1001 [1:28:08<2:11:10, 12.88s/it]

Epoch 390/1001 | ETA: 2025-02-27 23:43:02


 40%|███▉      | 400/1001 [1:30:24<2:09:59, 12.98s/it]

Epoch 400/1001 | ETA: 2025-02-27 23:43:04


 41%|████      | 410/1001 [1:32:41<2:09:14, 13.12s/it]

Epoch 410/1001 | ETA: 2025-02-27 23:43:09


 42%|████▏     | 420/1001 [1:34:55<2:06:38, 13.08s/it]

Epoch 420/1001 | ETA: 2025-02-27 23:43:05


 43%|████▎     | 430/1001 [1:37:12<2:05:09, 13.15s/it]

Epoch 430/1001 | ETA: 2025-02-27 23:43:09


 44%|████▍     | 440/1001 [1:39:26<2:00:36, 12.90s/it]

Epoch 440/1001 | ETA: 2025-02-27 23:43:05


 45%|████▍     | 450/1001 [1:41:41<1:58:18, 12.88s/it]

Epoch 450/1001 | ETA: 2025-02-27 23:43:06


 46%|████▌     | 460/1001 [1:43:56<1:55:29, 12.81s/it]

Epoch 460/1001 | ETA: 2025-02-27 23:43:05


 47%|████▋     | 470/1001 [1:46:10<1:53:02, 12.77s/it]

Epoch 470/1001 | ETA: 2025-02-27 23:43:03


 48%|████▊     | 480/1001 [1:48:24<1:53:01, 13.02s/it]

Epoch 480/1001 | ETA: 2025-02-27 23:42:59


 49%|████▉     | 490/1001 [1:50:38<1:48:44, 12.77s/it]

Epoch 490/1001 | ETA: 2025-02-27 23:42:57


 50%|████▉     | 500/1001 [1:52:53<1:48:02, 12.94s/it]

Epoch 500/1001 | ETA: 2025-02-27 23:42:56


 51%|█████     | 510/1001 [1:55:08<1:46:11, 12.98s/it]

Epoch 510/1001 | ETA: 2025-02-27 23:42:56


 52%|█████▏    | 520/1001 [1:57:23<1:44:48, 13.07s/it]

Epoch 520/1001 | ETA: 2025-02-27 23:42:56


 53%|█████▎    | 530/1001 [1:59:39<1:41:13, 12.89s/it]

Epoch 530/1001 | ETA: 2025-02-27 23:42:58


 54%|█████▍    | 540/1001 [2:01:53<1:38:39, 12.84s/it]

Epoch 540/1001 | ETA: 2025-02-27 23:42:55


 55%|█████▍    | 550/1001 [2:04:07<1:36:31, 12.84s/it]

Epoch 550/1001 | ETA: 2025-02-27 23:42:53


 56%|█████▌    | 560/1001 [2:06:24<1:36:07, 13.08s/it]

Epoch 560/1001 | ETA: 2025-02-27 23:42:56


 57%|█████▋    | 570/1001 [2:08:40<1:34:42, 13.18s/it]

Epoch 570/1001 | ETA: 2025-02-27 23:42:58


 58%|█████▊    | 580/1001 [2:10:54<1:30:34, 12.91s/it]

Epoch 580/1001 | ETA: 2025-02-27 23:42:56


 59%|█████▉    | 590/1001 [2:13:10<1:29:10, 13.02s/it]

Epoch 590/1001 | ETA: 2025-02-27 23:42:57


 60%|█████▉    | 600/1001 [2:15:26<1:27:56, 13.16s/it]

Epoch 600/1001 | ETA: 2025-02-27 23:42:58


 61%|██████    | 610/1001 [2:17:43<1:24:44, 13.00s/it]

Epoch 610/1001 | ETA: 2025-02-27 23:43:00


 62%|██████▏   | 620/1001 [2:19:56<1:21:17, 12.80s/it]

Epoch 620/1001 | ETA: 2025-02-27 23:42:58


 63%|██████▎   | 630/1001 [2:22:11<1:19:24, 12.84s/it]

Epoch 630/1001 | ETA: 2025-02-27 23:42:58


 64%|██████▍   | 640/1001 [2:24:25<1:18:01, 12.97s/it]

Epoch 640/1001 | ETA: 2025-02-27 23:42:56


 65%|██████▍   | 650/1001 [2:26:39<1:15:20, 12.88s/it]

Epoch 650/1001 | ETA: 2025-02-27 23:42:54


 66%|██████▌   | 660/1001 [2:28:55<1:14:48, 13.16s/it]

Epoch 660/1001 | ETA: 2025-02-27 23:42:55


 67%|██████▋   | 670/1001 [2:31:11<1:11:30, 12.96s/it]

Epoch 670/1001 | ETA: 2025-02-27 23:42:56


 68%|██████▊   | 680/1001 [2:33:27<1:09:20, 12.96s/it]

Epoch 680/1001 | ETA: 2025-02-27 23:42:57


 69%|██████▉   | 690/1001 [2:35:42<1:06:18, 12.79s/it]

Epoch 690/1001 | ETA: 2025-02-27 23:42:57


 70%|██████▉   | 700/1001 [2:37:59<1:04:17, 12.81s/it]

Epoch 700/1001 | ETA: 2025-02-27 23:42:59


 71%|███████   | 710/1001 [2:40:16<1:03:23, 13.07s/it]

Epoch 710/1001 | ETA: 2025-02-27 23:43:01


 72%|███████▏  | 720/1001 [2:42:32<1:00:32, 12.93s/it]

Epoch 720/1001 | ETA: 2025-02-27 23:43:04


 73%|███████▎  | 730/1001 [2:44:48<59:18, 13.13s/it]  

Epoch 730/1001 | ETA: 2025-02-27 23:43:04


 74%|███████▍  | 740/1001 [2:47:03<56:31, 13.00s/it]  

Epoch 740/1001 | ETA: 2025-02-27 23:43:04


 75%|███████▍  | 750/1001 [2:49:18<54:44, 13.08s/it]  

Epoch 750/1001 | ETA: 2025-02-27 23:43:04


 76%|███████▌  | 760/1001 [2:51:31<50:54, 12.68s/it]  

Epoch 760/1001 | ETA: 2025-02-27 23:43:01


 77%|███████▋  | 770/1001 [2:53:45<49:28, 12.85s/it]

Epoch 770/1001 | ETA: 2025-02-27 23:42:59


 78%|███████▊  | 780/1001 [2:56:00<47:20, 12.85s/it]

Epoch 780/1001 | ETA: 2025-02-27 23:42:59


 79%|███████▉  | 790/1001 [2:58:16<45:21, 12.90s/it]

Epoch 790/1001 | ETA: 2025-02-27 23:42:59


 80%|███████▉  | 800/1001 [3:00:31<43:08, 12.88s/it]

Epoch 800/1001 | ETA: 2025-02-27 23:42:59


 81%|████████  | 810/1001 [3:02:47<40:51, 12.84s/it]

Epoch 810/1001 | ETA: 2025-02-27 23:43:00


 82%|████████▏ | 820/1001 [3:05:02<38:49, 12.87s/it]

Epoch 820/1001 | ETA: 2025-02-27 23:43:00


 83%|████████▎ | 830/1001 [3:07:18<37:03, 13.00s/it]

Epoch 830/1001 | ETA: 2025-02-27 23:43:00


 84%|████████▍ | 840/1001 [3:09:34<35:01, 13.05s/it]

Epoch 840/1001 | ETA: 2025-02-27 23:43:01


 85%|████████▍ | 850/1001 [3:11:48<32:29, 12.91s/it]

Epoch 850/1001 | ETA: 2025-02-27 23:43:00


 86%|████████▌ | 860/1001 [3:14:04<30:24, 12.94s/it]

Epoch 860/1001 | ETA: 2025-02-27 23:43:01


 87%|████████▋ | 870/1001 [3:16:18<28:10, 12.90s/it]

Epoch 870/1001 | ETA: 2025-02-27 23:43:00


 88%|████████▊ | 880/1001 [3:18:35<25:56, 12.87s/it]

Epoch 880/1001 | ETA: 2025-02-27 23:43:01


 89%|████████▉ | 890/1001 [3:20:51<23:56, 12.94s/it]

Epoch 890/1001 | ETA: 2025-02-27 23:43:02


 90%|████████▉ | 900/1001 [3:23:06<21:39, 12.86s/it]

Epoch 900/1001 | ETA: 2025-02-27 23:43:03


 91%|█████████ | 910/1001 [3:25:21<19:39, 12.96s/it]

Epoch 910/1001 | ETA: 2025-02-27 23:43:02


 92%|█████████▏| 920/1001 [3:27:37<17:25, 12.91s/it]

Epoch 920/1001 | ETA: 2025-02-27 23:43:03


 93%|█████████▎| 930/1001 [3:29:52<15:13, 12.87s/it]

Epoch 930/1001 | ETA: 2025-02-27 23:43:03


 94%|█████████▍| 940/1001 [3:32:06<13:03, 12.84s/it]

Epoch 940/1001 | ETA: 2025-02-27 23:43:01


 95%|█████████▍| 950/1001 [3:34:20<10:48, 12.72s/it]

Epoch 950/1001 | ETA: 2025-02-27 23:42:59


 96%|█████████▌| 960/1001 [3:36:35<09:00, 13.19s/it]

Epoch 960/1001 | ETA: 2025-02-27 23:42:59


 97%|█████████▋| 970/1001 [3:38:49<06:41, 12.94s/it]

Epoch 970/1001 | ETA: 2025-02-27 23:42:59


 98%|█████████▊| 980/1001 [3:41:03<04:31, 12.91s/it]

Epoch 980/1001 | ETA: 2025-02-27 23:42:57


 99%|█████████▉| 990/1001 [3:43:16<02:21, 12.86s/it]

Epoch 990/1001 | ETA: 2025-02-27 23:42:55


100%|█████████▉| 1000/1001 [3:45:31<00:12, 12.89s/it]

Epoch 1000/1001 | ETA: 2025-02-27 23:42:55


100%|██████████| 1001/1001 [3:45:50<00:00, 13.54s/it]


In [None]:
from itertools import product
lookat = "read_heads" # "layers" or "read_heads"
if lookat == "losses":
  colordict = {mse.__name__: '#007f4e', exp_loss.__name__: '#e12729', mCELoss.__name__: '#194a7a', "total": "black", L1loss.__name__: "#f37324"}  
  colordict2 = {mse.__name__: '#72b043', exp_loss.__name__: '#f37324', mCELoss.__name__: '#476f95', "total": "black", L1loss.__name__: "#f37324"}
else:
  colordict = {"2": '#007f4e', "3": '#e12729', "5": '#194a7a', "total": "black", "1": "#f37324"}
  colordict2 = {"2": '#72b043', "3": '#f37324', "5": '#476f95', "total": "black", "1": "#f37324"}

#lossfn = [mCELoss, mse, exp_loss, L1loss]
logy = [False, True]
for comp, l in product(comp_obj, logy):
      if comp["lossfn"].__name__ == L1loss.__name__:
         continue
      lossfni = comp["i"]
      datas = comp["datas"]
      mnum_layers = comp["mnum_layers"]
      mread_heads = comp["mread_heads"]

      if lookat == "losses":
        myname = comp["lossfn"].__name__
        postfix = ""
      elif lookat == "layers":
        myname = str(mnum_layers)
        postfix = " controller network layers"
      else:
        myname = str(mread_heads)
        postfix = " read heads"
      

      df = pd.DataFrame(datas) # plot loss 
      pickle.dump(df, open(f"{name}/df_lf{lossfni}_total.pkl", "wb"))

      fig = go.Figure()
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_base"], mode='lines', name='Train Data', line=dict(color=colordict[myname])))
      # if comp["lossfn"].__name__ != mse.__name__:
      #   fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_mse"], mode='lines', name='MSE Loss (train Data)', line=dict(color=colordict[mse.__name__])))
      # if comp["lossfn"].__name__ != mCELoss.__name__:
      #   fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_ce"], mode='lines', name='CrossEntropy Loss (train Data)', line=dict(color=colordict[mCELoss.__name__])))
      # if comp["lossfn"].__name__ != exp_loss.__name__:
      #   fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_exp"], mode='lines', name='Exp Loss (train Data)', line=dict(color=colordict[exp_loss.__name__])))

      fig.add_trace(go.Scatter(x=df["epoch"], y=df["testloss"], mode='lines', name='Test Data', line=dict(color="black")))
      
      if l:
        fig.update_yaxes(type="log")
      fig.update_layout(title=f'Losses trained with {myname} {postfix}', xaxis_title='Epoch', yaxis_title='Loss')
      fig.show()
      fig.write_image(f"{name}/_losses_lf{lossfni}_{myname}_{str(l)}.png")

      fig = go.Figure()
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name='Train Data', line=dict(color=colordict[myname])))
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["testaccuracy"], mode='lines', name='Test Data', line=dict(color="black")))
      fig.update_layout(title=f'Accuracy trained with {myname} {postfix}', xaxis_title='Epoch', yaxis_title='Accuracy')
      fig.show()
      fig.write_image(f"{name}/_accuracy_lf{lossfni}_{myname}_{str(l)}.png")

fig = go.Figure()
fig2 = go.Figure()
for comp in comp_obj:
  lossfni = comp["i"]
  datas = comp["datas"]
  
  if lookat == "losses":
    myname = comp["lossfn"].__name__
    postfix = " different loss functions"
  elif lookat == "layers":
    myname = str(comp["mnum_layers"])
    postfix = " different numbers of layers"
  else:
    myname = str(comp["mread_heads"])
    postfix = " different numbers of read heads"
  df = pd.DataFrame(datas)
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name=f'{myname}', line=dict(color=colordict[myname])))
  fig2.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name=f'{myname}', line=dict(color=colordict[myname])))
  fig2.add_trace(go.Scatter(x=df["epoch"], y=df["testaccuracy"], mode='lines', name=f'{myname}', line=dict(color=colordict2[myname])))
fig.update_layout(title=f'Accuracy trained with {postfix}', xaxis_title='Epoch', yaxis_title='Accuracy')
fig.show()
fig.write_image(f"{name}/_accuracy_all.png")
fig2.update_layout(title=f'Accuracy trained with {postfix}', xaxis_title='Epoch', yaxis_title='Accuracy')
fig2.show()
fig2.write_image(f"{name}/_accuracy_all_test.png")

    

In [13]:
from itertools import product

colordict = {"1": '#007f4e', "2": '#e12729', "3": '#194a7a', "total": "black", "4": "#f37324"}
colordict2 = {"1": '#72b043', "2": '#f37324', "3": '#476f95', "total": "black", "4": "#f37324"}


#lossfn = [1, 3, 5]
logy = [False, True]
for comp, l in product(comp_obj, logy):
      if comp["lossfn"].__name__ == L1loss.__name__:
         continue
      lossfni = comp["i"]
      datas = comp["datas"]
      read_heads = comp["mread_heads"]
      df = pd.DataFrame(datas) # plot loss 
      pickle.dump(df, open(f"{name}/df_rh{read_heads}_total.pkl", "wb"))

      fig = go.Figure()
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_base"], mode='lines', name='Train Data', line=dict(color=colordict[comp["lossfn"].__name__])))
      if comp["lossfn"].__name__ != mse.__name__:
        fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_mse"], mode='lines', name='MSE Loss (train Data)', line=dict(color=colordict[mse.__name__])))
      if comp["lossfn"].__name__ != mCELoss.__name__:
        fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_ce"], mode='lines', name='CrossEntropy Loss (train Data)', line=dict(color=colordict[mCELoss.__name__])))
      if comp["lossfn"].__name__ != exp_loss.__name__:
        fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_exp"], mode='lines', name='Exp Loss (train Data)', line=dict(color=colordict[exp_loss.__name__])))

      fig.add_trace(go.Scatter(x=df["epoch"], y=df["testloss"], mode='lines', name='Test Data', line=dict(color="black")))
      
      if l:
        fig.update_yaxes(type="log")
      fig.update_layout(title=f'Losses trained with {comp["lossfn"].__name__}', xaxis_title='Epoch', yaxis_title='Loss')
      fig.show()
      fig.write_image(f"{name}/losses_lf{lossfni}_{str(l)}.png")

      fig = go.Figure()
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name='Train Data', line=dict(color=colordict[comp["lossfn"].__name__])))
      fig.add_trace(go.Scatter(x=df["epoch"], y=df["testaccuracy"], mode='lines', name='Test Data', line=dict(color="black")))
      fig.update_layout(title=f'Accuracy trained with {comp["lossfn"].__name__}', xaxis_title='Epoch', yaxis_title='Accuracy')
      fig.show()
      fig.write_image(f"{name}/accuracy_lf{lossfni}_{str(l)}.png")

fig = go.Figure()
fig2 = go.Figure()
for comp in comp_obj:
  lossfni = comp["i"]
  datas = comp["datas"]
  df = pd.DataFrame(datas)
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name=f'{comp["lossfn"].__name__}', line=dict(color=colordict[comp["lossfn"].__name__])))
  fig2.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name=f'{comp["lossfn"].__name__}', line=dict(color=colordict[comp["lossfn"].__name__])))
  fig2.add_trace(go.Scatter(x=df["epoch"], y=df["testaccuracy"], mode='lines', name=f'{comp["lossfn"].__name__}', line=dict(color=colordict2[comp["lossfn"].__name__])))
fig.update_layout(title=f'Accuracy trained with different loss functions', xaxis_title='Epoch', yaxis_title='Accuracy')
fig.show()
fig.write_image(f"{name}/accuracy_all.png")
fig2.update_layout(title=f'Accuracy trained with different loss functions', xaxis_title='Epoch', yaxis_title='Accuracy')
fig2.show()
fig2.write_image(f"{name}/accuracy_all_test.png")

    

KeyError: 'mCELoss'

In [None]:
df = pd.DataFrame(datas) # plot loss 
pickle.dump(df, open(f"{name}/df_total.pkl", "wb"))
print(df.columns)


fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss"], mode='lines', name='Train Data'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["testloss"], mode='lines', name='Test Data'))
fig.update_layout(title='Losses', xaxis_title='Epoch', yaxis_title='Loss')
fig.show()
fig.write_image(f"{name}/losses.png")

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["accuracy"], mode='lines', name='Train Data'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["testaccuracy"], mode='lines', name='Test Data'))
fig.update_layout(title='Accuracy', xaxis_title='Epoch', yaxis_title='Accuracy')
fig.show()
fig.write_image(f"{name}/accuracy.png")

if "factored_loss_base" in df.columns:
  fig = go.Figure()
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_base"], mode='lines', name='Base Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_allocation_weight"], mode='lines', name='Allocation Weight Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_allocation_gate"], mode='lines', name='Allocation Gate Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_write_gate"], mode='lines', name='Write Gate Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_read_modes"], mode='lines', name='Read Modes Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_write_weights"], mode='lines', name='Write Weights Loss'))
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_loss_usage_vector"], mode='lines', name='Usage Vector Loss'))
  fig.update_layout(title='Factored Losses', xaxis_title='Epoch', yaxis_title='Loss', yaxis_type="log")
  fig.show()
  fig.write_image(f"{name}/factored_losses.png")

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["weighted_loss"], mode='lines', name='Weighted Loss'))
fig.update_layout(title='Weighted Loss', xaxis_title='Epoch', yaxis_title='Loss', yaxis_type="log")
fig.show()
fig.write_image(f"{name}/weighted_loss.png")

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["sequencelength"], mode='lines', name='Sequence Length'))
fig.update_layout(title='Sequence Length', xaxis_title='Epoch', yaxis_title='Length')
fig.show()
fig.write_image(f"{name}/sequencelength.png")

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_base"], mode='lines', name='Base Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_allocation_weight"], mode='lines', name='Allocation Weight Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_allocation_gate"], mode='lines', name='Allocation Gate Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_write_gate"], mode='lines', name='Write Gate Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_read_modes"], mode='lines', name='Read Modes Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_write_weights"], mode='lines', name='Write Weights Factor'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["factor_usage_vector"], mode='lines', name='Usage Vector Factor'))
fig.update_layout(title='Factors', xaxis_title='Epoch', yaxis_title='Factor')
fig.show()
fig.write_image(f"{name}/factors.png")

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_base"], mode='lines', name='Base Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_allocation_weight"], mode='lines', name='Allocation Weight Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_allocation_gate"], mode='lines', name='Allocation Gate Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_write_gate"], mode='lines', name='Write Gate Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_read_modes"], mode='lines', name='Read Modes Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_write_weights"], mode='lines', name='Write Weights Loss'))
fig.add_trace(go.Scatter(x=df["epoch"], y=df["loss_usage_vector"], mode='lines', name='Usage Vector Loss'))
fig.update_layout(title='Losses', xaxis_title='Epoch', yaxis_title='Loss', yaxis_type="log")
fig.show()
fig.write_image(f"{name}/partial_losses.png")

for loss in ["loss_base", "loss_allocation_weight", "loss_allocation_gate", "loss_write_gate", "loss_read_modes", "loss_write_weights", "loss_usage_vector"]:
  fig = go.Figure()
  fig.add_trace(go.Scatter(x=df["epoch"], y=df[loss], mode='lines', name=loss))
  fig.update_layout(title=loss, xaxis_title='Epoch', yaxis_title='Loss', yaxis_type="log")
  fig.show()
  fig.write_image(f"{name}/loss_{loss}.png")
  
  newkey = "factor"+loss[4:]
  if not newkey in df.columns:
    continue
  fig = go.Figure()
  fig.add_trace(go.Scatter(x=df["epoch"], y=df[newkey], mode='lines', name=newkey))
  fig.update_layout(title=newkey, xaxis_title='Epoch', yaxis_title='Factor')
  fig.show()
  fig.write_image(f"{name}/{newkey}.png")

  if not "factored_"+loss  in df.columns:
    continue
  fig = go.Figure()
  fig.add_trace(go.Scatter(x=df["epoch"], y=df["factored_"+loss], mode='lines', name="Factored "+loss))
  fig.update_layout(title="Factored "+loss, xaxis_title='Epoch', yaxis_title='Loss', yaxis_type="log")
  fig.show()
  fig.write_image(f"{name}/factored_{loss}.png")



Index(['epoch', 'loss', 'testloss', 'sequencelength', 'accuracy',
       'testaccuracy', 'loss_base', 'loss_allocation_weight',
       'loss_allocation_gate', 'loss_write_gate', 'loss_read_modes',
       'loss_write_weights', 'loss_usage_vector', 'factor_base',
       'factor_allocation_weight', 'factor_allocation_gate',
       'factor_write_gate', 'factor_read_modes', 'factor_write_weights',
       'factor_usage_vector', 'factored_loss_base',
       'factored_loss_allocation_weight', 'factored_loss_allocation_gate',
       'factored_loss_write_gate', 'factored_loss_read_modes',
       'factored_loss_write_weights', 'factored_loss_usage_vector', 'factors',
       'weighted_loss'],
      dtype='object')


In [None]:
# #from dnc.dnc import DNC

# if 'rnn' in locals() or 'rnn' in globals():
#   del rnn

# rnn = DNC(
#         input_size=input_size,
#         hidden_size=output_size,
#         rnn_type='rnn',
#         #rnn_type='lstm',
#         num_layers=num_layers,
#         num_hidden_layers=1,
#         dropout=0,
#         nr_cells=mem_slot,
#         cell_size=mem_size,
#         read_heads=read_heads,
#         gpu_id=-1,
#         debug='store_true',
#         batch_first=True,
#         independent_linears=True,
#         nonlinearity='tanh',
#     )

# if not 'name' in locals() or not 'name' in globals():
#   name = 'add_9b2'
# if not 'lastcp' in locals() or not 'lastcp' in globals():
#   lastcp = f'{name}/checkpoint_1000.pth'
  
print(name)

with open(f"{name}/output_2.txt", "w") as f:
  batch_size=1
  rnn.load_state_dict(T.load(lastcp, weights_only=True))
  rnn.eval()
  
  stepByStep = copy.deepcopy(STEPBYSTEPOBJ)

  i=0
  llprint("\nIteration %d/%d" % (i, iterations))
  # We test now the learned generalization using sequence_max_length examples
  random_length = np.random.randint(2, sequence_length  + 1)
  input_data, target_output = generate_data(1, random_length, input_size)

  #print (input_data, target_output)

  

  
  input_data = var(T.from_numpy(input_data))
  target_output = var(T.from_numpy(target_output))

  labels = target_output.argmax(dim=2)

  stepByStep["CurrI"] = i
  stepByStep["currentObj"] = copy.deepcopy(stepByStep["defObj"])
  stepByStep["currentObj"]["i"] = i 
  stepByStep["input"] = input_data.detach().numpy()
  stepByStep["target"] = target_output.detach().numpy()
  stepByStep["MEMORYCOLUMNS"] = mem_slot
  stepByStep["INPUTSIZE"] = input_size
  stepByStep["OUTPUTSIZE"] = output_size # dataoutputsize?
  stepByStep["read_heads"] = read_heads

  stepByStep["INTERMEDIATEOUTPUT"] = output_size
  stepByStep["DNCOUPTPUT"] = output_size
    
  if rnn.debug:
    output, (chx, mhx, rv), v = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True, stepByStep=stepByStep)
  else:
    output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True, stepByStep=stepByStep)

  print("input_data: ", input_data)
  print("output: ", output)
  print("target_output: ", target_output)
  print("labels: ", labels)
  print("accuracy: ", calcAccuracy(output.type(T.float32), target_output))

  stepByStep["output"] = output
  stepByStep["objects"].append(copy.deepcopy(stepByStep["currentObj"]))
  stepByStep['loss'] = str(mse(output.type(T.float32), target_output).item())
  stepByStep['accuracy'] = (output.type(T.float32).argmax(dim=2) == labels).int().to(T.float32).mean().item()
  #output = output[:, -1, :].sum().data.cpu().numpy()
  #target_output = target_output.sum().data.cpu().numpy()
  print("loss", mse(output.type(T.float32), target_output).item())
  print(stepByStep["input"].shape)
  print(stepByStep["output"].shape)
  print(stepByStep["target"].shape)
  #raise Exception("STOP")

  print(stepByStep)

  pickle.dump(stepByStep, open(f"{name}/stepByStep.pkl", "wb"))

  print("\n\n")
  print("Input: ", tensor2string(input_data[0]), file=f)
  print("Output: ", tensor2string(output[0]), file=f)
  print("Target: ", tensor2string(target_output[0]), file=f)
  print("CE Loss: ", str(mse(output[0].to(dtype=T.float32), target_output[0]).item()), file=f)
  print("Log Loss: ", str(criterion(output[0].to(dtype=T.float32), target_output[0]).item()), file=f)
  print("Exp Loss: ", str(exp_loss(output[0].to(dtype=T.float32), target_output[0]).item()), file=f)
  print("\n\n")
  print("CE Loss: ", str(mse(output.to(dtype=T.float32), target_output).item()), file=f)
  print("Log Loss: ", str(criterion(output.to(dtype=T.float32), target_output).item()), file=f)
  print("Exp Loss: ", str(exp_loss(output.to(dtype=T.float32), target_output).item()), file=f)
  print("\n\n")

  try:
    print("\nReal value: ", ' = ' + str(int(target_output[0])))
    print("Predicted:  ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
  except Exception as e:
    pass

  

add_ed2

Iteration 0/2000input_data:  tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  