In [None]:
#Experimenting with depth - filtering  code for MNIST dataset
#   Omada 2 -- Grokfast experiments

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# from google.colab import files
# files.upload()


In [None]:
!ls /content/drive/MyDrive/PatRec_Project_Shared_Folder/

Algorithmic_code  MNIST_code	  __pycache__  requirements.txt
grokfast.py	  _Presentations  QM9_code     results


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/PatRec_Project_Shared_Folder')

In [None]:
!pip install -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt

Collecting torch_geometric (from -r /content/drive/MyDrive/PatRec_Project_Shared_Folder/requirements.txt (line 4))
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import random
import time
import math
import argparse
from argparse import ArgumentParser
from collections import defaultdict
from itertools import islice
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision
from typing import List, Optional, Dict, Literal
from collections import deque


In [None]:
from grokfast import gradfilter_ma, gradfilter_ema

In [None]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [None]:
def compute_accuracy(network, dataset, device, N=2000, batch_size=50):
    """Computes accuracy of `network` on `dataset`.
    """
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        correct = 0
        total = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            logits = network(x.to(device))
            predicted_labels = torch.argmax(logits, dim=1)
            correct += torch.sum(predicted_labels == labels.to(device))
            total += x.size(0)
        return (correct / total).item()


def compute_loss(network, dataset, loss_function, device, N=2000, batch_size=50):
    """Computes mean loss of `network` on `dataset`.
    """
    with torch.no_grad():
        N = min(len(dataset), N)
        batch_size = min(batch_size, N)
        dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loss_fn = loss_function_dict[loss_function](reduction='sum')
        one_hots = torch.eye(10, 10).to(device)
        total = 0
        points = 0
        for x, labels in islice(dataset_loader, N // batch_size):
            y = network(x.to(device))
            if loss_function == 'CrossEntropy':
                total += loss_fn(y, labels.to(device)).item()
            elif loss_function == 'MSE':
                total += loss_fn(y, one_hots[labels]).item()
            points += len(labels)
        return total / points



In [None]:
optimizer_dict = {
    'AdamW': torch.optim.AdamW,
    'Adam': torch.optim.Adam,
    'SGD': torch.optim.SGD
}

activation_dict = {
    'ReLU': nn.ReLU,
    'Tanh': nn.Tanh,
    'Sigmoid': nn.Sigmoid,
    'GELU': nn.GELU
}

loss_function_dict = {
    'MSE': nn.MSELoss,
    'CrossEntropy': nn.CrossEntropyLoss
}

In [None]:
import os
# Ensure the 'results' directory exists
results_dir = "/content/drive/MyDrive/PatRec_Project_Shared_Folder/results/MNIST"


In [None]:

def gradfilter_with_depth_scaling(
    m: nn.Module,
    grads: Optional[Dict[str, deque]] = None,
    window_size: int = 100,
    alpha: float = 0.98,
    lamb_max: float = 3.0,
    lamb_min: float = 1.0,
    d_max: int = 12,  # Total number of transformer layers
    filter_type: Literal['mean', 'sum'] = 'mean',
    warmup: bool = True,
    trigger: bool = False,
    embedding_layer_name: str = "embedding",
    final_and_output_layer_names: List[str] = ["ln_f", "head"],  # Default final and output layer names
) -> Dict[str, deque]:
    """
    Applies gradient filtering with dynamic depth-based lambda scaling.

    Args:
        m (nn.Module): The model containing the parameters.
        grads (Optional[Dict[str, deque]]): Dictionary for storing past gradients.
        window_size (int): Number of past gradients to consider.
        lamb_max (float): Maximum lambda value for scaling.
        lamb_min (float): Minimum lambda value for scaling.
        d_max (int): Total depth (number of transformer layers).
        filter_type (Literal['mean', 'sum']): Filtering strategy ('mean' or 'sum').
        warmup (bool): Whether to enable warmup for gradient filtering.
        trigger (bool): Optional trigger condition for gradient filtering.
        embedding_layer_name (str): Substring identifying embedding layer parameters.
        final_and_output_layer_names (List[str]): List of substrings identifying final/output layer parameters.

    Returns:
        Dict[str, deque]: Updated gradient storage for the model parameters.
     """
    if grads is None:
        grads = {n: p.grad.data.clone() for n, p in m.named_parameters() if p.requires_grad and p.grad is not None}

    for n, p in m.named_parameters():
        if p.requires_grad and p.grad is not None:
            # Determine depth or position
            if embedding_layer_name in n:
                depth = 0  # Embedding layers are assigned depth 0
            elif "layers" in n:
                # Extract depth information from name, e.g., "layers.0", "layers.1"
                depth = int(n.split(".")[1]) + 1  # Increment depth for transformer layers
            elif final_and_output_layer_names and any(layer_name in n for layer_name in final_and_output_layer_names):
                depth = d_max + 1  # Final and output layers are d_max + 1
            else:
                depth = d_max  # Default depth for unclassified layers

            # Adjust lambda based on depth
            lambda_d = lamb_max - (depth / (d_max + 1)) * (lamb_max - lamb_min)

            # Apply EMA update
            if n not in grads:
                grads[n] = p.grad.data.clone()  # Initialize EMA
            else:
                grads[n] = grads[n] * alpha + p.grad.data.clone() * (1 - alpha)

            # Scale gradient by depth-aware lambda
            p.grad.data = p.grad.data + grads[n] * lambda_d

    return grads

In [None]:
#        Comment after first time
# Download dataset

import torchvision

dataset_path = '/content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data'
train_dataset = torchvision.datasets.MNIST(
    root=dataset_path, train=True, transform=torchvision.transforms.ToTensor(), download=True
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 14.7MB/s]


Extracting /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 484kB/s]


