In [19]:
import sys
sys.path.append('./pytorch-maml')
import warnings
warnings.filterwarnings('ignore')
import torch
import math
import os
import time
import json
import logging
import numpy as np
from scipy import stats
from torchmeta.utils.data import BatchMetaDataLoader
from maml.metalearners.maml import ModelAgnosticMetaLearning
from maml.datasets import get_benchmark_by_name
from maml.stiefelmanifold_update_parameters import stiefelmanifold_update_parameters, kernel_loss

In [20]:
import time
from functools import wraps
def estimate_execution_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        elapsed_time = end - start
        print(f"Function '{func.__name__}' executed in: {elapsed_time:.6f} seconds")
        return result
    return wrapper
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [21]:
config_dict = {
    "dataset": "omniglot", # [omniglot, miniimagenet]
    "folder": "/datapath",
    "output_folder": "models",
    "num_ways": 5,
    "num_shots": 5,
    "num_shots_test": 15,
    "hidden_size": 64,
    "batch_size": 2,
    "num_steps": 5,
    "num_epochs": 600,  # 50,
    "num_batches": 3, # 100,
    "step_size": 0.01,
    "first_order": False,
    "meta_lr": 0.001,
    "num_workers": 0,
    "verbose": True,
    "use_cuda": True,
    "seed": 42,
    "save_path": "/resultspath",
}
with open("config.json", "w") as f:
    json.dump(config_dict, f)

In [22]:
import json
class Config:
    def __init__(self, config_dict):
        self.__dict__.update(config_dict)
    def __getattr__(self, name):
        try:
            return self.__dict__[name]
        except KeyError:
            raise AttributeError(f"'Config' object has no attribute '{name}'")
with open("config.json", "r") as f:
    config_json = json.load(f)
config = Config(config_json)

In [23]:
seed_everything(config.seed)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if config.use_cuda and torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

True


In [24]:
torch.backends.cudnn.enabled = False

In [25]:
benchmark = get_benchmark_by_name(config.dataset,
                                  config.folder,
                                  config.num_ways,
                                  config.num_shots,
                                  config.num_shots_test,
                                  hidden_size=config.hidden_size)
meta_train_loader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                        batch_size=config.batch_size,
                                        shuffle=True,
                                        num_workers=config.num_workers,
                                        pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                         batch_size=config.batch_size,
                                         shuffle=True,
                                         num_workers=config.num_workers,
                                         pin_memory=True)
meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=config.meta_lr)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                        meta_optimizer,
                                        first_order=config.first_order,
                                        num_adaptation_steps=config.num_steps,
                                        step_size=config.step_size,
                                        loss_function=benchmark.loss_function,
                                        device=device)
best_value = None

In [26]:
def calculate_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2, n - 1)
    return mean, mean - h, mean + h

In [27]:
all_epoch_results = []
all_grad_norms = []
best_accuracy = 0
best_result = None

for epoch in range(config.num_epochs):
    print(f"Epoch {epoch+1}/{config.num_epochs}")
    metalearner.train(meta_train_loader,
                      max_batches=config.num_batches,
                      verbose=config.verbose,
                      desc='Training',
                      leave=False)
    mean_results, results, batch_results = metalearner.evaluate(meta_val_dataloader,
                                   max_batches=config.num_batches,
                                   verbose=config.verbose,
                                   desc=f'Epoch {epoch+1}')
    
    accuracies = [result["accuracies_after"].mean() for result in batch_results]
    mean, ci_lower, ci_upper = calculate_confidence_interval(accuracies)

    all_grad_norms = [result['grad_norm'] for result in batch_results]
    print(all_grad_norms)
    # mean_grad_norm = np.mean(all_grad_norms)
    
    results['ci_lower'] = ci_lower
    results['ci_upper'] = ci_upper
    # results['grad_norm'] = mean_grad_norm
    results['epoch'] = epoch + 1
    # results['grad_norm']

    all_epoch_results.append(results)

    # Save results
    dirname = os.path.dirname(config_dict['save_path'])

    with open(os.path.join(dirname, 'omiglot_test_5w5s.json'), 'w') as f:
        json.dump(all_epoch_results, f, indent=4)

    if results['accuracies_after'] > best_accuracy:
        best_accuracy = results['accuracies_after']
        best_result = results
        with open(os.path.join(dirname, 'omiglot_best_5w5s.json'), 'w') as f:
            json.dump(best_result, f, indent=4)
            

Epoch 1/600


Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.39it/s, accuracy=0.8511, loss=0.8659]


TypeError: 'numpy.float32' object is not iterable