In [14]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt

# Define the LDS class
class LDS(nn.Module):
    def __init__(self, state_dim, input_dim, output_dim):
        super(LDS, self).__init__()
        self.d_out = output_dim
        self.h0 = nn.Parameter(torch.randn(state_dim))
        init_A = torch.randn(state_dim)
        self.A = nn.Parameter(init_A / torch.max(torch.abs(init_A)))
        self.B = nn.Parameter(torch.randn(input_dim, state_dim) / input_dim)
        self.C = nn.Parameter(torch.randn(state_dim, output_dim) / state_dim)
        self.D = nn.Parameter(torch.randn(input_dim, output_dim) / output_dim)
        self.M = nn.Parameter(torch.randn(output_dim, output_dim) / output_dim)

    def forward(self, inputs):
        bsz, seq_len, _ = inputs.shape
        h_t = self.h0.expand(bsz, self.h0.shape[0]).to(inputs.device)
        all_h_t = []

        for t in range(seq_len):
            u_t = inputs[:, t, :]
            h_t = self.A.flatten() * h_t + u_t @ self.B
            all_h_t.append(h_t.unsqueeze(1))

        all_h_t = torch.cat(all_h_t, dim=1)
        outputs = torch.matmul(all_h_t, self.C)
        return outputs

    def compute_loss(self, inputs, targets):
        mse_loss = nn.MSELoss()
        outputs = self(inputs)
        return mse_loss(outputs, targets)


# Command-line argument parsing
# parser = argparse.ArgumentParser(description="Train LDS model")
# parser.add_argument("--layer_i", type=int, help="Layer index", default = 2)
# parser.add_argument("--state_dim", type=int, help="State dimension", default = 100)
# parser.add_argument("--batch_size", type=int,  help="Batch size", default = 5)
# parser.add_argument("--epochs", type=int,  help="Number of epochs", default = 100)
# parser.add_argument("--seq_len", type=int,help="Sequence length", default = 1000)
# parser.add_argument("--lr", type=float, help="Learning rate", default = 0.01)
# args = parser.parse_args()

class Obj():
    def __init__(self):
        self.layer_i = 2
        self.state_dim = 100
        self.batch_size = 5
        self.epochs = 100
        self.seq_len = 1000
        self.lr = .02

args = Obj() 


# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the layer i weights
stu_layer_full = torch.load(f"./stu_layer_{args.layer_i}_500m_param_full.pt")
stu_layer_full.to(device)

# Initialize LDS model
lds = LDS(args.state_dim, 768, 768).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=args.lr)

# Training
lds_epochs = args.epochs
lds_loss_values = []

for epoch in range(lds_epochs):
    inputs = torch.randn(args.batch_size, args.seq_len, 768).to(device).to(torch.bfloat16)
    stu_outputs = stu_layer_full(inputs).to(device)

    optimizer.zero_grad()
    loss = lds.compute_loss(inputs.to(stu_outputs.dtype), stu_outputs)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# Save the loss progression to a file
with open("loss_progression.txt", "w") as f:
    for loss_value in lds_loss_values:
        f.write(f"{loss_value}\n")

# Plot the loss progression
plt.figure()
plt.plot(range(len(lds_loss_values)), lds_loss_values, label="Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Progression")
plt.legend()
plt.savefig("loss_progression.png")
plt.close()

# Save the LDS model and optimizer state
torch.save({
    "model_state_dict": lds.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}, "lds_model_and_optimizer.pt")

print("Training complete. Files saved.")

  stu_layer_full = torch.load(f"./stu_layer_{args.layer_i}_500m_param_full.pt")


Unable to import FlashFFTConv: No module named 'flashfftconv'. Falling back to PyTorch implementation.


  from .autonotebook import tqdm as notebook_tqdm


Epoch 0, Loss: 1.3745590448379517


KeyboardInterrupt: 

Available objects for config:
    AliasManager
    DisplayFormatter
    HistoryManager
    IPCompleter
    IPKernelApp
    LoggingMagics
    MagicsManager
    OSMagics
    PrefilterManager
    ScriptMagics
    StoreMagics
    ZMQInteractiveShell
