In [33]:
from experiment import ResultsViewer
from IPython.display import Image
import torch
from scipy.io import savemat, loadmat
import os
import json

In [34]:
results = ResultsViewer("experiments/experiment_1_server")

def lookup(results, *filters):
  r = []
  for res in results:
    skip = False
    for f in filters:
      if not f(res):
        skip = True
        break
    if not skip:
      r.append(res)
  return r

def model_to_matlab(res):
  W = []
  V = []
  skip = []
  out = {}
  sd = torch.load(res['state_dict_path'])
  for i in range(res['model']['layers']):
    key = 'blocks.{}.{}.weight'
    w = sd[key.format(i, 'W')].numpy()
    v = sd[key.format(i, 'V')].numpy()  
    s = sd[key.format(i, 'skip_l')].numpy()  
    W.append(w)
    V.append(v)
    skip.append(s)
    out[f'W_{i}'] = w
    out[f'V_{i}'] = v
    out[f'skip_{i}'] = s
  fname = f"D={res['data']['D']}_Term={res['model']['regularization_method']}_Lam={res['model']['regularization_lambda']}.mat"
  return fname, out
  
def save_matlab(results):
  new_results = []
  pardir = results.results_dir
  for res in results:
    res = res.copy()
    fname, matlab_out = model_to_matlab(res)
    fname = os.path.join(pardir, fname)
    savemat(fname, matlab_out)
    res["matlab_model_path"] = fname
    new_results.append(res)
  new_results_path = os.path.join(pardir, "results_with_matlab.json")
  with open(new_results_path, "w") as fp:
    fp.write(json.dumps(new_results))
  return new_results_path

In [35]:
nr = save_matlab(results)

In [32]:
nr

[{'data': {'D': 4, 'N': 81, 'n': 81},
  'model': {'relu_width': 1296,
   'linear_width': 16,
   'layers': 3,
   'epochs': 50000,
   'learning_rate': 0.001,
   'weight_decay': 0,
   'regularization_lambda': 0.001,
   'regularization_method': 1},
  'report': {'Eval. MSE': 0.0,
   'Eval. Acc': 1.0,
   'Sparsity': [0.1979166716337204,
    0.14298804104328156,
    0.03284143656492233,
    0.032310955226421356,
    0.06114969030022621,
    0.014660493470728397],
   'Sparsity by Epoch': [[0,
     [0.9988425970077515,
      0.9986497163772583,
      0.9986014366149902,
      0.9987461566925049,
      0.9992284178733826,
      0.9984567761421204]],
    [5000,
     [0.7245370149612427,
      0.7119502425193787,
      0.6262056231498718,
      0.6348379850387573,
      0.6119309663772583,
      0.6905864477157593]],
    [10000,
     [0.5842978358268738,
      0.5730613470077515,
      0.3358410596847534,
      0.3424479067325592,
      0.4617091119289398,
      0.44675925374031067]],
    [15000,
