### Batch the characters

wavenet (like) - we want to progressively form longer joins

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math

## more about broadcasting etc

In [2]:
n = torch.randn(3, 2)
m = torch.rand(2, 3)
print(n @ m)

n = torch.tril(torch.ones((3, 3), dtype=torch.int))
m = torch.tensor([1, 2, 3])
print(n.shape, m.shape)
n * m

tensor([[-0.8776, -0.9791, -0.8851],
        [ 1.7752,  2.2412,  2.0700],
        [ 0.4356,  0.6080,  0.5702]])
torch.Size([3, 3]) torch.Size([3])


tensor([[1, 0, 0],
        [1, 2, 0],
        [1, 2, 3]])

In [3]:
n = torch.tril(torch.ones((3, 3, 3), dtype=torch.int))
m = torch.arange(1, 10).view(3, 3)
print(n * m)

tensor([[[1, 0, 0],
         [4, 5, 0],
         [7, 8, 9]],

        [[1, 0, 0],
         [4, 5, 0],
         [7, 8, 9]],

        [[1, 0, 0],
         [4, 5, 0],
         [7, 8, 9]]])


In [4]:
a = torch.arange(1,10).view(3,3)
cp = torch.cat((a, a, a), 1).view(3,3,3)

cp * torch.tril(torch.ones((3, 3, 3), dtype=torch.int))

tensor([[[1, 0, 0],
         [1, 2, 0],
         [1, 2, 3]],

        [[4, 0, 0],
         [4, 5, 0],
         [4, 5, 6]],

        [[7, 0, 0],
         [7, 8, 0],
         [7, 8, 9]]])

In [5]:
a = torch.arange(1,10).view(3,3)
a.repeat(1, 1, 3).view(3, 3, 3) 

tensor([[[1, 2, 3],
         [1, 2, 3],
         [1, 2, 3]],

        [[4, 5, 6],
         [4, 5, 6],
         [4, 5, 6]],

        [[7, 8, 9],
         [7, 8, 9],
         [7, 8, 9]]])

In [6]:
a = torch.arange(1,10).view(3,3)
a.repeat(1, 1, 3).view(3, 3, 3) * torch.tril(torch.ones((3, 3, 3), dtype=torch.int))

tensor([[[1, 0, 0],
         [1, 2, 0],
         [1, 2, 3]],

        [[4, 0, 0],
         [4, 5, 0],
         [4, 5, 6]],

        [[7, 0, 0],
         [7, 8, 0],
         [7, 8, 9]]])

In [43]:
ctx = 4
samples = 2
a = torch.arange(1,samples*ctx+1).view(samples,ctx)
out = a.repeat(1, 1, ctx).view(-1, ctx, ctx) * torch.tril(torch.ones((samples, ctx, ctx), dtype=torch.int))
print(out.shape)
out

torch.Size([2, 4, 4])


tensor([[[1, 0, 0, 0],
         [1, 2, 0, 0],
         [1, 2, 3, 0],
         [1, 2, 3, 4]],

        [[5, 0, 0, 0],
         [5, 6, 0, 0],
         [5, 6, 7, 0],
         [5, 6, 7, 8]]])

In [46]:
emb_t = torch.randn(27, 3)
print(emb_t[0])
emb_t[out[0]]

tensor([-1.5188, -2.1504, -2.7174])


tensor([[[ 0.9355,  2.2526,  0.6165],
         [-1.5188, -2.1504, -2.7174],
         [-1.5188, -2.1504, -2.7174],
         [-1.5188, -2.1504, -2.7174]],

        [[ 0.9355,  2.2526,  0.6165],
         [-0.8971, -1.4147,  2.1098],
         [-1.5188, -2.1504, -2.7174],
         [-1.5188, -2.1504, -2.7174]],

        [[ 0.9355,  2.2526,  0.6165],
         [-0.8971, -1.4147,  2.1098],
         [ 0.2606,  2.3956,  0.6378],
         [-1.5188, -2.1504, -2.7174]],

        [[ 0.9355,  2.2526,  0.6165],
         [-0.8971, -1.4147,  2.1098],
         [ 0.2606,  2.3956,  0.6378],
         [-0.2348, -0.4804, -0.2756]]])

