In [None]:
import sys
import os

def running_in_colab():
    return 'google.colab' in sys.modules or os.path.exists('/content')

branch = "main"
username = "giovanna-brod-zamojska"
repo = "federated-learning-project"

is_private = True

def clone_repo_if_needed(exists_ok: bool, username: str, repository: str, is_private: bool, branch: str = None):

  colab_repo_path = f'/content/{repository}/'
  
  if running_in_colab():

    if exists_ok and os.path.exists(colab_repo_path):
        print(f"Repository already exists at {colab_repo_path}")
        return

    if not os.path.exists(colab_repo_path) or not exists_ok:

        # Remove any existing repo
        print(f"Removing content of {colab_repo_path}")
        os.system(f"rm -rf {colab_repo_path}")
        print("Current directory files and folders:", os.system("ls"))

        print("Cloning GitHub repo...")

        if is_private:
            # Clone private repository
            # Clone the GitHub repo (only needed once, if not already cloned)
            from getpass import getpass


            # Prompt for GitHub token (ensure token has access to the repo)
            token = getpass('Enter GitHub token: ')

            if branch:
              !git clone --branch {branch} https://{username}:{token}@github.com/{username}/{repo}.git
            else: 
              !git clone https://{username}:{token}@github.com/{username}/{repo}.git

        else:
            # Clone public repository
            if branch:
              !git clone --branch {branch} https://github.com/{username}/{repo}.git
            else:
              !git clone https://github.com/{username}/{repo}.git


    requirements_path = f"{colab_repo_path}/colab-requirements.txt"
    !pip install -r "$requirements_path"

  else:
    print("Not running in Google Colab. Skipping repository cloning.")#



def setup_notebook(repo_root_name: str = "federated-learning-project"):
    import sys
    from pathlib import Path

    if running_in_colab():
        print("Sys.path: ", sys.path)

        colab_repo_path = f'/content/{repo_root_name}/'
         # Add the repository root to sys.path so modules can be imported
        if str(colab_repo_path) not in sys.path:
            sys.path.insert(0, colab_repo_path)
            print(f"Added {colab_repo_path} to sys.path")
    else:
      
        notebook_dir = Path().absolute()
        project_root = notebook_dir.parent.parent

        # Add project root to Python path if not already present
        if str(project_root) not in sys.path:
            sys.path.insert(0, str(project_root))
            print(f"Added {project_root} to Python path")

        
clone_repo_if_needed(branch=branch, exists_ok=True, username=username, repository=repo, is_private=is_private)

setup_notebook()

    

In [None]:
import os
import json
import torch
import random
import numpy as np
from src.centralized_baseline.dataset import CIFAR100Dataset
from src.centralized_baseline.experiment_manager import ExperimentManager
from src.model_editing.centralized_baseline.trainer import ModelEditingTrainer

from itertools import product

checkpoint_dir = "./checkpoints"
experiments_dir = "./output"

if running_in_colab():
    from google.colab import drive
    drive.mount('/content/drive')

    experiments_dir = "/content/drive/MyDrive/-" # define your Google Drive path here
    checkpoint_dir = experiments_dir + "/checkpoints"


def set_seed(seed):
    """Set random seed for reproducibility"""
    print(f"Setting random seed to {seed}")
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run_experiments(seed: int):

        set_seed(seed)

        exp = "_EDIT_0"

        grid_dict = {
            "batch_size": [128],
            "lr": [0.01],
            "weight_decay": [5e-4],
            "momentum": [0.9],
            "epochs": [20],
            "seed": [seed],
            "num_workers": [4],
            "accum_steps": [1],
            "optimizer_type": ["SparseSGD"],
            "augment": [None],
            "sparsity": [0.99, 0.95, 0.9], # model editing
            "rounds": [1, 2, 3, 5, 10], # model editing
            "num_batches": [None], # model editing
            "strategy": ["train_least_important"], # model editing
            "approximate_fisher": [True], # model editing
        }

        # Generate param grid from all combinations
        keys, values = zip(*grid_dict.items())
        param_grid = [dict(zip(keys, v)) for v in product(*values)]

        manager = ExperimentManager(
            param_grid=param_grid,
            use_wandb=False,
            project_name="federated-learning-project", #wandb
            group_name="centralized-baseline-model-editing", #wandb
            checkpoint_dir=checkpoint_dir,
        )
        _, _, results = manager.run(
            trainer_class=ModelEditingTrainer,
            dataset_class=CIFAR100Dataset,
            run_name="baseline-model-editing", #wandb
            run_tags=[ f"v{exp}", "baseline-model-editing"], #wandb
            resume_training_from_config=None,
            model_editing=True,  # Enable model editing
        )
        print("Experiments completed.\n")

        filename = f"experiment_baseline_full_param_grid_search_v{exp}_MODEL_EDITING.json"  
        os.makedirs(experiments_dir, exist_ok=True)
        file_path = os.path.join(experiments_dir, filename)
        with open(file_path, "w") as f:
            json.dump(results, f, indent=4)
        print(f"Results saved to {file_path}")

try:
    run_experiments(seed=42)
except:
    import traceback
    print(traceback.format_exc())