In [1]:
import math
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from accelerate import Accelerator

from preprocessor import load_and_preprocess, decoding, process_data
from qwen import load_qwen

import numpy as np

import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from preprocessor import get_dataset

import wandb
import joblib

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt

import torchtune

import gc

  from .autonotebook import tqdm as notebook_tqdm
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [2]:
#for matplotlib plots
SMALL_SIZE = 15+5
MEDIUM_SIZE = 20+5
BIGGER_SIZE = 25+5

plt.rc('font', size=SMALL_SIZE)
plt.rc('axes', titlesize=SMALL_SIZE)
plt.rc('axes', labelsize=MEDIUM_SIZE)
plt.rc('xtick', labelsize=SMALL_SIZE)
plt.rc('ytick', labelsize=SMALL_SIZE)
plt.rc('legend', fontsize=SMALL_SIZE)
plt.rc('figure', titlesize=BIGGER_SIZE)

In [3]:
batch_size = 4
test_size = 0.2
max_steps = 200
max_ctx_length = 512 #768#512
weight_decay = 0.01
points = 80

# Define parameter grid
lora_ranks = [2, 4, 8]
learning_rates = [1e-5, 5e-5, 1e-4]

rank = lora_ranks[2]
lr = learning_rates[2]

In [4]:
class LoRALinear(nn.Module):
    def __init__(self, original_linear: nn.Linear, r: int, alpha: int = None):
        super().__init__()
        assert isinstance(original_linear, nn.Linear)
        self.original_linear = original_linear
        self.original_linear.weight.requires_grad = False
        if self.original_linear.bias is not None:
            self.original_linear.bias.requires_grad = False
        in_dim = original_linear.in_features
        out_dim = original_linear.out_features
        self.r = r
        self.alpha = alpha if alpha else r

        device = original_linear.weight.device
        self.A = nn.Parameter(torch.empty(r, in_dim, device=device))
        self.B = nn.Parameter(torch.zeros(out_dim, r, device=device))
        
        # Initialise A with He initialization
        nn.init.kaiming_normal_(self.A, nonlinearity="linear")

    def forward(self, x):
        base_out = self.original_linear(x)
        lora_out = (x @ self.A.T) @ self.B.T
        return base_out + lora_out * (self.alpha / self.r)

In [5]:
# Update hyperparameters
lora_rank = rank
lora_alpha = 2*lora_rank
learning_rate = lr

model, tokenizer = load_qwen()

In [6]:
# Process the data into sequences of text
train_texts, val_texts, test_texts = load_and_preprocess("lotka_volterra_data.h5", test_size=test_size)

# ^Each of these is a `list[str]` representing contiguous parts of the time series,
#  in text form (using the LLMTIME scheme).

# Modified tokenization with chunking
def process_sequences(texts, tokenizer, max_length=512, stride=256):
    all_input_ids = []
    for text in texts:
        # Apply Qwen's tokenization scheme to the text:
        encoding = tokenizer(text, return_tensors="pt", add_special_tokens=False, padding_side='left')
        seq_ids = encoding.input_ids[0]

        # Create sliding windows to further divide the data into chunks:
        for i in range(0, len(seq_ids), stride):
            chunk = seq_ids[i : i + max_length]
            if len(chunk) < max_length:
                chunk = torch.cat(
                    [
                        torch.full((max_length - len(chunk),), tokenizer.pad_token_id),
                        chunk,
                    ]
                )
            all_input_ids.append(chunk)
    return torch.stack(all_input_ids)


def process_data(texts, tokenizer, points=80):
    given_input_ids = []
    for text in texts:
        given_text = ';'.join([chunk for i, chunk in enumerate(text.split(';')) if i < points])
        encoding_given = tokenizer(given_text, return_tensors="pt", padding='max_length', padding_side='left', max_length=1200)
        given_input_ids.append(encoding_given.input_ids[0])
    return np.stack([text for text in texts]), torch.stack(given_input_ids)

def running_mse(prediction, actual):
    mse = []
    for i in range(len(prediction)):
        mse.append(mean_squared_error(prediction[:i+1], actual[:i+1]))
    return mse

def evaluate_model(model, val_loader, step, max_batches=None):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch_idx, (batch,) in enumerate(tqdm(val_loader, desc="val set")):
            # Exit loop after processing max_batches
            if max_batches is not None and batch_idx >= max_batches:
                break
            outputs = model(batch, labels=batch)
            loss = outputs.loss
            total_loss += loss.item()
            
    
    # Calculate metrics - divide by actual number of batches processed
    num_batches = min(len(val_loader), max_batches) if max_batches is not None else len(val_loader)
    avg_loss = total_loss / num_batches

    print(f'Loss on validation subset ({num_batches}/{len(val_loader)} batches) at step {step}: {avg_loss:.4f}')
    return avg_loss

# Defines the maximum context length for the model
train_input_ids = process_sequences(
    train_texts, tokenizer, max_ctx_length, stride=max_ctx_length // 2
)
val_input_ids = process_sequences(
    val_texts, tokenizer, max_ctx_length, stride=max_ctx_length
)
test_texts_all, test_input_ids_some = process_data(
    test_texts, tokenizer, points=points
)

In [7]:

train_dataset = TensorDataset(train_input_ids)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(val_input_ids)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(test_input_ids_some)
test_loader = DataLoader(test_dataset, shuffle=False)

In [None]:
# Dictionary to store results
grid_results = {}

