In [None]:
#Experimenting with depth - filtering  code for Algorithmic dataset
#   Omada 2 -- Grokfast experiments

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# from google.colab import files
# files.upload()


In [None]:
!ls /content/drive/MyDrive/PatRec_Project_Shared_Folder/

 algorithmic				      Grokking_mnist_v1.ipynb   __pycache__
 grokfast.py				      Grokking_qm9_v1.ipynb     requirements.txt
'Grokking and how to accelerate it.gslides'   Grokking_qm9_v2.ipynb     results


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/PatRec_Project_Shared_Folder')

In [None]:
!pip install -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt

Collecting torch_geometric (from -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt (line 4))
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m65.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import math
from argparse import ArgumentParser
from itertools import permutations
import copy

import matplotlib.pyplot as plt
from tqdm import tqdm
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from typing import Dict, Optional, Literal
import torch
import torch.nn as nn
from typing import List, Optional, Dict, Literal


In [None]:
from grokfast import gradfilter_ma, gradfilter_ema, gradfilter_with_depth_scaling

In [None]:

def gradfilter_with_depth_scaling(
    m: nn.Module,
    grads: Optional[Dict[str, deque]] = None,
    window_size: int = 100,
    alpha: float = 0.98,
    lamb_max: float = 3.0,
    lamb_min: float = 1.0,
    d_max: int = 12,  # Total number of transformer layers
    filter_type: Literal['mean', 'sum'] = 'mean',
    warmup: bool = True,
    trigger: bool = False,
    embedding_layer_name: str = "embedding",
    final_and_output_layer_names: List[str] = ["ln_f", "head"],  # Default final and output layer names
) -> Dict[str, deque]:
    """
    Applies gradient filtering with dynamic depth-based lambda scaling.

    Args:
        m (nn.Module): The model containing the parameters.
        grads (Optional[Dict[str, deque]]): Dictionary for storing past gradients.
        window_size (int): Number of past gradients to consider.
        lamb_max (float): Maximum lambda value for scaling.
        lamb_min (float): Minimum lambda value for scaling.
        d_max (int): Total depth (number of transformer layers).
        filter_type (Literal['mean', 'sum']): Filtering strategy ('mean' or 'sum').
        warmup (bool): Whether to enable warmup for gradient filtering.
        trigger (bool): Optional trigger condition for gradient filtering.
        embedding_layer_name (str): Substring identifying embedding layer parameters.
        final_and_output_layer_names (List[str]): List of substrings identifying final/output layer parameters.

    Returns:
        Dict[str, deque]: Updated gradient storage for the model parameters.
     """
    if grads is None:
        grads = {n: p.grad.data.clone() for n, p in m.named_parameters() if p.requires_grad and p.grad is not None}

    for n, p in m.named_parameters():
        if p.requires_grad and p.grad is not None:
            # Determine depth or position
            if embedding_layer_name in n:
                depth = 0  # Embedding layers are assigned depth 0
            elif "layers" in n:
                # Extract depth information from name, e.g., "layers.0", "layers.1"
                depth = int(n.split(".")[1]) + 1  # Increment depth for transformer layers
            elif final_and_output_layer_names and any(layer_name in n for layer_name in final_and_output_layer_names):
                depth = d_max + 1  # Final and output layers are d_max + 1
            else:
                depth = d_max  # Default depth for unclassified layers

            # Adjust lambda based on depth
            lambda_d = lamb_max - (depth / (d_max + 1)) * (lamb_max - lamb_min)

            # Apply EMA update
            if n not in grads:
                grads[n] = p.grad.data.clone()  # Initialize EMA
            else:
                grads[n] = grads[n] * alpha + p.grad.data.clone() * (1 - alpha)

            # Scale gradient by depth-aware lambda
            p.grad.data = p.grad.data + grads[n] * lambda_d

    return grads


In [None]:

class Block(nn.Module):
    """Causal transformer block
    """

    def __init__(self, dim, num_heads):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim)
        self.ln_2 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        attn_mask = torch.full(
            (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
        )
        attn_mask = torch.triu(attn_mask, diagonal=1)
        attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


class Decoder(nn.Module):
    """Causal Transformer decoder
    """

    def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5):
        super().__init__()
        self.token_embeddings = nn.Embedding(num_tokens, dim)
        self.position_embeddings = nn.Embedding(seq_len, dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(dim, num_heads))

        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_tokens, bias=False)

    def forward(self, x):
        h = self.token_embeddings(x)
        positions = torch.arange(x.shape[0], device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)
        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)
        logits = self.head(h)
        return logits




In [None]:
def multiplication_mod_p_data(p, eq_token, op_token):
    """x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
    """
    x = torch.arange(p)
    y = torch.arange(1, p)
    x, y = torch.cartesian_prod(x, y).T

    eq = torch.ones_like(x) * eq_token
    op = torch.ones_like(x) * op_token
    result = x * y % p

    # "All of our experiments used a small transformer trained on datasets of
    # equations of the form a◦b = c, where each of “a”, “◦”, “b”, “=”, and “c”
    # is a seperate token"
    return torch.stack([x, op, y, eq, result])



In [None]:
import os

# Specify the path to save in Google Drive
results_dir = "/content/drive/MyDrive/PatRec_Project_Shared_Folder/results/Algorithmic"
os.makedirs(results_dir, exist_ok=True)

In [None]:
            #    Demo to see labdas for each layer - also need to add fields (lambda,depth,..) in grads inside the filtering function for it to work

# #         Dummy training just to showcase depth-lambas of network

# def main(args):
#     import torch
#     import torch.nn as nn
#     from collections import deque
#     from typing import Optional, Dict, Literal

#     torch.manual_seed(args.seed)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # Mock tokens for the task
#     eq_token = args.p
#     op_token = args.p + 1
#     depth=args.depth

#     # Initialize the model
#     model = Decoder(
#         dim=128, num_layers=depth, num_heads=4, num_tokens=args.p + 2, seq_len=5
#     ).to(device)
#     print(model)
#     print(f"Total number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

#     # Create dummy input
#     dummy_input = torch.randint(0, args.p + 2, (5, 16)).to(device)  # Sequence of length 5, batch size 16

#     # Create dummy target (last token prediction task)
#     dummy_target = torch.randint(0, args.p + 2, (16,)).to(device)  # Batch size 16

#     # Forward pass
#     logits = model(dummy_input[:-1])  # Skip the last input token for causal prediction

#     # Loss computation
#     loss_fn = nn.CrossEntropyLoss()
#     loss = loss_fn(logits[-1], dummy_target)  # Compute loss for the last token

#     # Backward pass to generate gradients
#     loss.backward()

#     # Inspect parameter names and gradients
#     # print("\nParameter Names and Gradients:")
#     # for name, param in model.named_parameters():
#     #     if param.requires_grad:
#     #         grad_status = "Gradient computed" if param.grad is not None else "No gradient"
#     #         print(f"{name}: {grad_status}, Shape: {param.shape}")

#     # Pass gradients to the gradfilter_with_depth_scaling function for inspection
#     grads = gradfilter_with_depth_scaling(
#         model,
#         grads=None,  # Start with no previous gradients
#         window_size=args.window_size,
#         lamb_max=args.lamb_max,
#         lamb_min=args.lamb_min,
#         d_max=depth,  # Total number of transformer layers
#         filter_type='mean',
#         warmup=False,  # Skip warmup to apply the filtering immediately
#         trigger=False,
#         embedding_layer_name="embedding",
#         final_and_output_layer_names=["ln_f", "head"]
#     )

#     for name, grad_metadata in grads.items():
#       depth = grad_metadata.get("depth", "Unknown")
#       lambda_d = grad_metadata.get("lambda", "Unknown")
#       print(f"Layer: {name}, Depth: {depth}, Lambda: {lambda_d}, Queue Length: {len(grad_metadata['queue'])}")