In [8]:
vocab_size = 27

embedding_dims = 3
context_length = 4

In [9]:
%run names.py

ns = Names(context_length)

names = ns.get_names("names.txt")

i_names = [ns.stoi[s] for name in names for s in name]

ts = torch.tensor(i_names[:context_length])

print("the name encoded to its characters, ts")
print(ts)

emb = torch.randn(vocab_size, embedding_dims)

embedding = emb[ts]
print("those characters in our 3D embedding")
print(embedding)

tril = torch.tril(torch.ones(context_length, context_length, dtype=torch.int))
print("a triangluar matric of 1s")
print(tril)

trilled = ts * tril
print("a triangluarized version of the name")
print(trilled)

the name encoded to its characters, ts
tensor([ 0,  5, 13, 13])
those characters in our 3D embedding
tensor([[-0.5356,  0.2102,  0.9141],
        [-0.5295, -0.6447, -0.6345],
        [ 0.3357, -0.1848, -1.0286],
        [ 0.3357, -0.1848, -1.0286]])
a triangluar matric of 1s
tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1]], dtype=torch.int32)
a triangluarized version of the name
tensor([[ 0,  0,  0,  0],
        [ 0,  5,  0,  0],
        [ 0,  5, 13,  0],
        [ 0,  5, 13, 13]])


## A 3D version

we want triangluarized version to hold the 3D embeddings,

In [10]:
embedded_sequence = emb[trilled]

print(embedded_sequence.shape)
embedded_sequence[1]

torch.Size([4, 4, 3])


tensor([[-0.5356,  0.2102,  0.9141],
        [-0.5295, -0.6447, -0.6345],
        [-0.5356,  0.2102,  0.9141],
        [-0.5356,  0.2102,  0.9141]])

In [117]:
class Embedding:

    def __init__(self, vocab_size, embedding_dims):
        self.weight = torch.randn(vocab_size, embedding_dims)
        self.type = 'embedding'

    def __call__(self, IX):
        self.out = self.weight[IX]
        return self.out

    def parameters(self):
        return [self.weight]

In [13]:
em_layer = Embedding(27, 3)
em_layer.weight = emb
sq = em_layer(trilled)
sq[1] == embedded_sequence[1]

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])

In [74]:
ctx = 4
samples = 2
a = torch.arange(1,samples*ctx+1).view(samples,ctx)
out = a.repeat(1, 1, ctx).view(-1, ctx, ctx) * torch.tril(torch.ones((samples, ctx, ctx), dtype=torch.int))

em_layer = Embedding(27, 3)
em_out = em_layer(out)
em_layer.weight[0], em_out[0]

(tensor([-1.1056, -0.2679,  0.5158]),
 tensor([[[ 2.0601, -0.8392, -1.4414],
          [-1.1056, -0.2679,  0.5158],
          [-1.1056, -0.2679,  0.5158],
          [-1.1056, -0.2679,  0.5158]],
 
         [[ 2.0601, -0.8392, -1.4414],
          [ 1.0647, -2.1135,  2.1291],
          [-1.1056, -0.2679,  0.5158],
          [-1.1056, -0.2679,  0.5158]],
 
         [[ 2.0601, -0.8392, -1.4414],
          [ 1.0647, -2.1135,  2.1291],
          [ 0.8030, -0.5467,  0.2414],
          [-1.1056, -0.2679,  0.5158]],
 
         [[ 2.0601, -0.8392, -1.4414],
          [ 1.0647, -2.1135,  2.1291],
          [ 0.8030, -0.5467,  0.2414],
          [ 0.0137, -0.9042, -1.5613]]]))

## our embed still works

so our embed layer works with the new format

how about Flatten?

