In [19]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F
from torch import Tensor
from datetime import datetime
import math
import random
from mynn import *
import numpy as NP
import matplotlib.pyplot as plt
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
initLogging("Common init")

dtype = torch.float
log("dtype", dtype)

dvc = torch.device("cpu")
log("device", dvc.type)

contextSize = 3
log("contextSize", contextSize)

newNet = True
log("newNet", newNet)


Common init:         -------------------------- 2023-07-13 18:05:48
dtype:               torch.float32
device:              cpu
contextSize:         3
newNet:              True


In [21]:
logSection("Prepare data")

filePath = "names.txt"
log("filePath", filePath)

trRatio = 0.8
log("trRatio", trRatio)

devRatio = 0.9
log("devRatio", devRatio)

wordShufflingSeed = 42
log("wordShufflingSeed", wordShufflingSeed)

words = readFileSplitByLine(filePath)
random.seed(wordShufflingSeed)
random.shuffle(words)
log("first few words", words[:5])

lenWords = len(words);
log("lenWords", lenWords)

allPossibleChars = sorted(list(set("".join(words))))
log("allPossibleChars", allPossibleChars)

stoi = sToI(allPossibleChars)
log("stoi", stoi)

itos = iToS(stoi)
log("itos", itos)

vocabularyLength = len(itos)
log("vocabularyLength", vocabularyLength)

log("data random probability", f"{-torch.tensor(1 / vocabularyLength).log().item():.4f}")