Extracting /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.19MB/s]


Extracting /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.70MB/s]

Extracting /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/PatRec_Project_Shared_Folder/MNIST_code/MNIST_data/MNIST/raw






In [None]:
from collections import Counter
from sklearn.model_selection import train_test_split


def main(args):
    log_freq = math.ceil(args.optimization_steps / 150)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    torch.set_default_dtype(dtype)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # load dataset
    train_dataset = torchvision.datasets.MNIST(
        root=dataset_path, train=True, transform=torchvision.transforms.ToTensor(), download=True
    )

    # Create indices stratified by digit labels
    train_indices = list(range(len(train_dataset)))
    train_labels = [train_dataset.targets[i].item() for i in train_indices]

    args.train_points = 1000
    # Use train_test_split with stratification to randomly select a specified number of samples (args.train_points)
    stratified_indices, _ = train_test_split(
        train_indices,
        train_size=args.train_points,
        stratify=train_labels,
        random_state=args.seed
    )

    # Create a subset with the stratified indices
    train_subset = torch.utils.data.Subset(train_dataset, stratified_indices)
    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=args.batch_size, shuffle=True)

    test = torchvision.datasets.MNIST(root=args.download_directory, train=False,
        transform=torchvision.transforms.ToTensor(), download=True)

    assert args.activation in activation_dict, f"Unsupported activation function: {args.activation}"
    activation_fn = activation_dict[args.activation]

    print(f"Stratified train subset created with {len(stratified_indices)} samples.")

    # Create model
    layers = [nn.Flatten()]
    for i in range(args.depth):
        if i == 0:
            layers.append(nn.Linear(784, args.width))
            layers.append(activation_fn())
        elif i == args.depth - 1:
            layers.append(nn.Linear(args.width, 10))
        else:
            layers.append(nn.Linear(args.width, args.width))
            layers.append(activation_fn())
    mlp = nn.Sequential(*layers).to(device)
    with torch.no_grad():
        for p in mlp.parameters():
            p.data = args.initialization_scale * p.data
    nparams = sum([p.numel() for p in mlp.parameters() if p.requires_grad])
    print(f'Number of parameters: {nparams}')

    # create optimizer
    assert args.optimizer in optimizer_dict, f"Unsupported optimizer choice: {args.optimizer}"
    optimizer = optimizer_dict[args.optimizer](mlp.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # define loss function
    assert args.loss_function in loss_function_dict
    loss_fn = loss_function_dict[args.loss_function]()


    #     Start Training below
    train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []
    norms, last_layer_norms, log_steps = [], [], []
    grads = None

    steps = 0
    one_hots = torch.eye(10, 10).to(device)
    with tqdm(total=args.optimization_steps, dynamic_ncols=True) as pbar:
        for x, labels in islice(cycle(train_loader), args.optimization_steps):
            do_log = (steps < 30) or (steps < 150 and steps % 10 == 0) or steps % log_freq == 0
            if do_log:
                train_losses.append(compute_loss(mlp, train_subset, args.loss_function, device, N=len(train_subset)))
                train_accuracies.append(compute_accuracy(mlp, train_subset, device, N=len(train_subset)))
                test_losses.append(compute_loss(mlp, test, args.loss_function, device, N=len(test)))
                test_accuracies.append(compute_accuracy(mlp, test, device, N=len(test)))
                log_steps.append(steps)

                pbar.set_description(
                    "L: {0:1.1e}|{1:1.1e}. A: {2:2.1f}%|{3:2.1f}%".format(
                        train_losses[-1],
                        test_losses[-1],
                        train_accuracies[-1] * 100,
                        test_accuracies[-1] * 100,
                    )
                )

            y = mlp(x.to(device))
            if args.loss_function == 'CrossEntropy':
                loss = loss_fn(y, labels.to(device))
            elif args.loss_function == 'MSE':
                loss = loss_fn(y, one_hots[labels])

            optimizer.zero_grad()
            loss.backward()

            #######

            trigger = False

            if args.filter == "none":
                pass
            elif args.filter == "ma":
                grads = gradfilter_ma(mlp, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger)
            elif args.filter == "ema":
                grads = gradfilter_ema(mlp, grads=grads, alpha=args.alpha, lamb=args.lamb)
            else:
                raise ValueError(f"Invalid gradient filter type `{args.filter}`")

            #######

            optimizer.step()

            steps += 1
            pbar.update(1)

            if do_log:
                title = (f"MNIST Image Classification")

                plt.plot(log_steps, train_accuracies, label="train")
                plt.plot(log_steps, test_accuracies, label="val")
                plt.legend()
                plt.title(title)
                plt.xlabel("Optimization Steps")
                plt.ylabel("Accuracy")
                plt.xscale("log", base=10)
                plt.grid()
                plt.savefig(f"{results_dir}/mnist_acc_{args.label}.png", dpi=150)
                plt.close()

                plt.plot(log_steps, train_losses, label="train")
                plt.plot(log_steps, test_losses, label="val")
                plt.legend()
                plt.title(title)
                plt.xlabel("Optimization Steps")
                plt.ylabel(f"{args.loss_function} Loss")
                plt.xscale("log", base=10)
                plt.yscale("log", base=10)
                plt.grid()
                plt.savefig(f"{results_dir}/mnist_loss_{args.label}.png", dpi=150)
                plt.close()

                # Save results
                results_filename = os.path.join(results_dir, f"mnist_{args.label}.pt")
                torch.save({
                    'its': log_steps,
                    'train_acc': train_accuracies,
                    'train_loss': train_losses,
                    'val_acc': test_accuracies,
                    'val_loss': test_losses,
                }, results_filename)


In [None]:
import sys
from argparse import ArgumentParser

# Remove the extra arguments passed by the Jupyter Notebook kernel
sys.argv = ['']

In [None]:
# Load the saved results file in prder to use them for plots

results_filename = os.path.join(results_dir, "mnist_none_wd10e-02.pt")  # Replace with your actual filename
if os.path.exists(results_filename):
    data = torch.load(results_filename)
else:
    print(f"File {results_filename} not found!")


  data = torch.load(results_filename)


In [None]:
accuracy_plot_path = os.path.join(results_dir, "mnist_accuracy_none.png")
loss_plot_path = os.path.join(results_dir, "mnist_loss_none.png")

In [None]:
# Plot Accuracy
plt.plot(data['its'], data['train_acc'], label="Train Accuracy")
plt.plot(data['its'], data['val_acc'], label="Validation Accuracy")
plt.title("Accuracy over Optimization Steps")
plt.xlabel("Optimization Steps")
plt.ylabel("Accuracy")
plt.xscale("log")
plt.legend()
plt.grid()
plt.savefig(accuracy_plot_path, dpi=150)
plt.close()  # Close to avoid conflicts with the next plot
print(f"Accuracy plot saved to {accuracy_plot_path}")

Accuracy plot saved to /content/drive/MyDrive/PatRec_Project_Shared_Folder/results/mnist_accuracy_none.png


In [None]:
# Plot Loss
plt.plot(data['its'], data['train_loss'], label="Train Loss")
plt.plot(data['its'], data['val_loss'], label="Validation Loss")
plt.title("Loss over Optimization Steps")
plt.xlabel("Optimization Steps")
plt.ylabel("Loss")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.grid()
plt.savefig(loss_plot_path, dpi=150)
plt.close()
print(f"Loss plot saved to {loss_plot_path}")

Loss plot saved to /content/drive/MyDrive/PatRec_Project_Shared_Folder/results/mnist_loss_none.png


In [None]:
#Running MA algoritmh with alpha = 0.98 lamb = 0.1 weight_decay = 2.0 as recommended in github

In [None]:
if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--label", default="")
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--train_points", type=int, default=1000)
    parser.add_argument("--optimization_steps", type=int, default=100000)
    parser.add_argument("--batch_size", type=int, default=200)
    parser.add_argument("--loss_function", type=str, default="MSE")
    parser.add_argument("--optimizer", type=str, default="AdamW")
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--initialization_scale", type=float, default=8.0)
    parser.add_argument("--download_directory", type=str, default=".")
    parser.add_argument("--depth", type=int, default=3)
    parser.add_argument("--width", type=int, default=200)
    parser.add_argument("--activation", type=str, default="ReLU")

    # Grokfast
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="ma")
    parser.add_argument("--alpha", type=float, default=0.99)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=5.0)
    args = parser.parse_args([
    "--alpha" , "0.98",
    "--weight_decay", "2.0",
    "--lamb" , "0.1",
    ])

    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{args.lamb:.2f}'.replace('.', '')

    if args.filter == 'none':
        filter_suffix = ''
    elif args.filter == 'ma':
        filter_suffix = window_size_str + lamb_str
    elif args.filter == 'ema':
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ''
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + filter_str + filter_suffix + optim_suffix
    print(f'Experiment results saved under name: {args.label}')

    main(args)