In [14]:
class FlattenConsecutive:

    def __init__(self, n):
        self.n = n
        self.type = 'flatten'

    def __call__(self, x):
        B, T, C = x.shape
        x = x.view(B, T//self.n, self.n * C)
        if x.shape[1] == 1:
            x = x.squeeze(1)
        self.out = x
        return self.out

    def parameters(self):
        return []

In [15]:
fc = FlattenConsecutive(2)
outb = fc(sq)

print(outb.shape)
outb[1]

torch.Size([4, 2, 6])


tensor([[-0.5356,  0.2102,  0.9141, -0.5295, -0.6447, -0.6345],
        [-0.5356,  0.2102,  0.9141, -0.5356,  0.2102,  0.9141]])

## OK for a single sample

but for our batch now - it's shape looks like

In [75]:
print(em_out.shape)

torch.Size([2, 4, 4, 3])


and FlattenConsecutive expects B, T, C = x.shape

but do we need flatten now?

we expect an answer to our batch all at once

## rework sampling

In [84]:
full_length = len(i_names)

offset = math.floor(full_length*0.1)

sm = {
    'train': (0, offset * 8),
    'dev': (offset * 8, offset * 9),
    'test': (offset * 9, offset * 10),
}

def samples(set, num_samples, ctx=context_length):
    fr = sm[set]
    sm_fr = i_names[fr[0]: fr[1]]
    fr_l = len(sm_fr)-ctx
    sx = []
    ys = []
    for n in random.sample(range(fr_l), num_samples):
        sx += [i_names[n:n+ctx]]
        ys += [i_names[n+1:n+ctx+1]]

    xs = torch.tensor(sx).repeat(1, 1, ctx).view(num_samples, ctx, ctx) * torch.tril(torch.ones((num_samples, ctx, ctx), dtype=torch.int))

    return xs, ys
        
xp, yp = samples('train', 2)
print(xp, yp)

tensor([[[20,  0,  0,  0],
         [20, 15,  0,  0],
         [20, 15, 14,  0],
         [20, 15, 14,  0]],

        [[ 0,  0,  0,  0],
         [ 0,  0,  0,  0],
         [ 0,  0, 10,  0],
         [ 0,  0, 10,  1]]]) [[15, 14, 0, 0], [0, 10, 1, 9]]


## repeat the layers from before

In [20]:
class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out))
    self.bias = torch.zeros(fan_out) if bias else None
    self.type = 'linear'
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out

  def kaiming(self, nonlin):
    nn.init.kaiming_normal_(self.weight, nonlinearity=nonlin)
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])

In [21]:
class Tanh:
  def __init__(self):
    self.type = 'non-linearity'
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []

class Relu:
  def __init__(self):
    self.type = 'non-linearity'
  def __call__(self, x):
    self.out = torch.relu(x)
    return self.out
  def parameters(self):
    return []

class Gelu:
  def __init__(self):
    self.type = 'non-linearity'
  def __call__(self, x):
    self.out = torch.gelu(x)
    return self.out
  def parameters(self):
    return []

In [22]:
class Sequential:
    
    def __init__(self, layers):
        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        self.out = x
        return self.out

    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]

