In [68]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [69]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt
from model_550m import STU
import time
import random
from torch.nn import functional as F

In [70]:
from lds import LDS

In [80]:
layer_i = 2
state_dim = 10000
batch_size = 2
epochs = 4000
seq_len = 512
kx = 5
lr = 0.0001

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the layer i weights
stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
stu_layer_full.eval()

# Initialize LDS model
lds = LDS(state_dim, 896, 896, kx).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=lr)

# Training
lds_loss_values = []

best_loss = float('inf')

  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


In [81]:
for epoch in range(epochs):
    inputs = torch.randn(batch_size, seq_len, 896).to(device).to(torch.bfloat16)
    stu_outputs = stu_layer_full(inputs).to(device)

    optimizer.zero_grad()
    loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 0.7881022095680237
Epoch 10, Loss: 0.7615760564804077
Epoch 20, Loss: 0.7112730741500854
Epoch 30, Loss: 0.6704589128494263
Epoch 40, Loss: 0.6195156574249268
Epoch 50, Loss: 0.5625367760658264
Epoch 60, Loss: 0.5113675594329834
Epoch 70, Loss: 0.4532823860645294
Epoch 80, Loss: 0.4126841127872467
Epoch 90, Loss: 0.37279218435287476
Epoch 100, Loss: 0.33213090896606445
Epoch 110, Loss: 0.300788551568985
Epoch 120, Loss: 0.2772143483161926
Epoch 130, Loss: 0.2515886127948761
Epoch 140, Loss: 0.23406247794628143
Epoch 150, Loss: 0.21455128490924835
Epoch 160, Loss: 0.1989036202430725
Epoch 170, Loss: 0.18765760958194733
Epoch 180, Loss: 0.1771417111158371
Epoch 190, Loss: 0.16260582208633423
Epoch 200, Loss: 0.15747031569480896
Epoch 210, Loss: 0.14411146938800812
Epoch 220, Loss: 0.13924729824066162
Epoch 230, Loss: 0.1332768201828003
Epoch 240, Loss: 0.12550587952136993
Epoch 250, Loss: 0.11985667794942856
Epoch 260, Loss: 0.11149496585130692
Epoch 270, Loss: 0.106300599

KeyboardInterrupt: 

In [78]:
torch.save(lds.state_dict(), "lds_10k_5.pth")

In [None]:
import torch

def train_lds(config):
    
    for layer_i in config["layers"]:
        print(f"Training Layer {layer_i}...")

        # Load the layer i weights
        stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
        stu_layer_full.eval()

        # Initialize LDS model
        lds = LDS(config["state_dim"], 896, 896, config["kx"]).to(device)
        optimizer = torch.optim.Adam(lds.parameters(), lr=config["lr"])

        # Training
        lds_loss_values = []
        best_loss = float('inf')

        for epoch in range(config["epochs"]):
            inputs = torch.randn(config["batch_size"], config["seq_len"], 896).to(device).to(torch.bfloat16)
            stu_outputs = stu_layer_full(inputs).to(device)

            optimizer.zero_grad()
            loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
            lds_loss_values.append(loss.item())
            optimizer.step()

            with torch.no_grad():
                lds.A.data.clamp_(max=1, min=-1)

            if epoch % 10 == 0:
                print(f"Layer {layer_i}, Epoch {epoch}, Loss: {loss.item()}")

        # Save the trained model
        torch.save(lds.state_dict(), f"lds_layer_{layer_i}_10k_5.pth")

        print(f"Finished training Layer {layer_i}.")

config = {
    "layers": [0, 2, 4, 6, 8, 10],
    "state_dim": 10000,
    "batch_size": 2,
    "epochs": 4000,
    "seq_len": 512,
    "kx": 5,
    "lr": 0.0001
}

train_lds(config)


Training Layer 0...


  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


Layer 0, Epoch 0, Loss: 0.2048506885766983
Layer 0, Epoch 10, Loss: 0.19542242586612701
Layer 0, Epoch 20, Loss: 0.18141958117485046
Layer 0, Epoch 30, Loss: 0.15748588740825653
Layer 0, Epoch 40, Loss: 0.13437943160533905
Layer 0, Epoch 50, Loss: 0.10905499756336212
Layer 0, Epoch 60, Loss: 0.08817411959171295
Layer 0, Epoch 70, Loss: 0.07206659764051437
Layer 0, Epoch 80, Loss: 0.059815146028995514
Layer 0, Epoch 90, Loss: 0.05120417848229408
Layer 0, Epoch 100, Loss: 0.043458081781864166
Layer 0, Epoch 110, Loss: 0.03683347627520561
Layer 0, Epoch 120, Loss: 0.03205214440822601
Layer 0, Epoch 130, Loss: 0.027714872732758522
Layer 0, Epoch 140, Loss: 0.024229789152741432
Layer 0, Epoch 150, Loss: 0.021853851154446602
Layer 0, Epoch 160, Loss: 0.01916656270623207
Layer 0, Epoch 170, Loss: 0.01711888238787651
Layer 0, Epoch 180, Loss: 0.01490543968975544
Layer 0, Epoch 190, Loss: 0.01391309592872858
Layer 0, Epoch 200, Loss: 0.011846772395074368
Layer 0, Epoch 210, Loss: 0.010839668102