Experiment results saved under name: ma_w100_l010_wd20e+00
Stratified train subset created with 1000 samples.
Number of parameters: 199210


  0%|          | 0/100000 [00:00<?, ?it/s]

  plt.xscale("log", base=10)
  plt.xscale("log", base=10)


In [None]:
#Running ΕMA algoritmh with same args as before

In [None]:
if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--label", default="")
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--train_points", type=int, default=1000)
    parser.add_argument("--optimization_steps", type=int, default=100000)
    parser.add_argument("--batch_size", type=int, default=200)
    parser.add_argument("--loss_function", type=str, default="MSE")
    parser.add_argument("--optimizer", type=str, default="AdamW")
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--initialization_scale", type=float, default=8.0)
    parser.add_argument("--download_directory", type=str, default=".")
    parser.add_argument("--depth", type=int, default=3)
    parser.add_argument("--width", type=int, default=200)
    parser.add_argument("--activation", type=str, default="ReLU")

    # Grokfast
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="ema")
    parser.add_argument("--alpha", type=float, default=0.99)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=5.0)
    args = parser.parse_args([
    "--alpha" , "0.98",
    "--weight_decay", "2.0",
    "--lamb" , "0.1",
    ])

    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{args.lamb:.2f}'.replace('.', '')

    if args.filter == 'none':
        filter_suffix = ''
    elif args.filter == 'ma':
        filter_suffix = window_size_str + lamb_str
    elif args.filter == 'ema':
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ''
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + filter_str + filter_suffix + optim_suffix
    print(f'Experiment results saved under name: {args.label}')

    main(args)

Experiment results saved under name: ema_a0980_l010_wd20e+00
Stratified train subset created with 1000 samples.
Number of parameters: 199210


  0%|          | 0/100000 [00:00<?, ?it/s]

  plt.xscale("log", base=10)
  plt.xscale("log", base=10)
