In [68]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [69]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt
from model_550m import STU
import time
import random
from torch.nn import functional as F

In [70]:
from lds import LDS

In [71]:
layer_i = 0
state_dim = 10000
batch_size = 2
epochs = 4000
seq_len = 512
kx = 5
lr = 0.0001

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the layer i weights
stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
stu_layer_full.eval()

# Initialize LDS model
lds = LDS(state_dim, 896, 896, kx).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=lr)

# Training
lds_loss_values = []

best_loss = float('inf')

  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


In [75]:
for epoch in range(epochs):
    inputs = torch.randn(batch_size, seq_len, 896).to(device).to(torch.bfloat16)
    stu_outputs = stu_layer_full(inputs).to(device)

    optimizer.zero_grad()
    loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 0.20948727428913116
Epoch 10, Loss: 0.19298724830150604
Epoch 20, Loss: 0.17641928791999817
Epoch 30, Loss: 0.15829849243164062
Epoch 40, Loss: 0.13620591163635254
Epoch 50, Loss: 0.10767078399658203
Epoch 60, Loss: 0.0900392085313797
Epoch 70, Loss: 0.07257521152496338
Epoch 80, Loss: 0.060955245047807693
Epoch 90, Loss: 0.04956171661615372
Epoch 100, Loss: 0.044229187071323395
Epoch 110, Loss: 0.037870265543460846
Epoch 120, Loss: 0.03266969323158264
Epoch 130, Loss: 0.02808004803955555
Epoch 140, Loss: 0.025284023955464363
Epoch 150, Loss: 0.022732891142368317
Epoch 160, Loss: 0.019783174619078636
Epoch 170, Loss: 0.017606545239686966
Epoch 180, Loss: 0.01574600860476494
Epoch 190, Loss: 0.014383433386683464
Epoch 200, Loss: 0.012187255546450615
Epoch 210, Loss: 0.011149908415973186
Epoch 220, Loss: 0.01042140182107687
Epoch 230, Loss: 0.009279215708374977
Epoch 240, Loss: 0.008433567360043526
Epoch 250, Loss: 0.007872800342738628
Epoch 260, Loss: 0.007039038930088282

KeyboardInterrupt: 

In [78]:
torch.save(lds.state_dict(), "lds_10k_5.pth")

In [None]:
import torch

def train_lds(config):
    
    for layer_i in config["layers"]:
        print(f"Training Layer {layer_i}...")

        # Load the layer i weights
        stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
        stu_layer_full.eval()

        # Initialize LDS model
        lds = LDS(config["state_dim"], 896, 896, config["kx"]).to(device)
        optimizer = torch.optim.Adam(lds.parameters(), lr=config["lr"])

        # Training
        lds_loss_values = []
        best_loss = float('inf')

        for epoch in range(config["epochs"]):
            inputs = torch.randn(config["batch_size"], config["seq_len"], 896).to(device).to(torch.bfloat16)
            stu_outputs = stu_layer_full(inputs).to(device)

            optimizer.zero_grad()
            loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
            lds_loss_values.append(loss.item())
            optimizer.step()

            with torch.no_grad():
                lds.A.data.clamp_(max=1, min=-1)

            if epoch % 10 == 0:
                print(f"Layer {layer_i}, Epoch {epoch}, Loss: {loss.item()}")

        # Save the trained model
        torch.save(lds.state_dict(), f"lds_layer_{layer_i}_10k_5.pth")

        print(f"Finished training Layer {layer_i}.")

config = {
    "layers": [0, 2, 4, 6, 8, 10],
    "state_dim": 10000,
    "batch_size": 2,
    "epochs": 4000,
    "seq_len": 512,
    "kx": 5,
    "lr": 0.0001
}

train_lds(config)


Training Layer 0...


  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


Layer 0, Epoch 0, Loss: 0.2048506885766983
Layer 0, Epoch 10, Loss: 0.19542242586612701
