In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/next-concept-predictor/venv/lib/python3.11/site-packages')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# example walkthrough (full model)
here i'll start with the first token of these two sequences and we'll walk through the calculations

string1: " I think therefore I am\<endoftext>"

string2: "Every cloud has a silver lining\<endoftext>"

later we'll turn it into visual training and inference loops so that you can see how the ordering plays out

In [3]:
# setting hyperparameters
b=2
n=7
d=3
v=12

In [4]:
E_dict = {" I":0,
          " think":1,
         " there":2,
         "fore":3,
         " am":4,
         "Every":5,
         " cloud":6,
         " has":7,
         " a":8,
         " silver":9,
         " lining":10,
         "<endoftext>":11}
E_dict

{' I': 0,
 ' think': 1,
 ' there': 2,
 'fore': 3,
 ' am': 4,
 'Every': 5,
 ' cloud': 6,
 ' has': 7,
 ' a': 8,
 ' silver': 9,
 ' lining': 10,
 '<endoftext>': 11}

In [5]:
# create sequence of tokens
S_full_text = [[' I', ' think', ' there', 'fore', ' I', ' am', '<endoftext>'],
               ['Every', ' cloud', ' has', ' a', ' silver', ' lining', '<endoftext>']]
# turn into indices
S_full_indices = [[E_dict[word] for word in sentence] for sentence in S_full_text]
# turn into a tensor
S_full = torch.tensor(S_full_indices)
print("S_full: ", S_full.dtype, S_full.shape, S_full)

# starting off with the first token for each sequence
i=0
if i==0: 
    S_i = S_full[:,i].unsqueeze(dim=1) 
else: 
    S_i = S_full[:,i]
print("S_i: ", S_i.dtype, S_i.shape, S_i)

S_full:  torch.int64 torch.Size([2, 7]) tensor([[ 0,  1,  2,  3,  0,  4, 11],
        [ 5,  6,  7,  8,  9, 10, 11]])
S_i:  torch.int64 torch.Size([2, 1]) tensor([[0],
        [5]])


In [6]:
# embedding matrix
torch.manual_seed(420)
E = torch.randn(v,d)
# cosine norm for E
E = E / torch.norm(E, p=2, dim=1, keepdim=True) 
print("E: ", E.shape, E)

# Look up the embeddings
X = F.embedding(S_i, E)
print("X: ", X.shape, X)

E:  torch.Size([12, 3]) tensor([[-0.0084,  0.6013,  0.7990],
        [-0.7062,  0.0558,  0.7058],
        [-0.2553,  0.3534,  0.9000],
        [ 0.1616,  0.9819,  0.0985],
        [ 0.4082, -0.1934,  0.8921],
        [ 0.9826, -0.1837, -0.0258],
        [-0.0743, -0.6807, -0.7288],
        [ 0.7308,  0.3404,  0.5917],
        [-0.1637,  0.9282,  0.3341],
        [ 0.5391, -0.4316, -0.7232],
        [-0.0080,  0.9751, -0.2215],
        [-0.3469, -0.9187,  0.1890]])
X:  torch.Size([2, 1, 3]) tensor([[[-0.0084,  0.6013,  0.7990]],

        [[ 0.9826, -0.1837, -0.0258]]])


pretend the activation function and cosine normalization are being used inside transformer layers

we don't mess with the actual transformer layers here other than these two pieces so no need to code them

In [7]:
# activation function
Xi=torch.sin(torch.pi*X)
print("Xi: ", Xi.shape, Xi)

# i know sin is a weird choice but ask me about how terrible the traditional options are

Xi:  torch.Size([2, 1, 3]) tensor([[[-0.0264,  0.9498,  0.5903]],

        [[ 0.0545, -0.5455, -0.0811]]])


In [8]:
# Normalize each vector to have a unit length
norms = torch.norm(Xi, p=2, dim=2, keepdim=True)
Xf = Xi / norms
print("norms: ", norms.shape, norms)
print("Xf: ", Xf.shape, Xf)

norms:  torch.Size([2, 1, 1]) tensor([[[1.1186]],

        [[0.5542]]])