In [23]:
class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
    self.type = 'batch_norm_1d'
  
  def __call__(self, x):
    # calculate the forward pass
    if x.ndim == 2: 
      dim = 0
    if x.ndim == 3:
      dim = (0, 1)
    if self.training:
      xmean = x.mean(dim, keepdim=True) # batch mean
      xvar = x.var(dim, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

In [24]:
embedding_size = 3 # the dimensionality of the character embedding vectors

vocab_size = 27
nonlin='relu'

Ok lets try a simpler version without batchnorm (which kills my laptop)

In [182]:
def model_init(e_s):
    ln1 = Relu()
    ln2 = Relu()
    n_hidden = 100
    
    model = Sequential([
        Embedding(vocab_size, e_s),
        nn.Flatten(2),
        Linear(e_s * context_length, n_hidden, bias=False), 
        ln1,
        nn.LayerNorm(n_hidden),
        Linear(n_hidden, n_hidden, bias=True), 
        ln2,
        nn.LayerNorm(n_hidden),
        Linear(n_hidden, vocab_size, bias=True),
    ])
    
    parameters = model.parameters()
    print(sum(p.nelement() for p in parameters)) # number of parameters in total
    # should probably be a function on the model?
    for p in parameters:
      p.requires_grad = True
    
    print(vocab_size, embedding_size)

In [166]:
epochs = 10
batch_size = 16
sample_loops = 4000

learning_rate = .01

running_loss = []
running_lr = []

ud_ratio = []

lr_step = {
    0: { 0: .1},
    3: { 0: .1},
    9: { 0: .1}
}

In [163]:
lrs = lr_step[embedding_size]
if lrs == None:
    lrs = lr_step[0]

for ep in range(epochs):
    epoch_loss = 0
    for s in range(sample_loops):
        x, y = samples('train', batch_size)
        Y = torch.tensor(y).view(-1)

        logits = model(x)

        logits = logits.view(-1, 27)

        loss = F.cross_entropy(logits, Y) # loss function
        
        with torch.no_grad():
            epoch_loss += loss

        # again stuff on parameters should probably be in model?
        for p in parameters:
          p.grad = None
        loss.backward()

        for p in parameters:
            p.data -= learning_rate * p.grad

    #just keep any epoch stuff in a no grad block
    with torch.no_grad():
        if ep in lrs:
            learning_rate = lrs[ep]

        running_loss.append(epoch_loss.item())
        running_lr.append(learning_rate)

        ud_ratio.append([ (learning_rate*p.grad.std()/ p.data.std()).log10().item() for p in parameters ])
    
        if ep % 2 == 0:
            print(epoch_loss/sample_loops)
            learning_rate *= .92
            print(ep, learning_rate)

tensor(3.2463)
0 0.09200000000000001
tensor(2.5708)
2 0.08464000000000002
tensor(2.4679)
4 0.07786880000000002
tensor(2.4349)
6 0.07163929600000002
tensor(2.4181)
8 0.06590815232000002
tensor(2.4000)
10 0.06063550013440003
tensor(2.3886)
12 0.05578466012364803
tensor(2.3829)
14 0.05132188731375619
tensor(2.3740)
16 0.0472161363286557
tensor(2.3704)
18 0.043438845422363245
tensor(2.3677)
20 0.039963737788574184
tensor(2.3575)
22 0.03676663876548825
tensor(2.3531)
24 0.033825307664249196
tensor(2.3514)
26 0.03111928305110926
tensor(2.3485)
28 0.028629740407020522
tensor(2.3527)
30 0.02633936117445888
tensor(2.3464)
32 0.02423221228050217
tensor(2.3450)
34 0.022293635298061998
tensor(2.3379)
36 0.020510144474217038
tensor(2.3367)
38 0.018869332916279676


In [164]:
print(epoch_loss/sample_loops)

tensor(2.3379)


In [253]:
# def model_init(e_s, h_s):
#     ln1 = Relu()
#     ln2 = Relu()
#     n_hidden = h_s
    
#     model = Sequential([
#         Embedding(vocab_size, e_s),
#         nn.Flatten(2),
#         Linear(e_s * context_length, n_hidden, bias=False), 
#         nn.LayerNorm(n_hidden),
#         ln1,
#         Linear(n_hidden, n_hidden, bias=False),
#         nn.LayerNorm(n_hidden),
#         ln2,
#         Linear(n_hidden, vocab_size, bias=True),
#     ])
    
#     parameters = model.parameters()
#     print(sum(p.nelement() for p in parameters)) # number of parameters in total
#     # should probably be a function on the model?
#     for p in parameters:
#       p.requires_grad = True
    
#     print(vocab_size, embedding_size)

In [273]:
def model_init(emb_s, h_s):
    ln1 = Relu()
    n_hidden = h_s
    
    model = Sequential([
        Embedding(vocab_size, emb_s),
        nn.Flatten(2),
        Linear(emb_s * context_length, n_hidden, bias=False), 
        ln1,
        Linear(n_hidden, vocab_size, bias=True),
        nn.Flatten(0,1)
    ])
    
    parameters = model.parameters()
    print(sum(p.nelement() for p in parameters)) # number of parameters in total
    # should probably be a function on the model?
    for p in parameters:
      p.requires_grad = True
    
    print("flatten", vocab_size, embedding_size)
    return model

In [272]:
x, y = samples('train', 2, 4)
print(x.shape)
print(x)
em = Embedding(27, 3)
x = em(x)
print(x.shape)
print(x)
fl = nn.Flatten(2)
x = fl(x)
print(x.shape)
print(x)
l1 = nn.Linear(12, 60)
x = l1(x)
print(x.shape)
ru = Relu()
x = ru(x)

l2 = Linear(60, vocab_size, bias=True)
x = l2(x)
print(x.shape)

fl2 = nn.Flatten(0,1)
x = fl2(x)
print("flatten")
print(x.shape)

print(y)
y = torch.tensor(y).view(-1)
ls = F.cross_entropy(x, y)

torch.Size([2, 4, 4])
tensor([[[ 8,  0,  0,  0],
         [ 8, 13,  0,  0],
         [ 8, 13,  5,  0],
         [ 8, 13,  5,  5]],

        [[ 7,  0,  0,  0],
         [ 7,  5,  0,  0],
         [ 7,  5, 12,  0],
         [ 7,  5, 12,  0]]])
torch.Size([2, 4, 4, 3])
tensor([[[[ 0.5390, -0.4193,  0.5890],
          [ 0.1797, -0.8898, -2.0754],
          [ 0.1797, -0.8898, -2.0754],
          [ 0.1797, -0.8898, -2.0754]],

         [[ 0.5390, -0.4193,  0.5890],
          [ 1.5444,  0.1385, -0.1019],
          [ 0.1797, -0.8898, -2.0754],
          [ 0.1797, -0.8898, -2.0754]],

         [[ 0.5390, -0.4193,  0.5890],
          [ 1.5444,  0.1385, -0.1019],
          [-0.2783,  0.8016, -0.5159],
          [ 0.1797, -0.8898, -2.0754]],

         [[ 0.5390, -0.4193,  0.5890],
          [ 1.5444,  0.1385, -0.1019],
          [-0.2783,  0.8016, -0.5159],
          [-0.2783,  0.8016, -0.5159]]],


        [[[ 0.4684,  0.2388,  0.1303],
          [ 0.1797, -0.8898, -2.0754],
          [ 0.1797, -

## 2 sequences
from each we create 4 lower left trianglur sequences



In [216]:
x, y = samples('train', 1, 3)
print(x)
x = em(x)
print(x)
x = fl(x)
x

tensor([[[ 7,  0,  0],
         [ 7,  5,  0],
         [ 7,  5, 12]]])
tensor([[[[ 0.5761, -0.4086, -2.1589],
          [-0.8405, -0.1164,  1.3651],
          [-0.8405, -0.1164,  1.3651]],

         [[ 0.5761, -0.4086, -2.1589],
          [-0.0053,  0.5236, -0.7010],
          [-0.8405, -0.1164,  1.3651]],

         [[ 0.5761, -0.4086, -2.1589],
          [-0.0053,  0.5236, -0.7010],
          [ 1.0390,  0.5107, -0.9023]]]])


tensor([[[ 0.5761, -0.4086, -2.1589, -0.8405, -0.1164,  1.3651, -0.8405,
          -0.1164,  1.3651],
         [ 0.5761, -0.4086, -2.1589, -0.0053,  0.5236, -0.7010, -0.8405,
          -0.1164,  1.3651],
         [ 0.5761, -0.4086, -2.1589, -0.0053,  0.5236, -0.7010,  1.0390,
           0.5107, -0.9023]]])

In [301]:
embed = 9
epochs = 80
batch_size = 32
sample_loops = 8000
hidden_size = 200

md = model_init(embed, hidden_size)
parameters = md.parameters()

lr_step = {
    0: { 0: .1},
    3: { 0: .1},
    9: { 0: .04, 60: 0.01}
}

lrs = lr_step[embed]

m_setup = {
    "embed": embed,
    "epochs": epochs,
    "batch_size": batch_size,
    "sample_loops": sample_loops,
    "hidden_size": hidden_size,
    "learning_rates": lrs
}

learning_rate = lrs[0]

running_loss = []
running_lr = []

ud_ratio = []
learning_rate

12870
flatten 27 3


0.04

In [302]:
for ep in range(epochs):
    epoch_loss = 0
    for s in range(sample_loops):
        x, y = samples('train', batch_size)

        logits = md(x)
        Y = torch.tensor(y).view(-1)
       
        loss = F.cross_entropy(logits, Y) # loss function
        with torch.no_grad():
            epoch_loss += loss

        # again stuff on parameters should probably be in model?
        for p in parameters:
          p.grad = None
        loss.backward()

        for p in parameters:
            p.data -= learning_rate * p.grad
    #just keep any epoch stuff in a no grad block
    with torch.no_grad():
        if ep in lrs:
            learning_rate = lrs[ep]

        running_loss.append(epoch_loss.item())
        running_lr.append(learning_rate)

        ud_ratio.append([ (learning_rate*p.grad.std()/ p.data.std()).log10().item() for p in parameters ])
    
        if ep % 4 == 0:
            print(epoch_loss/sample_loops)
            learning_rate *= .92
            print(ep, learning_rate)

tensor(3.1719)
0 0.0368
tensor(2.4397)
4 0.033856000000000004
tensor(2.3878)
8 0.031147520000000005
tensor(2.3617)
12 0.028655718400000006
tensor(2.3441)
16 0.026363260928000006
tensor(2.3328)
20 0.024254200053760007
tensor(2.3228)
24 0.022313864049459207
tensor(2.3167)
28 0.02052875492550247
tensor(2.3114)
32 0.018886454531462274
tensor(2.3051)
36 0.01737553816894529
tensor(2.3045)
40 0.015985495115429668
tensor(2.3003)
44 0.014706655506195295
tensor(2.2969)
48 0.013530123065699671
tensor(2.2944)
52 0.012447713220443699
tensor(2.2907)
56 0.011451896162808204
tensor(2.2901)
60 0.0092
tensor(2.2871)
64 0.008464000000000001
tensor(2.2851)
68 0.007786880000000001
tensor(2.2840)
72 0.007163929600000001
tensor(2.2821)
76 0.0065908152320000015


In [303]:
print(epoch_loss/sample_loops, m_setup, learning_rate)

tensor(2.2811) {'embed': 9, 'epochs': 80, 'batch_size': 32, 'sample_loops': 8000, 'hidden_size': 200, 'learning_rates': {0: 0.04, 60: 0.01}} 0.0065908152320000015


In [300]:
print(epoch_loss/sample_loops, m_setup, learning_rate)

tensor(2.3809) {'embed': 9, 'epochs': 40, 'batch_size': 32, 'sample_loops': 4000, 'hidden_size': 200, 'learning_rates': {0: 0.02}} 0.008687769084472646


In [201]:
print(epoch_loss/sample_loops, m_setup, learning_rate)

tensor(2.3494) {'embed': 9, 'epochs': 12, 'batch_size': 32, 'sample_loops': 2000, 'hidden_size': 100, 'learning_rates': {0: 0.04}} 0.024254200053760007


In [None]:
tensor(2.3738) {'d': 9, 'e': 12, 'b_s': 32, 's_l': 2000}

tensor(2.4077) 3dim
10 0.06063550013440003
tensor(2.4335) 3dim
10 0.06063550013440003

In [None]:
eps = [i for i in range(len(running_loss))]
ls = [l for l in running_loss]

plt.plot(eps, ls)

t = layers[2].weight.detach()
hy, hx = torch.histogram(t, density=True)
plt.figure(figsize=(20,4))
plt.plot(hx.detach()[1:], hy.detach())
t2 = layers[-2].weight.detach()
h2y, h2x = torch.histogram(t2, density=True)
plt.plot(h2x.detach()[1:], h2y.detach())

# look at our gradient distributions

t = layers[2].weight.grad.detach()
hy, hx = torch.histogram(t, density=True)
plt.figure(figsize=(20,4))
plt.plot(hx.detach()[1:], hy.detach())
t2 = layers[-2].weight.grad.detach()
h2y, h2x = torch.histogram(t2, density=True)
plt.plot(h2x.detach()[1:], h2y.detach())


plt.figure(figsize=(20,4))
plt.plot(ud_ratio)