#     # Print the filtered gradients for inspection
#     # print("\nFiltered Gradients:")
#     # for name, grad_queue in grads.items():
#     #     print(f"{name}: Queue length = {len(grad_queue)}")


In [None]:
# if __name__ == "__main__":
#     parser = ArgumentParser()
#     parser.add_argument("--label", default="")
#     parser.add_argument("--seed", type=int, default=0)
#     parser.add_argument("--depth", type=int, default=2)
#     parser.add_argument("--p", type=int, default=97)
#     parser.add_argument("--budget", type=int, default=3e5)
#     parser.add_argument("--batch_size", type=int, default=256)
#     parser.add_argument("--lr", type=float, default=1e-3)
#     parser.add_argument("--beta1", type=float, default=0.9)
#     parser.add_argument("--beta2", type=float, default=0.98)
#     parser.add_argument("--weight_decay", type=float, default=0)
#     parser.add_argument("--optimizer", default="Adam")
#     parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")
#     parser.add_argument("--alpha", type=float, default=0.99)
#     parser.add_argument("--window_size", type=int, default=50)
#     parser.add_argument("--lamb", type=float, default=5.0)
#     parser.add_argument("--lamb_max", type=float, default=5.0, help="Maximum lambda for depth scaling")
#     parser.add_argument("--lamb_min", type=float, default=1.0, help="Minimum lambda for depth scaling")
#     parser.add_argument("--two_stage", action='store_true')
#     parser.add_argument("--save_weights", action='store_true')
#     parser.add_argument("--dataset_fraction", type=float, default=1.0, help="Fraction of the dataset to use (0.0 to 1.0)")


#     # Parse known arguments to handle Jupyter conflicts
#     args, unknown = parser.parse_known_args()

#     # Modify label dynamically
#     filter_str = ('_' if args.label != '' else '') + args.filter
#     window_size_str = f'_w{args.window_size}'
#     alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
#     lamb_str = f'_l{int(args.lamb)}'

#     if args.filter == "none":
#         filter_suffix = ""
#     elif args.filter == "ma":
#         filter_suffix = window_size_str + lamb_str
#     elif args.filter == "ema":
#         filter_suffix = alpha_str + lamb_str
#     else:
#         raise ValueError(f"Unrecognized filter type {args.filter}")

#     optim_suffix = ""
#     if args.weight_decay != 0:
#         optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
#     if args.lr != 1e-3:
#         optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

#     args.label = args.label + filter_str + filter_suffix + optim_suffix

#    # Call main
#     main(args)

