In [7]:
# transformer plus a simple equations: learning some wave rules
import torch
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn 
from torch.nn import functional as F
from wavefn import WaveFunction

batch_size = 8
block_size = 32
max_iter = 1000
epochs = 10
eval_interval = 500
learning_rate = 5e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embed = 64
n_head = 16
n_layer = 16
dropout = 0.05
text = []

# tokens set as integers
tokens = ['F1', 'F2', 'F3', 'F4', '_', 'M', 'H', 'B']

# Let's make some waves! Not strictly needed, but repurposing the wave fn so keeping it for now
c = 1.0  # Wave speed
dx = 0.1  # Spatial step size
dt = 0.1  # Time step size
wave_fn = WaveFunction(c, dx, dt)

vocab_size=len(tokens)
stoi = { ch:i for i, ch in enumerate(tokens)}
itos = { i:ch for i, ch in enumerate(tokens)}
def enc(s, pad_length):
    encoded = [stoi[ch] for ch in s if ch in stoi]
    padding = [stoi['_']] * (pad_length - len(encoded))  # Padding token is '_'
    return encoded + padding[:max(0, pad_length - len(encoded))]  # Ensures the sequence is exactly pad_length long
dec = lambda l: ''.join([itos[i] for i in l[1:-1]])  # Skipping the first and last items ('s' and 'e')
# Define an appropriate size for your validation batch
val_batch_size = batch_size  

def generate_operation_sequence_with_objective(block_size):
    objective = random.choice(['M', 'H', 'B'])
    if objective == 'M':
        sequence = ['F1', '_', '_', 'F4']
    elif objective == 'H':
        sequence = ['_', 'F2', 'F3', '_']
    else:  # B includes all operations
        sequence = ['F1', 'F2', 'F3', 'F4']

    # Ensure the sequence length does not exceed block size - 2 for start/end tokens
    sequence = sequence[:block_size - 2]
    return [objective] + sequence

def apply_operations_sequence_to_wave(sequence, initial_wave):
    wave_fn = WaveFunction()
    current_profile = np.array(initial_wave, dtype=int)
    operations_map = {
        'F1': wave_fn.F1, 'F2': wave_fn.F2, 'F3': wave_fn.F3, 'F4': wave_fn.F4,
        '_': lambda x: x  # No-operation function returns the input as is
    }
    
    # Start applying transformations after the objective; skip '_'
    for op in sequence[1:]:  # Skip the 'objective' token
        if op in operations_map:  # Check if operation is defined in the map
            current_profile = operations_map[op](current_profile)
    
    return current_profile

def get_batch(batch_size, block_size):
    sequences = [generate_operation_sequence_with_objective(block_size) for _ in range(batch_size)]
    X, Y = [], []
    for seq in sequences:
        X.append(torch.tensor([stoi[seq[0]]], dtype=torch.long).unsqueeze(0).to(device))
        Y.append(torch.tensor([stoi[ch] for ch in seq[1:]], dtype=torch.long).to(device))
    X = torch.cat(X, dim=0)  # [batch_size, 1]
    Y = pad_sequence(Y, batch_first=True, padding_value=stoi['_'])  # Pad Y
    # print("X shape:", X.shape, "Y shape:", Y.shape)  # Debugging statement
    return X, Y

# single head attention
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed,head_size,bias=False)
        self.query = nn.Linear(n_embed,head_size,bias=False)
        self.value = nn.Linear(n_embed,head_size,bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) *C**-0.5 # scaled attention
        # wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf')) # decoder block
        wei = F.softmax(wei,dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei@v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed,n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        out =  torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out) # Projection si the linear transformation of the outcome of prev layer
        return out

class SinusoidalActivation(nn.Module):
    def forward(self, x):
        # return torch.sin(x)
        return x + torch.sin(x) ** 2

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed,4* n_embed), 
            nn.GELU(),
            # SinusoidalActivation(),
            nn.Linear(4* n_embed, n_embed),
            nn.Dropout(dropout),
            )
        self

    def forward(self,x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed //n_head
        self.sa = MultiHeadAttention(n_head,head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
    
    def forward(self,x):
        attn_output = self.sa(self.ln1(x))
        x = x + attn_output  # add & norm for attention
        ffwd_output = self.ffwd(self.ln2(x))
        x = x + ffwd_output  # add & norm for feedforward
        return x

# bigram language model
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size  # Save block_size as an instance variable
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size,n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed,n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed,vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape  # T should be set to the number of tokens per sequence (currently it seems to be 1)
        idx = idx.repeat(1, 4)  # Assuming you want to handle 4 tokens per sequence
        T = 4  # Manually setting T to 4 tokens per sequence

        tok_emb = self.token_embedding_table(idx.squeeze(-1))  # [B, T, n_embed]
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)).unsqueeze(0).repeat(B, 1, 1)  # Repeat for batch

        x = tok_emb + pos_emb  # Combine embeddings

        x = self.blocks(x)  # Process through transformer blocks
        x = self.ln_f(x)  # Apply final layer normalization
        logits = self.lm_head(x)  # [B, T, vocab_size]
        logits = logits.view(B * T, self.vocab_size)  # Reshape to [batch_size * sequence_length, vocab_size]

        if targets is not None:
            targets = targets.view(-1)  # Flatten targets to [batch_size * sequence_length]
            loss = F.cross_entropy(logits, targets)
            return logits, loss
        return logits, None

    def generate(self, idx, max_new_tokens):
        outputs = idx
        for _ in range(max_new_tokens):
            idx_cond = outputs[:, -block_size:] if outputs.size(1) > block_size else outputs
            logits, loss = self(idx_cond)
            # Check if logits is two-dimensional and adjust accordingly
            if logits.dim() == 2:  # If only batch_size and num_classes, no sequence_length
                logits = logits.unsqueeze(1)  # Unsqueeze to simulate sequence_length of 1
            logits = logits[:, -1, :]  # Now this should work
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            outputs = torch.cat([outputs, idx_next], dim=1)
        return outputs


