### Setup & Definition

In [1]:
import torch
import random
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#deallocate all cuda memory
torch.cuda.empty_cache()

#print cuda memory
print(torch.cuda.memory_allocated())

cuda
0


In [2]:
class AUNNModel(nn.Module):
    def __init__(
        self, 
        embedding_dim:int, 
        output_dim:int, 
        num_layers:int, 
        hidden_dim:int):        

        assert num_layers >= 2, "Number of layers must be at least 2"

        super(AUNNModel, self).__init__() 
    
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Input Layer
        self.input_layer =  nn.Linear(self.embedding_dim, self.hidden_dim)

        # Hidden Layers
        self.layers = nn.ModuleList()
        for _ in range(self.num_layers - 2):  # Exclude input and output layers
            self.layers.append(nn.Sequential(
                nn.Linear(self.hidden_dim, self.hidden_dim),
                nn.SiLU()
            ))

        # Output Layer
        self.output_layer = nn.Linear(self.hidden_dim, self.output_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def encode(self, x: torch.Tensor): #binary encoding
        dim = self.embedding_dim
        encoding = ((x.unsqueeze(1) >> torch.arange(dim, device=x.device)) & 1).to(torch.float32)
        return encoding

    def forward(self, indices):
        
        x = self.encode(indices)
        x = self.input_layer(x)
        x = x + nn.SiLU()(x)

        for layer in self.layers:
            x = x + layer(x)  # MLP output with skip connection

        x = self.output_layer(x)
        return x

In [3]:
# Define a function to save the model checkpoint
def save_checkpoint(model, params, optimizer, losses, filename="checkpoint.pth"):
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
        'params':{}
    }

    keys = ['embedding_dim', 'output_dim', 'num_layers', 'hidden_dim']
    assert all(k in params for k in keys)
    for k in keys:
        checkpoint['params'][k] = params[k]

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved with loss {losses[-1]:.4f}")

In [4]:
def load_checkpoint(filename="checkpoint.pth"):

    checkpoint = torch.load(filename, weights_only=True)
    
    params = checkpoint['params']
    model = AUNNModel(**params)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    optimizer = optim.AdamW(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    losses = checkpoint['losses']

    print(f"Checkpoint loaded: loss {losses[-1]:.4f}")

    return model, optimizer, losses

### Train

In [5]:
#really long repeated pattern
# text = "abc" * 10_000  # Repeat the sequence to create a long string

#really long random pattern
options = ['|'+'a'*7,'|'+'b'*7,'|'+'c'*7]
num_sequences = 15_000
num_repeats = num_sequences * len(options)
options = options * num_repeats
random.seed(42)
random.shuffle(options)
text = ''.join(options)
print(len(text))

1080000


In [6]:
vocab = sorted(set(text))
token_to_id = {token: id for id, token in enumerate(vocab)}
id_to_token = {id: token for token, id in token_to_id.items()}
token_ids = [token_to_id[char] for char in text]
token_ids = torch.tensor(token_ids, dtype=torch.long).to(device)
print(f'Training on {len(token_ids)} tokens.')

Training on 1080000 tokens.


In [7]:
# Hyperparameters
embedd_dim = 32
num_layers = 8    # Must be even and at least 2 (bc of skip connections)
hidden_dim = 256   # Size of hidden layers

# Initialize the model, loss function, and optimizer
model = AUNNModel(embedd_dim, len(vocab), num_layers, hidden_dim).to(device)
print(f"Model has {model.count_params():,} parameters")

Model has 404,228 parameters


In [8]:
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters())

In [9]:
from tqdm.auto import tqdm

model.train()

context_len = 1024
for i in tqdm(range(len(token_ids) - context_len), total=len(token_ids) - context_len):

    start = i
    end = i + context_len
    targets = token_ids[start:end]
    indices = torch.arange(start, end).to(device)
        
    j = 0
    while True:

        outputs = model(indices)
        loss = criterion(outputs, targets)
        loss_val = loss.item()

        predicted = outputs.argmax(dim=1)
        a = targets[-1]
        a = id_to_token[a.item()]
        b = predicted[-1]
        b = id_to_token[b.item()]

        if j == 0:
            print(a,b,'T' if a == b else 'F', f"{loss_val:f}" )
        j += 1

        loss_thresh = 0.0001
        if a == b and loss_val < loss_thresh:
            break

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

  0%|          | 0/1078976 [00:00<?, ?it/s]

c | F 25.366972
| a F 0.002966
b c F 0.006612
b b T 0.000098
b b T 0.000098
b b T 0.000096
b b T 0.000095
b b T 0.000095
b b T 0.000094
| | T 0.000094
c a F 0.016525
c c T 0.000100
c c T 0.000100
c c T 0.000099
c c T 0.000099
c c T 0.000101
c c T 0.000105
| | T 0.000100
c c T 0.000127
c c T 0.000100
c c T 0.000099
c c T 0.000149
c c T 0.000099
c c T 0.000099
c c T 0.000099
| | T 0.000099
c c T 0.000253
c c T 0.000111
c c T 0.000100
c a F 0.001296
c c T 0.000099
c c T 0.000098
c c T 0.000102
| | T 0.000100
b b T 0.000099
b b T 0.000100
b b T 0.000099
b b T 0.000167
b b T 0.000199
b b T 0.000098
b b T 0.000098
| | T 0.000098
c c T 0.000553
c c T 0.000192
c c T 0.000631
c c T 0.000097
c c T 0.000097
c c T 0.000097
c c T 0.000098
| | T 0.000098
a b F 0.011699
a a T 0.000155
a a T 0.000100
a a T 0.000100
a a T 0.000100
a a T 0.000145
a a T 0.000099
| | T 0.000099
a c F 0.010244
a a T 0.000172
a a T 0.000328
a a T 0.000098
a a T 0.000098
a a T 0.000161
a a T 0.000109
| | T 0.000100
c b F 0.0