Decoder(
  (token_embeddings): Embedding(99, 128)
  (position_embeddings): Embedding(5, 128)
  (layers): ModuleList(
    (0-1): 2 x Block(
      (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=128, out_features=99, bias=False)
)
Total number of parameters: 422784
Layer: token_embeddings.weight, Depth: 0, Lambda: 5.0, Queue Length: 1
Layer: position_embeddings.weight, Depth: 0, Lambda: 5.0, Queue Length: 1
Layer: layers.0.ln_1.weight, Depth: 1, Lambda: 3.66666666666666

In [None]:
import copy
#           Proper Training now

def main(args):
    torch.manual_seed(args.seed)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # tokens for <op> and <=>. It's not clear why <=> is needed at all since it
    # has no effect on the output, but we'll leave it in to best follow the
    # paper.
    eq_token = args.p
    op_token = args.p + 1
    depth=args.depth

    #Values to validate on
    lambdas = [1.0, 2.0, 5.0]
    alphas = [0.8,  0.9, 0.98]

    stop_threshold = 0.95
    stop_patience = 10000  # Number of epochs to sustain 0.95 validation accuracy


    # Loop over all combinations of lambdas and alphas
    for lamb, alpha in itertools.product(lambdas, alphas):
        print(f"Testing lambda={lamb}, alpha={alpha}")

        # Initialize the model
        model = Decoder(
            dim=128, num_layers=depth, num_heads=4, num_tokens=args.p + 2, seq_len=5
        ).to(device)

        nparams = sum([p.numel() for p in model.parameters() if p.requires_grad])
        print(model)
        print(f"Total number of parameters: {nparams}")

        # Split data into training and validation sets
        data = multiplication_mod_p_data(args.p, eq_token, op_token)
        split_idx = data.shape[1] // 2
        perm = torch.randperm(data.shape[1])
        train_idx = perm[:split_idx]
        valid_idx = perm[split_idx:]
        train_data = data[:, train_idx]
        valid_data = data[:, valid_idx]

        # Initialize optimizer and scheduler
        optimizer = getattr(torch.optim, args.optimizer)(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            betas=(args.beta1, args.beta2),
        )
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda update: 1 if update > 10 else update / 10
        )

        steps_per_epoch = math.ceil(train_data.shape[1] / args.batch_size)
        its, train_acc, val_acc, train_loss, val_loss = [], [], [], [], []
        grads = None
        i = 0
        net_its, nets = [], []
        best_val_acc = 0
        best_params = {}
        patience_counter = 0
        stop_early = False
        steps_to_achieve = None

        for e in tqdm(range(int(args.budget) // steps_per_epoch)):
            if stop_early:
                break

            # Shuffle training data
            train_data = train_data[:, torch.randperm(train_data.shape[1])]

            for data, is_train in [(train_data, True), (valid_data, False)]:
                model.train(is_train)
                total_loss = 0
                total_acc = 0
                dl = torch.split(data, args.batch_size, dim=1)

                for input in dl:
                    input = input.to(device)
                    with torch.set_grad_enabled(is_train):
                        logits = model(input[:-1])
                        loss = F.cross_entropy(logits[-1], input[-1])
                        total_loss += loss.item() * input.shape[-1]

                    if is_train:
                        model.zero_grad()
                        loss.backward()

                        # Gradient filtering
                        trigger = i < 500 if args.two_stage else False
                        if args.filter == "none":
                            pass
                        elif args.filter == "ma":
                            grads = gradfilter_ma(
                                model, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger
                            )
                        elif args.filter == "ema":
                            grads = gradfilter_ema(
                                model, grads=grads, alpha=args.alpha, lamb=args.lamb
                            )
                        elif args.filter == "ema_depth":
                            grads = gradfilter_with_depth_scaling(
                                model, grads=grads, alpha=args.alpha, lamb_min=0.5 * args.lamb, lamb_max=2 * args.lamb
                            )
                        else:
                            raise ValueError(f"Invalid gradient filter type `{args.filter}`")

                        optimizer.step()
                        scheduler.step()
                        i += 1

                    acc = (logits[-1].argmax(-1) == input[-1]).float().mean()
                    total_acc += acc.item() * input.shape[-1]

                if is_train:
                    train_acc.append(total_acc / train_data.shape[-1])
                    train_loss.append(total_loss / train_data.shape[-1])
                    its.append(i)
                else:
                    val_acc.append(total_acc / valid_data.shape[-1])
                    val_loss.append(total_loss / valid_data.shape[-1])

            # Early stopping
            if len(val_acc) > 0:
                avg_val_acc = val_acc[-1]
                if avg_val_acc >= stop_threshold:
                    patience_counter += 1
                    if patience_counter >= stop_patience:
                        print(f"Stopping early for lambda={lamb}, alpha={alpha} at epoch {e + 1}.")
                        stop_early = True
                        steps_to_achieve = i
                else:
                    patience_counter = 0

            # Save weights and results
            if args.save_weights:
                do_save = (
                    e <= 500 or (e > 500 and (e + 1) % 100 == 0) or e == int(args.budget) // steps_per_epoch - 1
                )
            else:
                do_save = (e + 1) % 100 == 0

            if do_save:
                net_its.append(e)
                nets.append(copy.deepcopy(model.state_dict()))
                results = {
                    "its": its,
                    "train_acc": train_acc,
                    "train_loss": train_loss,
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                }
                torch.save(results, f"{results_dir}/pt/res_{args.label}.pt")

            steps = torch.arange(len(train_acc)).numpy() * steps_per_epoch + 1  # Add 1 to avoid zero

            plt.plot(steps, train_acc, label="train")
            plt.plot(steps, val_acc, label="val")
            plt.legend()
            plt.title(f"Accuracy for Modular Multiplication (lambda={lamb}, alpha={alpha})")
            plt.xlabel("Optimization Steps")
            plt.ylabel("Accuracy")
            plt.xscale("log", base=10)
            plt.grid()
            plt.savefig(f"{results_dir}/acc/acc_{args.label}_lambda_{lamb}_alpha_{alpha}.png", dpi=150)
            plt.close()

            plt.plot(steps, train_loss, label="train")
            plt.plot(steps, val_loss, label="val")
            plt.legend()
            plt.title(f"Loss for Modular Multiplication (lambda={lamb}, alpha={alpha})")
            plt.xlabel("Optimization Steps")
            plt.ylabel("Loss")
            plt.xscale("log", base=10)
            plt.grid()
            plt.savefig(f"{results_dir}/loss/loss_{args.label}_lambda_{lamb}_alpha_{alpha}.png", dpi=150)
            plt.close()

        # Update best parameters
        if len(val_acc) > 0 and val_acc[-1] > best_val_acc:
            best_val_acc = val_acc[-1]
            best_params = {"lambda": lamb, "alpha": alpha}

    print(f"Best parameters: {best_params} with validation accuracy {best_val_acc}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--label", default="depth_algo_trial")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--p", type=int, default=97)
    parser.add_argument("--depth", type=int, default=2)
    parser.add_argument("--budget", type=int, default=3e5)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--beta1", type=float, default=0.9)
    parser.add_argument("--beta2", type=float, default=0.98)
    parser.add_argument("--weight_decay", type=float, default=0.005)
    parser.add_argument("--optimizer", default="Adam")
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir","ema_depth"], default="ema_depth")
    parser.add_argument("--alpha", type=float, default=0.98)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=2.0)
    parser.add_argument("--two_stage", action='store_true')
    parser.add_argument("--save_weights", action='store_true')
    parser.add_argument("--dataset_fraction", type=float, default=1.0, help="Fraction of the dataset to use (0.0 to 1.0)")

    # Parse known arguments to handle Jupyter conflicts
    args, unknown = parser.parse_known_args()

    # Modify label dynamically
    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{int(args.lamb)}'

    if args.filter == "none":
        filter_suffix = ""
    elif args.filter == "ma":
        filter_suffix = window_size_str + lamb_str
    elif args.filter == "ema":
        filter_suffix = alpha_str + lamb_str
    elif args.filter == "ema_depth":
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ""
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + filter_str + filter_suffix + optim_suffix
    print(f"Experiment results saved under name: {args.label}")



   # Call main
    main(args)

Experiment results saved under name: depth_algo_trial_ema_depth_a0980_l2_wd50e-03
Testing lambda=1.0, alpha=0.8
Decoder(
  (token_embeddings): Embedding(99, 128)
  (position_embeddings): Embedding(5, 128)
  (layers): ModuleList(
    (0-1): 2 x Block(
      (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=128, out_features=99, bias=False)
)
Total number of parameters: 422784


  1%|          | 242/30000 [03:55<8:03:18,  1.03it/s]


KeyboardInterrupt: 

In [None]:
args = parser.parse_args([
    "--window_size", "100",
    "--lamb", "2.0",
    "--alpha", "0.98",
    "--weight_decay","0.005",
    "--filter","ema",
    "--batch_size", "512",
    "--label", "ema_same_as_paper"
])

print(f"Experiment results saved under name: {args.label}")

print(f"Alg: {args.filter}")


# Call main
main(args)