Xf:  torch.Size([2, 1, 3]) tensor([[[-0.0236,  0.8491,  0.5277]],

        [[ 0.0983, -0.9843, -0.1463]]])


In [9]:
# select final row
Y = Xf[:,-1:].squeeze(dim=1)
print("Y: ", Y.shape, Y)
# in a noormal GPT we'd first multiply by E.T and then select the row, because normal gpt's want
# to immediately select from a b,v tensor whereas we need a b,d tensor to give us concept vectors

Y:  torch.Size([2, 3]) tensor([[-0.0236,  0.8491,  0.5277],
        [ 0.0983, -0.9843, -0.1463]])


In [10]:
# Perform cosine similarity, which simplifies down to dot product
# We transpose E to get shape (d, v) and then use matmul for cosine similarity
Z = torch.matmul(Y, E.T)
print("E.T: ", E.T.shape, E.T)
print("Z: ", Z.shape, Z)

# let's make sure Z looks good
print("Z Max:", torch.max(Z, dim=1).values)
print("Z Min:", torch.min(Z, dim=1).values)

E.T:  torch.Size([3, 12]) tensor([[-0.0084, -0.7062, -0.2553,  0.1616,  0.4082,  0.9826, -0.0743,  0.7308,
         -0.1637,  0.5391, -0.0080, -0.3469],
        [ 0.6013,  0.0558,  0.3534,  0.9819, -0.1934, -0.1837, -0.6807,  0.3404,
          0.9282, -0.4316,  0.9751, -0.9187],
        [ 0.7990,  0.7058,  0.9000,  0.0985,  0.8921, -0.0258, -0.7288,  0.5917,
          0.3341, -0.7232, -0.2215,  0.1890]])
Z:  torch.Size([2, 12]) tensor([[ 0.9324,  0.4364,  0.7810,  0.8819,  0.2969, -0.1928, -0.9608,  0.5840,
          0.9683, -0.7608,  0.7113, -0.6721],
        [-0.7096, -0.2276, -0.5047, -0.9651,  0.1000,  0.2812,  0.7694, -0.3498,
         -0.9787,  0.5837, -0.9282,  0.8425]])
Z Max: tensor([0.9683, 0.8425])
Z Min: tensor([-0.9608, -0.9787])


### the conditional parts

In [11]:
# finding closest tokens to the raw output & their similarity scores
G, H = torch.max(Z, dim=1)
print("G: ", G.shape, G)
print("H: ", H.shape, H)

G:  torch.Size([2]) tensor([0.9683, 0.8425])
H:  torch.Size([2]) tensor([ 8, 11])


In [12]:
gamma = 0.9 # $-1\leq \gamma < 1$

In [13]:
A = (G > gamma).float()
print("A: ", A.dtype, A.shape, A)

A:  torch.float32 torch.Size([2]) tensor([1., 0.])


In [14]:
A_unsqueeze = A.unsqueeze(1)
print("A_unsqueeze: ", A_unsqueeze.shape, A_unsqueeze)
A_expand = A_unsqueeze.expand(-1, v)
print("A_expand: ", A_expand.shape, A_expand)

A_unsqueeze:  torch.Size([2, 1]) tensor([[1.],
        [0.]])
A_expand:  torch.Size([2, 12]) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [15]:
I = torch.ones(b,v).float()
print("I: ", I.dtype, I.shape, I)
A_prime_expand = I-A_expand
print("A_prime_expand: ", A_prime_expand.dtype, A_prime_expand.shape, A_prime_expand)

I:  torch.float32 torch.Size([2, 12]) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
A_prime_expand:  torch.float32 torch.Size([2, 12]) tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])


### only used for trainnig

In [16]:
# Generate indices for one-hot positions from S
#ideal_indices = torch.randint(low=0, high=v, size=(b,))
print("S_full: ", S_full.shape, S_full)
S_next_ideal = S_full[:,i+1]
print("S_next_ideal: ", S_next_ideal.shape, S_next_ideal)

# Create the one-hot matrix
Q = torch.nn.functional.one_hot(S_next_ideal, num_classes=v).float()
print("Q: ", Q.dtype, Q.shape, Q)