Prepare data:        -------------------------- 2023-07-13 18:05:48
filePath:            names.txt
trRatio:             0.8
devRatio:            0.9
wordShufflingSeed:   42
first few words:     ['yuheng', 'diondre', 'xavien', 'jori', 'juanluis']
lenWords:            32033
allPossibleChars:    ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
stoi:                {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
itos:                {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
vocabularyLength:    27
data random probability: 3.2958


In [22]:
logSection("Prepare dataset")

dataDtype = torch.int64
log("data dtype", dataDtype)

lenTrain = int(trRatio * lenWords)
trWords = words[:lenTrain]
trX, trY = buildDataSet(trWords, contextSize, stoi, itos, dataDtype, dvc)
log("data set training", lenTrain, trX.shape, trY.shape, trWords[:3])

endVal = int(devRatio * lenWords)
valWords = words[lenTrain:endVal];
valX, valY = buildDataSet(valWords, contextSize, stoi, itos, dataDtype, dvc)
log("data set validation", endVal - lenTrain, valX.shape, valY.shape, valWords[:3])

lenTest = lenWords - endVal
tstWords = words[endVal:]
tstX, tstY = buildDataSet(tstWords, contextSize, stoi, itos, dataDtype, dvc)
log("data set test", lenTest, tstX.shape, tstY.shape, tstWords[:3])

Prepare dataset:     -------------------------- 2023-07-13 18:05:48
data dtype:          torch.int64
data set training:   25626 torch.Size([182625, 3]) torch.Size([182625]) ['yuheng', 'diondre', 'xavien']
data set validation: 3203 torch.Size([22655, 3]) torch.Size([22655]) ['amay', 'aytana', 'jenevi']
data set test:       3204 torch.Size([22866, 3]) torch.Size([22866]) ['mustafa', 'reuben', 'kahlel']


In [23]:
logSection("Build network")

embeddingDims = 10
log("embeddingDims", embeddingDims)

hiddenLayerSize = 100
log("hiddenLayerSize", hiddenLayerSize)

learningSeed = 2147483647
log("learningSeed", learningSeed)
g = torch.Generator(device=dvc).manual_seed(learningSeed)

if newNet:
    np2 = makeNetwork2(g, vocabularyLength, embeddingDims, contextSize, hiddenLayerSize, dtype, dvc)
    printNetwork(np2)
        
    @torch.no_grad()
    def trLoss2(): return getLoss2(np2, np2.C[trX], trY)

    @torch.no_grad()
    def valLoss2(): return getLoss2(np2, np2.C[valX], valY)

    @torch.no_grad()
    def tstLoss2(): return getLoss2(np2, np2.C[tstX], tstY)

    def getLosses2() -> Losses2:
        l = Losses2()
        l.tr = trLoss2()
        l.val = valLoss2()
        l.tst = tstLoss2()
        return l

    def logLosses2():
        losses = getLosses2()
        l1 = f"{losses.tr.loss.item():>10.4f}"
        l2 = f"{losses.val.loss.item():>10.4f}"
        l3 = f"{losses.tst.loss.item():>10.4f}"
        logSimple(f"{l1} {l2} {l3}")
else :
    np = makeNetwork(g, vocabularyLength, embeddingDims, contextSize, hiddenLayerSize, dvc)
    log("parametersCount", sum(p.nelement() for p in np.all))

    cal = calibrateBatchNorm(np, trX)

    plt.figure()
    plt.hist(cal.mean.tolist(), 100)
    plt.title("Calibration: mean of pre-activations")

    plt.figure()
    plt.hist(cal.std.tolist(), 100)
    plt.title("Calibration: standard deviation of pre-activations")

    @torch.no_grad()
    def trLoss(): return getLoss(np, cal, np.C[trX], trY)

    @torch.no_grad()
    def valLoss(): return getLoss(np, cal, np.C[valX], valY)

    @torch.no_grad()
    def tstLoss(): return getLoss(np, cal, np.C[tstX], tstY)

    def getLosses() -> Losses:
        l = Losses()
        l.tr = trLoss()
        l.val = valLoss()
        l.tst = tstLoss()
        return l

    def logLosses():
        losses = getLosses()
        l1 = f"{losses.tr.loss.item():>10.4f}"
        l2 = f"{losses.val.loss.item():>10.4f}"
        l3 = f"{losses.tst.loss.item():>10.4f}"
        logSimple(f"{l1} {l2} {l3}")

Build network:       -------------------------- 2023-07-13 18:05:49
embeddingDims:       10
hiddenLayerSize:     100
learningSeed:        2147483647
Network Structure:  
Layer LinearWithBias 4: torch.Size([30, 100]), torch.Size([100]), 
Layer Tanh 5: 
Layer LinearWithBias 6: torch.Size([100, 100]), torch.Size([100]), 
Parameters Count:    13470


In [24]:
logSection("Learning")

trainingBatchSize = 32
log("trainingBatchSize", trainingBatchSize)

trXLength = trX.shape[0]
trainingSteps = (trXLength // trainingBatchSize) + 1
log ("trainingSteps: ", trainingSteps)

maxIteration = 200_000
log("maxIteration", maxIteration)

repeats = 36
log("repeats",repeats)

maxLr = 0.14
log("maxLr", maxLr)

minLr = 0.0001
log("minLr", minLr)

actualIterations = min(maxIteration, repeats * math.ceil(trXLength / trainingBatchSize))
log("actualIterations", actualIterations)

lrAtIx: list[float] = []
stepIx: list[int] = []
lossAtIx: list[float] = []
logLossAtIx: list[float] = []
up = UpdateNetResult()
fr = ForwardPassResult()
fr.loss = torch.tensor(0)
i = 0

if newNet:
    bnMeanRunning = 0
    bnStdRunning = 0
    for repeat in range(repeats):
        
        if i >= maxIteration:
            break;

        logSimple(f"{repeat:>3}, {i:>6} losses: {fr.loss.item():>10.4f}   ", end="")
        logLosses2()

        for start in range(0, trXLength, trainingBatchSize):

            if i >= maxIteration:
                log("Break at max iteration")
                break;
            
            end = min(start + trainingBatchSize, trXLength)
            miniBatchIxs = torch.randint(0, trXLength, (trainingBatchSize,), generator=g, device=dvc)
            fr = forwardPass2(np2, trX, trY, miniBatchIxs)
            backwardPass2(np2.layers, np2.parameters, fr.loss)

            up = updateNet(np2.parameters, i, actualIterations, maxLr, minLr)
            lrAtIx.append(up.learningRate)

            stepIx.append(i)
            lossAtIx.append(fr.loss.item())
            logLossAtIx.append(fr.loss.log10().item())

            i += 1

    logSimple(f" final losses: {fr.loss.item():>15.4f}   ", end="")
    logLosses2()
else:
    lr = 0.1;
    lre = torch.linspace(-3, 0, trainingSteps)
    lrs = 10 ** lre

    for repeat in range(repeats):
        
        if i >= maxIteration:
            break;

        logSimple(f"{repeat:>3}, {i:>6} losses: {fr.loss.item():>10.4f}   ", end="")
        logLosses()

        for start in range(0, trXLength, trainingBatchSize):

            if i >= maxIteration:
                log("Break at max iteration")
                break;
            
            end = min(start + trainingBatchSize, trXLength)
            #miniBatchIxs = torch.arange(start, end)
            miniBatchIxs = torch.randint(0, trXLength, (trainingBatchSize,), generator=g, device=dvc)
            fr = forwardPass(np, cal, trX, trY, miniBatchIxs)
            backwardPass(np.all, fr.loss)

            if i == 1:
                plt.figure(figsize=(20, 10))
                plt.imshow(fr.h.abs() > 0.99, cmap="gray", interpolation="nearest")
                
                plt.figure()
                plt.hist(fr.h.view(-1).tolist(), 100)
                plt.title('Histogram of h')

                plt.figure()
                plt.hist(fr.hPreActivations.view(-1).tolist(), 100)
                plt.title('Histogram of hPreActivations')

                plt.figure()
                plt.hist(fr.hPreActivations.mean(0, keepdim=True).view(-1).tolist(), 100)
                plt.title('Histogram of Mean of hPreActivations')

                plt.figure()
                plt.hist(fr.hPreActivations.std(0, keepdim=True).view(-1).tolist(), 100)
                plt.title('Histogram of Std of hPreActivations')

            up = updateNet(np.all, i, actualIterations, maxLr, minLr)
            lrAtIx.append(up.learningRate)

            stepIx.append(i)
            lossAtIx.append(fr.loss.item())
            logLossAtIx.append(fr.loss.log10().item())

            #lr = lrs[i].item()
            #lrAtIx.append(lrs[i].item())
            
            i += 1
    
    logSimple(f" final losses: {fr.loss.item():>15.4f}   ", end="")
    logLosses()

#bestLr = lrs[findLowestIndex(lossAtIx)].item();
#log("best learning rate", bestLr)
log("emb.shape", fr.emb.shape)
#log("h.shape", fr.h.shape)
log("logits.shape", fr.logits.shape)

Learning:            -------------------------- 2023-07-13 18:05:49
trainingBatchSize:   32
trainingSteps: :     5708
maxIteration:        200000
repeats:             36
maxLr:               0.14
minLr:               0.0001
actualIterations:    200000
  0,      0 losses:     0.0000       4.5749     4.5745     4.5742
  1,   5708 losses:     2.4943       2.3476     2.3441     2.3473
  2,  11416 losses:     2.3216       2.2889     2.2911     2.2899
  3,  17124 losses:     2.1395       2.2628     2.2645     2.2653
  4,  22832 losses:     2.5993       2.2338     2.2428     2.2413
  5,  28540 losses:     2.1791       2.2145     2.2250     2.2245
  6,  34248 losses:     2.1214       2.2146     2.2204     2.2234
  7,  39956 losses:     2.1114       2.1987     2.2123     2.2098
  8,  45664 losses:     1.9537       2.1914     2.2064     2.2045
  9,  51372 losses:     2.1421       2.1940     2.2042     2.2127
 10,  57080 losses:     2.1321       2.1771     2.1920     2.1953
 11,  62788 losses:   

In [25]:
logSection("Sampling")

samplingSeed = learningSeed + 10
gSampling = torch.Generator(device=dvc).manual_seed(samplingSeed)
log("samplingSeed", samplingSeed)

maxSampleLength = 50
log("maxSampleLength", maxSampleLength)

Sampling:            -------------------------- 2023-07-13 18:08:25
samplingSeed:        2147483657
maxSampleLength:     50


In [26]:
if newNet:
    samples = sampleMany2(np2, gSampling, contextSize, itos, 20, maxSampleLength)
    for s in samples:
        logSimple(f"{''.join(s.values):<21}{(s.prob * 10000):>4.0f}: ", end="")
        for p in s.probs:
            logSimple(f"{(p / (1 / 27) * 10):.0f} ", end="")
        logSimple()
else:
    samples = sampleMany(np, cal, gSampling, contextSize, itos, 20, maxSampleLength)
    for s in samples:
        logSimple(f"{''.join(s.values):<21}{(s.prob * 10000):>4.0f}: ", end="")
        for p in s.probs:
            logSimple(f"{(p / (1 / 27) * 10):.0f} ", end="")
        logSimple()

mona.                 193: 20 19 86 20 171 
kayah.                328: 25 112 44 35 66 253 
seel.                 108: 18 24 5 42 125 
ndyn.                  41: 10 1 12 66 216 
alee.                 506: 38 40 75 59 180 
threttedraegus.        51: 11 23 18 58 15 106 86 1 46 56 19 8 57 111 83 
ched.                  78: 12 61 39 4 79 
elin.                 198: 13 85 94 52 126 
shi.                  161: 18 53 30 43 
jenrene.              142: 21 46 77 6 54 59 13 126 
susopharleiyah.       130: 18 18 23 12 18 209 43 29 58 108 30 8 237 113 253 
hotelin.               40: 7 26 2 22 44 51 52 126 
shub.                  76: 18 53 7 2 155 
ridhiraes.             55: 14 36 28 4 30 24 117 4 20 44 
kindreelynn.          230: 25 24 104 19 43 75 40 34 34 175 117 225 
novana.               106: 10 38 54 106 146 18 155 
ubrence.                6: 1 13 117 78 72 10 193 139 
ryamili.               97: 14 20 64 16 96 64 57 23 
eli.                  153: 13 85 94 22 
kay.                  228: 25 112 

In [36]:
if newNet:
    pass
else:
    def printProb(txt: str):
        ps = calcProb(np, cal, txt, contextSize, stoi)
        op = calcOneProb(ps)
        logSimple(f"{txt:<21}{(op * 10000):<7.0f}: ", end="")
        for p in ps:
            logSimple(f"{(p / (1 / 27) * 10):.0f} ", end="")
        logSimple()

    printProb('.')
    printProb('m.')
    printProb('mi.')
    printProb('mic.')
    printProb('mich.')
    printProb('micha.')
    printProb('michal.')
    printProb('michael.')
    printProb('michaela.')
    printProb('michaella.')
    printProb('michel.')
    printProb('michalx.')
    printProb('michalxx.')
    printProb('michalxxx.')
    printProb('martin.')
    printProb('andrej.')
    printProb('andrey.')
    printProb('joey.')
    printProb('james.')
    printProb('xin.')
    printProb('maxim.')
    printProb('alex.')
    printProb('alexa.')

In [28]:
if newNet:
    pass
else:
    plt.plot(range(len(lrAtIx)), lrAtIx, "black")
    plt.ylim(min(lrAtIx), max(lrAtIx))
    plt.grid(True)
    plt.show()
    print("Actual min max LR", max(lrAtIx), min(lrAtIx))

In [29]:
if newNet:
    pass
else:
    plt.figure(figsize=(20,10))
    plt.imshow(fr.h.abs() > 0.99, cmap="gray", interpolation="nearest")

In [30]:
if newNet:
    pass
else:
    fig, ax = plt.subplots()
    fig.set_facecolor("#777777")
    ax.set_facecolor("#222222")
    #ax.plot(lrAtIx, lossAtIx)
    plt.show()

In [31]:
if newNet:
    pass
else:
    fig, ax = plt.subplots()
    fig.set_facecolor("#777777")
    ax.set_facecolor("#222222")
    ax.plot(stepIx, lossAtIx)
    plt.show()

In [32]:
if newNet:
    pass
else:
    fig, ax = plt.subplots()
    fig.set_facecolor("#777777")
    ax.set_facecolor("#222222")
    ax.plot(stepIx, logLossAtIx)
    plt.show()

In [33]:
if newNet:
    pass
else:
    dim = 0
    fig = plt.figure(figsize=(8,8))
    fig.set_facecolor("#777777")
    sc = plt.scatter(np.C[:, dim].data, np.C[:,dim + 1].data, s=200)
    for i in range(np.C.shape[0]):
        plt.text(np.C[i, dim].item(), np.C[i, dim + 1].item(), itos[i], ha="center", va="center", color="white")
    plt.grid()

In [34]:
#np.C.shape, trX.shape, np.C[trX].shape, np.C[:5], trX[:5], np.C[trX][:5]

In [35]:
# Let"s suppose these are your lists
C = [
[ 0.8774, -0.6801],
         [ 0.1651, -0.5025],
         [ 0.2769, -0.3570],
         [-0.8820,  0.3902],
         [-0.4824,  0.8744],
         [-0.3190,  0.7807],
         [-0.0100, -0.3401],
         [ 0.9975,  2.8280],
         [ 0.9623, -1.3172],
         [ 0.2180, -0.3820],
         [ 0.6139, -0.4287],
         [-0.7386,  0.5880],
         [-0.3088,  0.9816],
         [ 0.3907, -0.4174],
         [-0.7380,  0.5205],
         [-0.5288,  0.7074],
         [-0.3956,  0.9625],
         [-0.3802, -0.3504],
         [-0.2861,  0.7589],
         [ 0.5309, -0.5105],
         [-0.0922, -0.6410],
         [-0.3823,  0.9899],
         [ 0.0965, -0.5708],
         [-0.8582, -1.3429],
         [-0.4960,  0.2842],
         [-0.6105,  0.1336],
         [-0.2623,  0.2942]
]
trX = [
    [0, 0, 0],
    [0, 0, 22],
    [0, 22, 9],
    [22, 9, 15],
    [9, 15, 12],
    # ... More values
]

# We create a new list to hold the result
result = []

# We loop over each element in trX
for i in range(len(trX)):
    temp = []
    # Then we loop over each value in the element
    for j in range(len(trX[i])):
        # We use the value as an index to get the corresponding data from C
        temp.append(C[trX[i][j]])
    # We add the result to our new list
    result.append(temp)

# Let"s log the first 5 elements of the result list
for i in range(5):
    print(result[i])

[[0.8774, -0.6801], [0.8774, -0.6801], [0.8774, -0.6801]]
[[0.8774, -0.6801], [0.8774, -0.6801], [0.0965, -0.5708]]
[[0.8774, -0.6801], [0.0965, -0.5708], [0.218, -0.382]]
[[0.0965, -0.5708], [0.218, -0.382], [-0.5288, 0.7074]]
[[0.218, -0.382], [-0.5288, 0.7074], [-0.3088, 0.9816]]