def apply_predicted_operations(predicted_operations, initial_wave):
    wave_fn = WaveFunction()
    current_profile = np.array(initial_wave, dtype=int)
    operations_map = {'F1': wave_fn.F1, 'F2': wave_fn.F2, 'F3': wave_fn.F3, 'F4': wave_fn.F4}
    
    for op in predicted_operations:
        if op in operations_map:
            current_profile = operations_map[op](current_profile)
    
    return current_profile

model = LanguageModel()
m = model.to(device)
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, alpha=0.99, eps=1e-08, weight_decay=0.01, momentum=0.5, centered=False)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate, rho=0.9, eps=1e-06, weight_decay=0.01)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True)
loss = None  # Initialize loss variable outside the loop

for epoch in range(epochs):
    for iter in range(max_iter // epochs):  # Distribute iterations across epochs
        model.train()
        xb, yb = get_batch(batch_size, block_size)
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        max_norm = 1
        # Clip gradients to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        if iter % eval_interval == 0 and loss is not None:  # Validation logic
            model.eval()
            with torch.no_grad():
                xv, yv = get_batch(val_batch_size, block_size)
                val_logits, val_loss = model(xv, yv)
                print(f"Epoch {epoch}, Iteration {iter}: Training Loss = {loss.item()}, Validation Loss = {val_loss.item()}")
            model.train()

    scheduler.step(val_loss)  # Update the learning rate at the end of each epoch

torch.save(model, 'models/cat_wavefn_model.pth')

Epoch 0, Iteration 0: Training Loss = 2.08620285987854, Validation Loss = 1.720586895942688
Epoch 1, Iteration 0: Training Loss = 0.06775912642478943, Validation Loss = 0.06311294436454773
Epoch 2, Iteration 0: Training Loss = 0.018831947818398476, Validation Loss = 0.01729976199567318
Epoch 3, Iteration 0: Training Loss = 0.008490938693284988, Validation Loss = 0.00882628932595253
Epoch 4, Iteration 0: Training Loss = 0.005634819623082876, Validation Loss = 0.005472458899021149
Epoch 5, Iteration 0: Training Loss = 0.003730582073330879, Validation Loss = 0.0032127031590789557
Epoch 6, Iteration 0: Training Loss = 0.002684080507606268, Validation Loss = 0.0025460761971771717


In [None]:
import re
import numpy as np
import torch
import matplotlib.pyplot as plt
from wavefn import WaveFunction

# Setting up the environment
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the pre-trained model and set it to evaluation mode
model = torch.load('models/cat_wavefn_model.pth', map_location=device)
model.eval()
np.random.seed(42)  # Seed for reproducibility
# Initialize the WaveFunction class
wave_fn = WaveFunction()
# Generate a random initial wave profile
input_length = 32
initial_wave = (np.random.rand(input_length) * 10).astype(int)
print(f"Initial Wave Profile: {initial_wave}")

# Prepare the model input (assuming the model takes numerical input)
initial_conditions = ['M']
input_tensor = torch.tensor([[stoi[ch] for ch in initial_conditions]], dtype=torch.long).to(device)

# Predict operations
with torch.no_grad():
    output = model.generate(input_tensor, max_new_tokens=5)
    transformer_output = dec(output[0].tolist())

print(f"Transformer output is: {transformer_output}")
def parse_operations(input_string):
    # Use a regular expression to find a continuous sequence of 'F' followed by a digit until another letter is encountered
    match = re.match(r'(F\d)+', input_string)
    if match:
        # Extract all 'F' followed by a digit from the matched group
        return re.findall(r'F\d', match.group())
    else:
        return []

predicted_operations = parse_operations(transformer_output)

# Compare with direct simulation from WaveFunction using an objective
objective_wave_output = wave_fn.simulate_wave_equation(initial_wave, objective='M')
print(f"Objective wave output is: {objective_wave_output}")
print(f"Predicted operations are: {predicted_operations}")
transformed_wave = wave_fn.custom_transform(initial_wave, predicted_operations)
print(f"Transformed wave is: {transformed_wave}")

# Plotting results
plt.figure(figsize=(14, 7))
plt.subplot(1, 3, 1)
plt.plot(initial_wave, label='Initial Wave', marker='o')
plt.title('Initial Wave')
plt.xlabel('Position')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(objective_wave_output, label='Objective-Based Transformation', marker='x')
plt.title('Objective-Based Transformation')
plt.xlabel('Position')
plt.ylabel('Amplitude')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(transformed_wave, label='Transformer Predicted Transformation', marker='x')
plt.title('Transformer Predicted Transformation')
plt.xlabel('Position')
plt.ylabel('Amplitude')
plt.legend()

plt.tight_layout()
plt.show()


Initial Wave Profile: [3 9 7 5 1 1 0 8 6 7 0 9 8 2 1 1 3 5 4 2 6 1 2 3 4 7 1 5 5 0 6 1]


IndexError: too many indices for tensor of dimension 2