S_full:  torch.Size([2, 7]) tensor([[ 0,  1,  2,  3,  0,  4, 11],
        [ 5,  6,  7,  8,  9, 10, 11]])
S_next_ideal:  torch.Size([2]) tensor([1, 6])
Q:  torch.float32 torch.Size([2, 12]) tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])


In [17]:
# Create a mask where zeros are marked
zero_mask_Q = Q == 0

# Replace zeros with -1
Q[zero_mask_Q] = -1
print("Q_prime: ", Q.shape, Q)

Q_prime:  torch.Size([2, 12]) tensor([[-1.,  1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
        [-1., -1., -1., -1., -1., -1.,  1., -1., -1., -1., -1., -1.]])


In [18]:
print("A_expand: ", A_expand.shape, A_expand)
Z_ideal = Q*A_expand
print("Z_ideal: ", Z_ideal.shape, Z_ideal)

A_expand:  torch.Size([2, 12]) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
Z_ideal:  torch.Size([2, 12]) tensor([[-1.,  1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
        [-0., -0., -0., -0., -0., -0.,  0., -0., -0., -0., -0., -0.]])


In [19]:
Z_concepts = Z*A_prime_expand
print("Z_concepts: ", Z_concepts.shape, Z_concepts)

Z_concepts:  torch.Size([2, 12]) tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
          0.0000, -0.0000,  0.0000, -0.0000],
        [-0.7096, -0.2276, -0.5047, -0.9651,  0.1000,  0.2812,  0.7694, -0.3498,
         -0.9787,  0.5837, -0.9282,  0.8425]])


In [20]:
Z_train = Z_concepts + Z_ideal
print("Z_train: ", Z_train.shape, Z_train)

Z_train:  torch.Size([2, 12]) tensor([[-1.0000,  1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
         -1.0000, -1.0000, -1.0000, -1.0000],
        [-0.7096, -0.2276, -0.5047, -0.9651,  0.1000,  0.2812,  0.7694, -0.3498,
         -0.9787,  0.5837, -0.9282,  0.8425]])


### only used for inference

In [21]:
# Apply softmax along the last dimension (v)
P = F.softmax(Z, dim=-1)
print("P: ", P.shape, P)
print(torch.sum(Z[0], dim=-1), torch.sum(P[0], dim=-1))

P:  torch.Size([2, 12]) tensor([[0.1360, 0.0828, 0.1169, 0.1293, 0.0720, 0.0441, 0.0205, 0.0960, 0.1410,
         0.0250, 0.1090, 0.0273],
        [0.0395, 0.0640, 0.0485, 0.0306, 0.0888, 0.1064, 0.1733, 0.0566, 0.0302,
         0.1440, 0.0317, 0.1865]])
tensor(3.0058) tensor(1.)


In [22]:
S_iplus1 = torch.max(P, dim=1).indices
print("S_iplus1: ", S_iplus1.shape, S_iplus1)
# if you were only going to do greedy decoding you could skip making P and just create S_nplus1 from Z
# however, you need to make P if you want to do probabilistic decoding so we're doing it here

S_iplus1:  torch.Size([2]) tensor([ 8, 11])


In [23]:
# Create a mask where zeros are marked
zero_mask_A = A == 0

# Replace zeros with -1
S_iplus1_prime = S_iplus1
S_iplus1_prime[zero_mask_A] = v

print("S_iplus1_prime: ", S_iplus1_prime.shape, S_iplus1_prime)
# so at the v value we're essentially going to be ignoring that prediction 

S_iplus1_prime:  torch.Size([2]) tensor([ 8, 12])


In [25]:
#i=0
print("old i: ", i)
i+=1
print("new i: ", i)

old i:  0
new i:  1


In [26]:
epsilon = torch.zeros(1,d)
print("epsilon: ", epsilon.shape, epsilon)
E_prime = torch.cat((E,epsilon), dim=0)
print("E_prime: ", E_prime.shape, E_prime)
# notice this new empty embedding vector at the end

