In [1]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [2]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt
from model_550m import STU
import time
import random
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from lds import LDS

In [4]:
layer_i = 2
state_dim = 10000
batch_size = 2
epochs = 4000
seq_len = 512
kx = 5
lr = 0.0001

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the layer i weights
stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
stu_layer_full.eval()

# Initialize LDS model
lds = LDS(state_dim, 896, 896, kx).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=lr)

# Training
lds_loss_values = []

best_loss = float('inf')

  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


In [5]:
for epoch in range(epochs):
    inputs = torch.randn(batch_size, seq_len, 896).to(device).to(torch.bfloat16)
    stu_outputs = stu_layer_full(inputs).to(device)

    optimizer.zero_grad()
    loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 0.792000949382782
Epoch 10, Loss: 0.7540223002433777
Epoch 20, Loss: 0.7185772657394409
Epoch 30, Loss: 0.663846492767334
Epoch 40, Loss: 0.6181823015213013
Epoch 50, Loss: 0.5651823878288269


KeyboardInterrupt: 

In [None]:
torch.save(lds.state_dict(), "lds_10k_5.pth")

In [None]:
import torch

def train_lds(config):
    
    for layer_i in config["layers"]:
        print(f"Training Layer {layer_i}...")

        # Load the layer i weights
        stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
        stu_layer_full.eval()

        # Initialize LDS model
        lds = LDS(config["state_dim"], 896, 896, config["kx"]).to(device)
        optimizer = torch.optim.Adam(lds.parameters(), lr=config["lr"])

        # Training
        lds_loss_values = []
        best_loss = float('inf')

        for epoch in range(config["epochs"]):
            inputs = torch.randn(config["batch_size"], config["seq_len"], 896).to(device).to(torch.bfloat16)
            stu_outputs = stu_layer_full(inputs).to(device)

            optimizer.zero_grad()
            loss = lds.compute_loss(inputs, stu_outputs.to(torch.float).detach())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
            lds_loss_values.append(loss.item())
            optimizer.step()

            with torch.no_grad():
                lds.A.data.clamp_(max=1, min=-1)

            if epoch % 10 == 0:
                print(f"Layer {layer_i}, Epoch {epoch}, Loss: {loss.item()}")

        # Save the trained model
        torch.save(lds.state_dict(), f"lds_layer_{layer_i}_10k_5.pth")

        print(f"Finished training Layer {layer_i}.")

config = {
    "layers": [0, 2, 4, 6, 8, 10],
    "state_dim": 10000,
    "batch_size": 2,
    "epochs": 4000,
    "seq_len": 512,
    "kx": 5,
    "lr": 0.0001
}

train_lds(config)