print(f"\n{'='*50}")
print(f"Training with lora_rank={lora_rank}, learning_rate={learning_rate}")
print(f"{'='*50}\n")

# Apply LoRA with current rank
for layer in model.model.layers:
    layer.self_attn.q_proj = LoRALinear(layer.self_attn.q_proj, r=lora_rank, alpha=lora_alpha)
    layer.self_attn.v_proj = LoRALinear(layer.self_attn.v_proj, r=lora_rank, alpha=2*lora_alpha)

# Create optimizer with current learning rate
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), 
    lr=learning_rate, 
    # weight_decay=weight_decay,
)

# Prepare with accelerator
accelerator = Accelerator()
model, optimizer, train_loader_local, val_loader_local, test_loader_local = accelerator.prepare(
    model, optimizer, train_loader, val_loader, test_loader
)

# Train the model (shortened training for grid search)
steps = 0
train_losses = []
val_losses = []
early_stop_steps = min(max_steps, 500)  # Reduce training for grid search

while steps < early_stop_steps:
    progress_bar = tqdm(train_loader_local, desc=f"Steps {steps}")
    for (batch,) in progress_bar:
        model.train()
        optimizer.zero_grad()
        outputs = model(batch, labels=batch)
        loss = outputs.loss
        train_losses.append([loss.item(), steps])
        accelerator.backward(loss)
        optimizer.step()
        
        if (steps % 50) == 0:
            avg_loss = evaluate_model(model, val_loader_local, steps)
            val_losses.append([avg_loss, steps])
            model.train()
            
        steps += 1
        progress_bar.set_postfix(loss=loss.item())
        
        if steps >= early_stop_steps:
            break

# Final evaluation
final_val_loss = evaluate_model(model, val_loader_local, steps)

# Store results
grid_results[(lora_rank, learning_rate)] = {
    "final_val_loss": final_val_loss,
    "train_losses": train_losses,
    "val_losses": val_losses,
}

del model
del tokenizer
del optimizer
del train_loader_local
del val_loader_local
del test_loader_local
del accelerator
del train_losses
del val_losses


Training with lora_rank=8, learning_rate=0.0001



val set: 100%|██████████| 75/75 [01:34<00:00,  1.26s/it]
Steps 0:   0%|          | 1/1000 [01:40<27:59:03, 100.84s/it, loss=2.84]

Loss on validation subset (75/75 batches) at step 0: 3.6430


val set: 100%|██████████| 75/75 [01:35<00:00,  1.28s/it]/it, loss=0.936]
Steps 0:   5%|▌         | 51/1000 [07:10<8:51:54, 33.63s/it, loss=0.959]

Loss on validation subset (75/75 batches) at step 50: 0.7479


val set: 100%|██████████| 75/75 [01:35<00:00,  1.28s/it]s/it, loss=0.808]
Steps 0:  10%|█         | 101/1000 [12:38<8:18:40, 33.28s/it, loss=0.854]

Loss on validation subset (75/75 batches) at step 100: 0.6952


val set: 100%|██████████| 75/75 [01:34<00:00,  1.26s/it]s/it, loss=0.81] 
Steps 0:  15%|█▌        | 151/1000 [18:04<7:45:42, 32.91s/it, loss=0.401]

Loss on validation subset (75/75 batches) at step 150: 0.6709


Steps 0:  20%|█▉        | 199/1000 [21:55<1:28:14,  6.61s/it, loss=0.723]
val set: 100%|██████████| 75/75 [01:14<00:00,  1.00it/s]

Loss on validation subset (75/75 batches) at step 200: 0.6561





In [None]:
joblib.dump(grid_results, f"../results/grid_results_{rank}_{lr}.joblib")

['../results/grid_history_8_0.0001.joblib']

In [None]:

# # Print and save results
# print("\n\nGrid Search Results:")
# print("=====================")
# for params, results in grid_results.items():
#     print(f"lora_rank={params[0]}, learning_rate={params[1]}: validation loss = {results['final_val_loss']:.6f}")

# print(f"\nBest parameters: lora_rank={best_params[0]}, learning_rate={best_params[1]}, validation loss = {best_val_loss:.6f}")

# # Save grid search results
# joblib.dump(grid_results, f"../results/grid_search_results.pkl")

# # Create visualization of grid search results
# plt.figure(figsize=(12, 8))
# for params, results in grid_results.items():
#     plt.plot(np.arange(len(results['train_losses'])), 
#              results['train_losses'], 
#              label=f"rank={params[0]}, lr={params[1]}", 
#              alpha=0.7)

# plt.xlabel("Steps")
# plt.ylabel("Training Loss")
# plt.title("Training Loss by Hyperparameter Configuration")
# plt.legend()
# plt.savefig("../plots/lora_lr_grid_search_training_losses.png")
# plt.show()

# # Plot validation losses
# val_loss_data = {params: results["final_val_loss"] for params, results in grid_results.items()}
# params_labels = [f"rank={p[0]}, lr={p[1]}" for p in val_loss_data.keys()]
# val_losses = list(val_loss_data.values())

# plt.figure(figsize=(14, 6))
# plt.bar(params_labels, val_losses)
# plt.xlabel("Hyperparameters")
# plt.ylabel("Final Validation Loss")
# plt.title("Validation Loss by Hyperparameter Configuration")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.savefig("../plots/lora_lr_grid_search_validation_losses.png")
# plt.show()

: 