epsilon:  torch.Size([1, 3]) tensor([[0., 0., 0.]])
E_prime:  torch.Size([13, 3]) tensor([[-0.0084,  0.6013,  0.7990],
        [-0.7062,  0.0558,  0.7058],
        [-0.2553,  0.3534,  0.9000],
        [ 0.1616,  0.9819,  0.0985],
        [ 0.4082, -0.1934,  0.8921],
        [ 0.9826, -0.1837, -0.0258],
        [-0.0743, -0.6807, -0.7288],
        [ 0.7308,  0.3404,  0.5917],
        [-0.1637,  0.9282,  0.3341],
        [ 0.5391, -0.4316, -0.7232],
        [-0.0080,  0.9751, -0.2215],
        [-0.3469, -0.9187,  0.1890],
        [ 0.0000,  0.0000,  0.0000]])


In [30]:
print
Y_token = F.embedding(S_iplus1_prime,E_prime)
print("Y_token: ", Y_token.shape, Y_token)

Y_token:  torch.Size([2, 3]) tensor([[-0.1637,  0.9282,  0.3341],
        [ 0.0000,  0.0000,  0.0000]])


In [31]:
Y_concept = Y*A_prime_expand[:,0:d]
print("Y_concept: ", Y_concept.shape, Y_concept)

Y_concept:  torch.Size([2, 3]) tensor([[-0.0000,  0.0000,  0.0000],
        [ 0.0983, -0.9843, -0.1463]])


In [32]:
Y_inference = Y_token + Y_concept
print("Y_inference: ", Y_inference.shape, Y_inference)

Y_inference:  torch.Size([2, 3]) tensor([[-0.1637,  0.9282,  0.3341],
        [ 0.0983, -0.9843, -0.1463]])


In [33]:
print("old X0: ", X.shape, X)
X = torch.cat((X,Y_inference.unsqueeze(1)), dim=1)
print("new X0: ", X.shape, X)

old X0:  torch.Size([2, 1, 3]) tensor([[[-0.0084,  0.6013,  0.7990]],

        [[ 0.9826, -0.1837, -0.0258]]])
new X0:  torch.Size([2, 2, 3]) tensor([[[-0.0084,  0.6013,  0.7990],
         [-0.1637,  0.9282,  0.3341]],

        [[ 0.9826, -0.1837, -0.0258],
         [ 0.0983, -0.9843, -0.1463]]])


In [34]:
#print("the old S_i: ", S_i.shape, S_i)
#S_iplus1_unsqueeze = S_iplus1.unsqueeze(dim=1)
#S_i = torch.concat((S_i,S_iplus1_unsqueeze), dim=1)
#print("the new S_i: ", S_i.shape, S_i)

the old S_i:  torch.Size([2, 2]) tensor([[ 0,  8],
        [ 5, 12]])
the new S_i:  torch.Size([2, 3]) tensor([[ 0,  8,  8],
        [ 5, 12, 12]])


## change in gamma over time

In [103]:
gamma_00 = -1
gamma_f0 = -1
gamma_0f = 0.8
gamma_ff = -1

In [104]:
n = 6 # sequence length
m = 9 # num of epochs

In [105]:
delta_gamma_m = (gamma_0f - gamma_00)/m
print(delta_gamma_m)

0.2


In [117]:
for i in range(m):
    gamma_i = gamma_00 + i*delta_gamma_m
    print(f"This epoch, to become a token a concept needs cos similarity at least gamma_0,{i} = {gamma_i:.2f}")
    
    delta_gamma_n = (gamma_i - gamma_ff)/n
    for j in range(2*n):
        
        print(f"\tAt the {j+1}'th token in this sequence, gamma_{j},{i} = {max(-1, gamma_i - j*delta_gamma_n):.2f}")

This epoch, to become a token a concept needs cos similarity at least gamma_0,0 = -1.00
	At the 1'th token in this sequence, gamma_0,0 = -1.00
	At the 2'th token in this sequence, gamma_1,0 = -1.00
	At the 3'th token in this sequence, gamma_2,0 = -1.00
	At the 4'th token in this sequence, gamma_3,0 = -1.00
	At the 5'th token in this sequence, gamma_4,0 = -1.00
	At the 6'th token in this sequence, gamma_5,0 = -1.00
	At the 7'th token in this sequence, gamma_6,0 = -1.00
	At the 8'th token in this sequence, gamma_7,0 = -1.00
	At the 9'th token in this sequence, gamma_8,0 = -1.00
	At the 10'th token in this sequence, gamma_9,0 = -1.00
	At the 11'th token in this sequence, gamma_10,0 = -1.00
	At the 12'th token in this sequence, gamma_11,0 = -1.00