KeyboardInterrupt: 

In [12]:
params = {
    'embedding_dim': embedd_dim,
    'output_dim': len(vocab),
    'num_layers': num_layers,
    'hidden_dim': hidden_dim
}
save_checkpoint(model, params, optimizer, [0], filename="sequence/checkpoint_sequential.pth")

Checkpoint saved with loss 0.0000


In [19]:
model, optimizer, losses = load_checkpoint(filename="sequence/checkpoint_sequential.pth")

Checkpoint loaded: loss 0.0000


In [17]:
# Generate Text Function
def generate_text(model, start_index, length):
    model.eval()
    generated_tokens = []
    indices = torch.arange(start_index, start_index + length).to(device)
    with torch.no_grad():
        outputs = model(indices)
        _, predicted = torch.max(outputs, 1)
    for id in predicted:
        token = id_to_token.get(id.item(), "<UNK>")
        generated_tokens.append(token)
    return ''.join(generated_tokens)

# Generate text starting from zero
start_index = 0
generated_text = generate_text(model, start_index=start_index, length=100)
print("Generated Text:")
print(generated_text)

# Generate text starting from the next index after the training data
start_index = len(text) - 100
generated_text = generate_text(model, start_index=start_index, length=200)
print("Generated Text:")
before_end = generated_text[:100]
after_end = generated_text[100:]
print(f"{before_end}~~{after_end}")

Generated Text:
|aaaaaaababaaabaaaaaaaaaaaaaaaaabbbbbbaabbbbbbbbaaaaaaaaaaaaaaaaccccccccbabbbabbcbbbbbbbcccccccccccc
Generated Text:
aaaa|aaaaaaa|bbbbbbb|ccccccc|aababcb|aaaaaaa|bbbbbbb|aaaaaaa|ccccccc|bbbbbbb|ccccccc|aaaaaaa|bbbbbbb~~|bbbbbbb|bbbbbbb|bbbbbbb|aaaaaaa|bbbbbbb|bbbbbbb|caaaaac|bbbbbbb|ccccccc|ccccccc|bbbbbbb|aaaaaaa|ccc


In [20]:
N = len(text) - 1

#see value before conditioning
generated_text = generate_text(model, start_index=N+1, length=150)
print("Original Text:")
print(generated_text) # "|a..."
print("")

#conditioning the model
conditioning_targets = ['|','c']  # Desired tokens at N+1 and N+2, this also works if you use other letters as conditioning_targets
conditioning_positions = [N+1, N+2]
# conditioning_targets = ['c','a','b'] #show conditioning works on discontinuous tokens
# conditioning_positions = [N+2, N+7, N+12]

conditioning_target_indices = [token_to_id[token] for token in conditioning_targets]
targets_tensor = torch.tensor(conditioning_target_indices, dtype=torch.long).to(device)
positions_tensor = torch.tensor(conditioning_positions).to(device)

criterion = nn.CrossEntropyLoss()
conditioning_optimizer = optim.SGD(model.parameters(), lr=1e-3)
model.train()
step = 0

while True:
    conditioning_optimizer.zero_grad()
    outputs = model(positions_tensor)
    loss = criterion(outputs, targets_tensor)
    loss.backward()
    conditioning_optimizer.step()
    print(f"Conditioning Step {step+1}, Loss: {loss.item():.6f}")
    #get new outputs
    generated_tokens = []
    outputs = model(positions_tensor)
    predicted = torch.argmax(outputs, dim=1)
    for id in predicted:
        token = id_to_token.get(id.item(), "<UNK>")
        generated_tokens.append(token)
    print("".join(generated_tokens))
    if generated_tokens == conditioning_targets:
        break

#see value after conditioning
generated_text = generate_text(model, start_index=N+1, length=150)
print("")
print("Text after conditioning:")
print(generated_text) # "|bbb|..." shows conditioning works because of new values for N+3 and N+4

Original Text:
|bbbbbbb|bbbbbbb|bbbbbbb|aaaaaaa|bbbbbbb|bbbbbbb|caaaaac|bbbbbbb|ccccccc|ccccccc|bbbbbbb|aaaaaaa|ccccccc|ccccccc|ccccccc|ccccccc|bbbbbbb|bbbbbbb|aaaaa

Conditioning Step 1, Loss: 2.787506
|c

Text after conditioning:
|ccccccc|ccccccc|bbbbbbb|aaaaaaa|bbbcbcc|bbbbbbb|cccaccc|bbbbbbb|ccccccc|ccccccc|bbbbbbb|aaaaaaa|ccccccc|ccccccc|ccccccc|ccccccc|bbbbbbb|bbbbbbb|aaaaa