This epoch, to become a token a concept needs cos similarity at least gamma_0,1 = -0.80
	At the 1'th token in this sequence, gamma_0,1 = -0.80
	At the 2'th token in this sequence, gamma_1,1 = -0.83
	At the 3'th token in this sequence, gamma_2,1 = 

# now making everything into loops

first we'll define a bunch of functions

then we'll put those functions into training & inference loops

In [None]:
# setting hyperparameters
b=2
n=7
d=3
v=12

# defining the vocabulary
E_dict = {" I":0,
          " think":1,
         " there":2,
         "fore":3,
         " am":4,
         "Every":5,
         " cloud":6,
         " has":7,
         " a":8,
         " silver":9,
         " lining":10,
         "<endoftext>":11}

# and our sequence of tokens
S_full_text = [[' I', ' think', ' there', 'fore', ' I', ' am', '<endoftext>'],
               ['Every', ' cloud', ' has', ' a', ' silver', ' lining', '<endoftext>']]

i think for now imma stick to writing a walkthrough in the overleaf pdf and revisit a code version later if you don't think the walkthrough is clear

In [24]:
t = [[-1,1,-1,-1],[1,-1,-1,-1]]
tt = torch.tensor(t).float()
tt

tensor([[-1.,  1., -1., -1.],
        [ 1., -1., -1., -1.]])

In [26]:
tts = torch.softmax(tt, dim=1)
tts

tensor([[0.0963, 0.7112, 0.0963, 0.0963],
        [0.7112, 0.0963, 0.0963, 0.0963]])

In [28]:
z = [[0,1,0,0],[1,0,0,0]]
zt = torch.tensor(z).float()
zts = torch.softmax(zt, dim=1)
zts

tensor([[0.1749, 0.4754, 0.1749, 0.1749],
        [0.4754, 0.1749, 0.1749, 0.1749]])

In [30]:
# Example logits and target
logits = torch.tensor([[0.92, 0.2, -0.1], [-0.56, 0.02, 0.3]])
target = torch.tensor([0, 2])

# Apply LogSoftmax
log_softmax_output = F.log_softmax(logits, dim=1)

# Compute NLLLoss
nll_loss_output = F.nll_loss(log_softmax_output, target)

# The combined cross-entropy loss
cross_entropy_loss = F.cross_entropy(logits, target)

print("LogSoftmax Output:", log_softmax_output)
print("NLLLoss Output:", nll_loss_output)
print("CrossEntropy Loss:", cross_entropy_loss)

LogSoftmax Output: tensor([[-0.6138, -1.3338, -1.6338],
        [-1.6388, -1.0588, -0.7788]])
NLLLoss Output: tensor(0.6963)
CrossEntropy Loss: tensor(0.6963)


In [32]:
def custom_loss(Z, S, A):
    # Apply LogSoftmax
    log_probs = F.log_softmax(Z, dim=1)

    # Gather the log probabilities corresponding to the true classes
    gathered_probs = log_probs.gather(1, S.unsqueeze(1)).squeeze(1)

    # Apply the mask
    masked_probs = gathered_probs * A

    # Compute the average loss over unmasked entries
    # Use A.sum() to count the number of unmasked entries
    loss = -torch.sum(masked_probs) / torch.clamp(A.sum(), min=1)
    return loss

# Example usage
A = torch.tensor([1, 0], dtype=torch.int32) # Second entry is ignored

loss = custom_loss(logits, target, A)
print("Custom Loss:", loss)

Custom Loss: tensor(0.6138)


In [34]:
# Reshape A to (b, 1) so it can be broadcasted
A_prime = A.view(-1, 1)

# Apply the mask
masked_logits = logits * A_prime

print(A_prime, masked_logits)

tensor([[1],
        [0]], dtype=torch.int32) tensor([[ 0.9200,  0.2000, -0.1000],
        [-0.0000,  0.0000,  0.0000]])
