In [1]:
import shutil
import numpy as np
import torch
import torch.nn as nn
import time

from collections import defaultdict
from utils.data.load_data import create_data_loaders
from utils.common.utils import save_reconstructions, ssim_loss
from utils.common.loss_function import SSIMLoss
# from utils.model.unet import Unet
# from utils.model.unet import Unet2

# unet = Unet(in_chans = 1, out_chans = 1)
# unet2 = Unet2(in_chans = 1, out_chans = 1)

  from .autonotebook import tqdm as notebook_tqdm


In [36]:
!ls

 Code				   logs		       run_Unet_JB.py   utils
 brain_leaderboard_state_dict.pt  'model test.ipynb'   run_Unet_SJ.py
 evaluate.py			   plot.py	       test_run.py
 leaderboard_eval.py		   result	       train.py


In [33]:
!mv policy.py utils/model

# Testing model's input & output shape

In [7]:
dummy_input = torch.Tensor(1, 384, 384)
print(dummy_input.shape)

torch.Size([1, 384, 384])


In [8]:
output1 = unet(dummy_input)
output2 = unet2(dummy_input)
print(output1.shape, output2.shape)

torch.Size([1, 384, 384]) torch.Size([1, 384, 384])


# Testing model load & save

In [93]:
for root, dirs, files in os.walk('./result'):
    if root.endswith('/checkpoints'):
        print(root)
        print(files)
        print()

./result/SJ/test_Unet/checkpoints
['best_model.pt', 'model.pt']

./result/JB/Unet/checkpoints
[]

./result/JB/newUnet/checkpoints
['best_model.pt', 'model.pt', 'newUnet_test02_epoch30.pt', 'newUnet_test02_best.pt']

./result/JB/test_Unet/checkpoints
['best_model.pt', 'model.pt']

./result/test_Unet/checkpoints
[]



In [71]:
checkpoint = torch.load(f='/root/SNU_fastMRI/result/JB/newUnet/checkpoints/newUnet_test02_epoch30.pt')

In [68]:
from utils.model.unet_advanced import Unet as newUnet
net = newUnet(1,1)

In [72]:
net.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [114]:
checkpoint['optimizer']['param_groups'][0]

{'lr': 0.0005,
 'betas': (0.9, 0.999),
 'eps': 1e-08,
 'weight_decay': 0,
 'amsgrad': False,
 'params': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23]}

In [103]:
type(c['epoch'])

int

In [73]:
net

Unet(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2

In [74]:
a = Path('/root')
b = 'hi'
print(a/b)

/root/hi


In [2]:
import os
os.path.exists('result/JB/newUnet/checkpoints/newUnet_test02_best.pt')

True

# testing json load and save

In [26]:
import json
with open('result/JB/newUnet/jsons/newUnet_test02.json','r') as f:
    res = json.load(f)

In [14]:
tlo = res['train_losses'][:31]
vlo = res['val_losses'][:31]

In [19]:
tln = res['train_losses'][-60:]
vln = res['val_losses'][-60:]

In [23]:
res['train_losses'] = tlo + tln
res['val_losses'] = vlo + vln

In [25]:
with open('result/JB/newUnet/jsons/newUnet_test02.json', 'w') as f:
        json.dump(res, f)

In [1]:
import os, sys
if os.getcwd() + '/utils/model/' not in sys.path:
    sys.path.insert(1, os.getcwd() + '/utils/model/')

import argparse
import shutil
import hashlib

from utils.learning.train_part import train
from utils.common.utils import save_exp_result
from pathlib import Path

def parse(c):
    parser = argparse.ArgumentParser(description='Train Unet on FastMRI challenge Images',
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-g', '--GPU-NUM', type=int, default=0, help='GPU number to allocate')
    parser.add_argument('-b', '--batch-size', type=int, default=256, help='Batch size')
    parser.add_argument('-e', '--num-epochs', type=int, default=3, help='Number of epochs')
    parser.add_argument('-l', '--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('-r', '--report-interval', type=int, default=10, help='Report interval')
    parser.add_argument('-n', '--net-name', type=Path, required=True, help='Name of network')
    parser.add_argument('-o', '--optim', type=str, default='Adam', help='Name of optimizer')
    parser.add_argument('-s', '--scheduler', type=str, default='Plateau', help='Name of lr scheduler')
    
    parser.add_argument('-t', '--data-path-train', type=Path, default='/data/fastMRI/train/', help='Directory of train data')
    parser.add_argument('-v', '--data-path-val', type=Path, default='/data/fastMRI/val/', help='Directory of validation data')
    parser.add_argument('--cascade', type=int, default=1, help='Number of cascades | Should be less than 12')
    
    parser.add_argument('--in-chans', type=int, default=1, help='Size of input channels for network')
    parser.add_argument('--out-chans', type=int, default=1, help='Size of output channels for network')
    parser.add_argument('--input-key', type=str, default='image_input', help='Name of input key')
    parser.add_argument('--target-key', type=str, default='image_label', help='Name of target key')
    parser.add_argument('--max-key', type=str, default='max', help='Name of max key in attributes')
    
    parser.add_argument('--load', type=str, default='', help='Name of saved model that will be loaded')
    parser.add_argument('-u', '--user', type=str, choices=['SJ','JB'], required=True, help='User name')
    parser.add_argument('-x', '--exp-name', type=str, default='test', help='Name of an experiment')
      
    args = parser.parse_args(c)
#     tot_iter = 5164
#     args.report_interval = int((tot_iter/args.batch_size)/10)
    
    return args

In [2]:
command = '-u JB -x AdaptiveVarNet1 -n AdaptiveVarNet -b 1 -e 1 -l 1e-5 -s P --input-key kspace --cascade 1'.split()

In [3]:
args = parse(command)

In [4]:
args

Namespace(GPU_NUM=0, batch_size=1, cascade=1, data_path_train=PosixPath('/data/fastMRI/train'), data_path_val=PosixPath('/data/fastMRI/val'), exp_name='AdaptiveVarNet1', in_chans=1, input_key='kspace', load='', lr=1e-05, max_key='max', net_name=PosixPath('AdaptiveVarNet'), num_epochs=1, optim='Adam', out_chans=1, report_interval=10, scheduler='P', target_key='image_label', user='JB')

In [20]:
import torch
from utils.model.adaptive_varnet import AdaptiveVarNet

try:
    del model
except:
    pass
torch.cuda.empty_cache()
model = AdaptiveVarNet(num_cascades=64).to('cuda')

In [47]:
import os, sys
if os.getcwd() + '/utils/model/' not in sys.path:
    sys.path.insert(1, os.getcwd() + '/utils/model/')
import requests    
from tqdm import tqdm
from utils.model.adaptive_varnet import AdaptiveVarNet

def download_model(url, fname):
    response = requests.get(url, timeout=10, stream=True)

    chunk_size = 8 * 1024 * 1024  # 8 MB chunks
    total_size_in_bytes = int(response.headers.get("content-length", 0))
    progress_bar = tqdm(
        desc="Downloading state_dict",
        total=total_size_in_bytes,
        unit="iB",
        unit_scale=True,
    )

    with open(fname, "wb") as fh:
        for chunk in response.iter_content(chunk_size):
            progress_bar.update(len(chunk))
            fh.write(chunk)

model = AdaptiveVarNet(num_cascades=args.cascade)
AdaptiveVarNet_FOLDER = "https://dl.fbaipublicfiles.com/active-mri-acquisition/midl_models/adaptive_4x.ckpt/"
MODEL_FNAMES = "adaptive_4x.ckpt"

url_root = AdaptiveVarNet_FOLDER
download_model('https://dl.fbaipublicfiles.com/active-mri-acquisition/midl_models/adaptive_4x.ckpt', MODEL_FNAMES)

pretrained = torch.load(MODEL_FNAMES)
pretrained_copy = copy.deepcopy(pretrained)


Downloading state_dict: 100%|███████████████████████████████████████████████████| 651M/651M [01:41<00:00, 6.40MiB/s]


ModuleNotFoundError: No module named 'pytorch_lightning'

In [74]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 7.2 MB/s eta 0:00:01
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.9.3


In [81]:
ckpt = torch.load("adaptive_4x.ckpt")

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.70 GiB total capacity; 1.59 GiB already allocated; 4.19 MiB free; 1.59 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [78]:
ckpt.

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'callbacks', 'optimizer_states', 'lr_schedulers', 'state_dict', 'hparams_name', 'hyper_parameters'])

In [80]:
for l in ckpt['state_dict']:
    print(l)


varnet.sens_net.norm_unet.unet.down_sample_layers.0.layers.0.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.0.layers.4.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.1.layers.0.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.1.layers.4.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.2.layers.0.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.2.layers.4.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.3.layers.0.weight
varnet.sens_net.norm_unet.unet.down_sample_layers.3.layers.4.weight
varnet.sens_net.norm_unet.unet.conv.layers.0.weight
varnet.sens_net.norm_unet.unet.conv.layers.4.weight
varnet.sens_net.norm_unet.unet.up_conv.0.layers.0.weight
varnet.sens_net.norm_unet.unet.up_conv.0.layers.4.weight
varnet.sens_net.norm_unet.unet.up_conv.1.layers.0.weight
varnet.sens_net.norm_unet.unet.up_conv.1.layers.4.weight
varnet.sens_net.norm_unet.unet.up_conv.2.layers.0.weight
varnet.sens_net.norm_unet.unet.up_conv.2.layers.4.weight
varnet.sen

In [77]:
from utils.model.adaptive_varnet_module import AdaptiveVarNetModule
pretrained = AdaptiveVarNetModule(**ckpt["hyper_parameters"])

                not been set for this class (DistributedMetricSum). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_full_state_property`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


TypeError: Can't instantiate abstract class DistributedMetricSum with abstract method forward

In [67]:
for layer in pretrained_copy.keys():
    if layer.split('.',2)[1].isdigit() and (args.cascade <= int(layer.split('.',2)[1]) <=11):
        del pretrained[layer]

RuntimeError: Error(s) in loading state_dict for AdaptiveVarNet:
	Missing key(s) in state_dict: "sens_net.norm_unet.unet.down_sample_layers.0.layers.0.weight", "sens_net.norm_unet.unet.down_sample_layers.0.layers.4.weight", "sens_net.norm_unet.unet.down_sample_layers.1.layers.0.weight", "sens_net.norm_unet.unet.down_sample_layers.1.layers.4.weight", "sens_net.norm_unet.unet.down_sample_layers.2.layers.0.weight", "sens_net.norm_unet.unet.down_sample_layers.2.layers.4.weight", "sens_net.norm_unet.unet.down_sample_layers.3.layers.0.weight", "sens_net.norm_unet.unet.down_sample_layers.3.layers.4.weight", "sens_net.norm_unet.unet.conv.layers.0.weight", "sens_net.norm_unet.unet.conv.layers.4.weight", "sens_net.norm_unet.unet.up_conv.0.layers.0.weight", "sens_net.norm_unet.unet.up_conv.0.layers.4.weight", "sens_net.norm_unet.unet.up_conv.1.layers.0.weight", "sens_net.norm_unet.unet.up_conv.1.layers.4.weight", "sens_net.norm_unet.unet.up_conv.2.layers.0.weight", "sens_net.norm_unet.unet.up_conv.2.layers.4.weight", "sens_net.norm_unet.unet.up_conv.3.0.layers.0.weight", "sens_net.norm_unet.unet.up_conv.3.0.layers.4.weight", "sens_net.norm_unet.unet.up_conv.3.1.weight", "sens_net.norm_unet.unet.up_conv.3.1.bias", "sens_net.norm_unet.unet.up_transpose_conv.0.layers.0.weight", "sens_net.norm_unet.unet.up_transpose_conv.1.layers.0.weight", "sens_net.norm_unet.unet.up_transpose_conv.2.layers.0.weight", "sens_net.norm_unet.unet.up_transpose_conv.3.layers.0.weight", "cascades.0.dc_weight", "cascades.0.model.unet.down_sample_layers.0.layers.0.weight", "cascades.0.model.unet.down_sample_layers.0.layers.4.weight", "cascades.0.model.unet.down_sample_layers.1.layers.0.weight", "cascades.0.model.unet.down_sample_layers.1.layers.4.weight", "cascades.0.model.unet.down_sample_layers.2.layers.0.weight", "cascades.0.model.unet.down_sample_layers.2.layers.4.weight", "cascades.0.model.unet.down_sample_layers.3.layers.0.weight", "cascades.0.model.unet.down_sample_layers.3.layers.4.weight", "cascades.0.model.unet.conv.layers.0.weight", "cascades.0.model.unet.conv.layers.4.weight", "cascades.0.model.unet.up_conv.0.layers.0.weight", "cascades.0.model.unet.up_conv.0.layers.4.weight", "cascades.0.model.unet.up_conv.1.layers.0.weight", "cascades.0.model.unet.up_conv.1.layers.4.weight", "cascades.0.model.unet.up_conv.2.layers.0.weight", "cascades.0.model.unet.up_conv.2.layers.4.weight", "cascades.0.model.unet.up_conv.3.0.layers.0.weight", "cascades.0.model.unet.up_conv.3.0.layers.4.weight", "cascades.0.model.unet.up_conv.3.1.weight", "cascades.0.model.unet.up_conv.3.1.bias", "cascades.0.model.unet.up_transpose_conv.0.layers.0.weight", "cascades.0.model.unet.up_transpose_conv.1.layers.0.weight", "cascades.0.model.unet.up_transpose_conv.2.layers.0.weight", "cascades.0.model.unet.up_transpose_conv.3.layers.0.weight", "policies.0.sampler.channel_layer.layers.0.weight", "policies.0.sampler.channel_layer.layers.0.bias", "policies.0.sampler.down_sample_layers.0.layers.0.weight", "policies.0.sampler.down_sample_layers.0.layers.0.bias", "policies.0.sampler.down_sample_layers.1.layers.0.weight", "policies.0.sampler.down_sample_layers.1.layers.0.bias", "policies.0.sampler.down_sample_layers.2.layers.0.weight", "policies.0.sampler.down_sample_layers.2.layers.0.bias", "policies.0.sampler.down_sample_layers.3.layers.0.weight", "policies.0.sampler.down_sample_layers.3.layers.0.bias", "policies.0.sampler.feature_extractor.0.layers.0.weight", "policies.0.sampler.feature_extractor.0.layers.0.bias", "policies.0.sampler.feature_extractor.1.layers.0.weight", "policies.0.sampler.feature_extractor.1.layers.0.bias", "policies.0.sampler.feature_extractor.2.layers.0.weight", "policies.0.sampler.feature_extractor.2.layers.0.bias", "policies.0.sampler.feature_extractor.3.layers.0.weight", "policies.0.sampler.feature_extractor.3.layers.0.bias", "policies.0.sampler.feature_extractor.4.layers.0.weight", "policies.0.sampler.feature_extractor.4.layers.0.bias", "policies.0.sampler.fc_out.0.weight", "policies.0.sampler.fc_out.0.bias", "policies.0.sampler.fc_out.2.weight", "policies.0.sampler.fc_out.2.bias", "policies.0.sampler.fc_out.4.weight", "policies.0.sampler.fc_out.4.bias". 
	Unexpected key(s) in state_dict: "varnet.sens_net.norm_unet.unet.down_sample_layers.0.layers.0.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.0.layers.4.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.1.layers.0.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.1.layers.4.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.2.layers.0.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.2.layers.4.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.3.layers.0.weight", "varnet.sens_net.norm_unet.unet.down_sample_layers.3.layers.4.weight", "varnet.sens_net.norm_unet.unet.conv.layers.0.weight", "varnet.sens_net.norm_unet.unet.conv.layers.4.weight", "varnet.sens_net.norm_unet.unet.up_conv.0.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_conv.0.layers.4.weight", "varnet.sens_net.norm_unet.unet.up_conv.1.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_conv.1.layers.4.weight", "varnet.sens_net.norm_unet.unet.up_conv.2.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_conv.2.layers.4.weight", "varnet.sens_net.norm_unet.unet.up_conv.3.0.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_conv.3.0.layers.4.weight", "varnet.sens_net.norm_unet.unet.up_conv.3.1.weight", "varnet.sens_net.norm_unet.unet.up_conv.3.1.bias", "varnet.sens_net.norm_unet.unet.up_transpose_conv.0.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_transpose_conv.1.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_transpose_conv.2.layers.0.weight", "varnet.sens_net.norm_unet.unet.up_transpose_conv.3.layers.0.weight", "varnet.cascades.0.model.unet.down_sample_layers.0.layers.0.weight", "varnet.cascades.0.model.unet.down_sample_layers.0.layers.4.weight", "varnet.cascades.0.model.unet.down_sample_layers.1.layers.0.weight", "varnet.cascades.0.model.unet.down_sample_layers.1.layers.4.weight", "varnet.cascades.0.model.unet.down_sample_layers.2.layers.0.weight", "varnet.cascades.0.model.unet.down_sample_layers.2.layers.4.weight", "varnet.cascades.0.model.unet.down_sample_layers.3.layers.0.weight", "varnet.cascades.0.model.unet.down_sample_layers.3.layers.4.weight", "varnet.cascades.0.model.unet.conv.layers.0.weight", "varnet.cascades.0.model.unet.conv.layers.4.weight", "varnet.cascades.0.model.unet.up_conv.0.layers.0.weight", "varnet.cascades.0.model.unet.up_conv.0.layers.4.weight", "varnet.cascades.0.model.unet.up_conv.1.layers.0.weight", "varnet.cascades.0.model.unet.up_conv.1.layers.4.weight", "varnet.cascades.0.model.unet.up_conv.2.layers.0.weight", "varnet.cascades.0.model.unet.up_conv.2.layers.4.weight", "varnet.cascades.0.model.unet.up_conv.3.0.layers.0.weight", "varnet.cascades.0.model.unet.up_conv.3.0.layers.4.weight", "varnet.cascades.0.model.unet.up_conv.3.1.weight", "varnet.cascades.0.model.unet.up_conv.3.1.bias", "varnet.cascades.0.model.unet.up_transpose_conv.0.layers.0.weight", "varnet.cascades.0.model.unet.up_transpose_conv.1.layers.0.weight", "varnet.cascades.0.model.unet.up_transpose_conv.2.layers.0.weight", "varnet.cascades.0.model.unet.up_transpose_conv.3.layers.0.weight", "varnet.cascades.1.model.unet.down_sample_layers.0.layers.0.weight", "varnet.cascades.1.model.unet.down_sample_layers.0.layers.4.weight", "varnet.cascades.1.model.unet.down_sample_layers.1.layers.0.weight", "varnet.cascades.1.model.unet.down_sample_layers.1.layers.4.weight", "varnet.cascades.1.model.unet.down_sample_layers.2.layers.0.weight", "varnet.cascades.1.model.unet.down_sample_layers.2.layers.4.weight", "varnet.cascades.1.model.unet.down_sample_layers.3.layers.0.weight", "varnet.cascades.1.model.unet.down_sample_layers.3.layers.4.weight", "varnet.cascades.1.model.unet.conv.layers.0.weight", "varnet.cascades.1.model.unet.conv.layers.4.weight", "varnet.cascades.1.model.unet.up_conv.0.layers.0.weight", "varnet.cascades.1.model.unet.up_conv.0.layers.4.weight", "varnet.cascades.1.model.unet.up_conv.1.layers.0.weight", "varnet.cascades.1.model.unet.up_conv.1.layers.4.weight", "varnet.cascades.1.model.unet.up_conv.2.layers.0.weight", "varnet.cascades.1.model.unet.up_conv.2.layers.4.weight", "varnet.cascades.1.model.unet.up_conv.3.0.layers.0.weight", "varnet.cascades.1.model.unet.up_conv.3.0.layers.4.weight", "varnet.cascades.1.model.unet.up_conv.3.1.weight", "varnet.cascades.1.model.unet.up_conv.3.1.bias", "varnet.cascades.1.model.unet.up_transpose_conv.0.layers.0.weight", "varnet.cascades.1.model.unet.up_transpose_conv.1.layers.0.weight", "varnet.cascades.1.model.unet.up_transpose_conv.2.layers.0.weight", "varnet.cascades.1.model.unet.up_transpose_conv.3.layers.0.weight", "varnet.cascades.2.model.unet.down_sample_layers.0.layers.0.weight", "varnet.cascades.2.model.unet.down_sample_layers.0.layers.4.weight", "varnet.cascades.2.model.unet.down_sample_layers.1.layers.0.weight", "varnet.cascades.2.model.unet.down_sample_layers.1.layers.4.weight", "varnet.cascades.2.model.unet.down_sample_layers.2.layers.0.weight", "varnet.cascades.2.model.unet.down_sample_layers.2.layers.4.weight", "varnet.cascades.2.model.unet.down_sample_layers.3.layers.0.weight", "varnet.cascades.2.model.unet.down_sample_layers.3.layers.4.weight", "varnet.cascades.2.model.unet.conv.layers.0.weight", "varnet.cascades.2.model.unet.conv.layers.4.weight", "varnet.cascades.2.model.unet.up_conv.0.layers.0.weight", "varnet.cascades.2.model.unet.up_conv.0.layers.4.weight", "varnet.cascades.2.model.unet.up_conv.1.layers.0.weight", "varnet.cascades.2.model.unet.up_conv.1.layers.4.weight", "varnet.cascades.2.model.unet.up_conv.2.layers.0.weight", "varnet.cascades.2.model.unet.up_conv.2.layers.4.weight", "varnet.cascades.2.model.unet.up_conv.3.0.layers.0.weight", "varnet.cascades.2.model.unet.up_conv.3.0.layers.4.weight", "varnet.cascades.2.model.unet.up_conv.3.1.weight", "varnet.cascades.2.model.unet.up_conv.3.1.bias", "varnet.cascades.2.model.unet.up_transpose_conv.0.layers.0.weight", "varnet.cascades.2.model.unet.up_transpose_conv.1.layers.0.weight", "varnet.cascades.2.model.unet.up_transpose_conv.2.layers.0.weight", "varnet.cascades.2.model.unet.up_transpose_conv.3.layers.0.weight", "varnet.cascades.3.model.unet.down_sample_layers.0.layers.0.weight", "varnet.cascades.3.model.unet.down_sample_layers.0.layers.4.weight", "varnet.cascades.3.model.unet.down_sample_layers.1.layers.0.weight", "varnet.cascades.3.model.unet.down_sample_layers.1.layers.4.weight", "varnet.cascades.3.model.unet.down_sample_layers.2.layers.0.weight", "varnet.cascades.3.model.unet.down_sample_layers.2.layers.4.weight", "varnet.cascades.3.model.unet.down_sample_layers.3.layers.0.weight", "varnet.cascades.3.model.unet.down_sample_layers.3.layers.4.weight", "varnet.cascades.3.model.unet.conv.layers.0.weight", "varnet.cascades.3.model.unet.conv.layers.4.weight", "varnet.cascades.3.model.unet.up_conv.0.layers.0.weight", "varnet.cascades.3.model.unet.up_conv.0.layers.4.weight", "varnet.cascades.3.model.unet.up_conv.1.layers.0.weight", "varnet.cascades.3.model.unet.up_conv.1.layers.4.weight", "varnet.cascades.3.model.unet.up_conv.2.layers.0.weight", "varnet.cascades.3.model.unet.up_conv.2.layers.4.weight", "varnet.cascades.3.model.unet.up_conv.3.0.layers.0.weight", "varnet.cascades.3.model.unet.up_conv.3.0.layers.4.weight", "varnet.cascades.3.model.unet.up_conv.3.1.weight", "varnet.cascades.3.model.unet.up_conv.3.1.bias", "varnet.cascades.3.model.unet.up_transpose_conv.0.layers.0.weight", "varnet.cascades.3.model.unet.up_transpose_conv.1.layers.0.weight", "varnet.cascades.3.model.unet.up_transpose_conv.2.layers.0.weight", "varnet.cascades.3.model.unet.up_transpose_conv.3.layers.0.weight", "varnet.cascades.4.model.unet.down_sample_layers.0.layers.0.weight", "varnet.cascades.4.model.unet.down_sample_layers.0.layers.4.weight", "varnet.cascades.4.model.unet.down_sample_layers.1.layers.0.weight", "varnet.cascades.4.model.unet.down_sample_layers.1.layers.4.weight", "varnet.cascades.4.model.unet.down_sample_layers.2.layers.0.weight", "varnet.cascades.4.model.unet.down_sample_layers.2.layers.4.weight", "varnet.cascades.4.model.unet.down_sample_layers.3.layers.0.weight", "varnet.cascades.4.model.unet.down_sample_layers.3.layers.4.weight", "varnet.cascades.4.model.unet.conv.layers.0.weight", "varnet.cascades.4.model.unet.conv.layers.4.weight", "varnet.cascades.4.model.unet.up_conv.0.layers.0.weight", "varnet.cascades.4.model.unet.up_conv.0.layers.4.weight", "varnet.cascades.4.model.unet.up_conv.1.layers.0.weight", "varnet.cascades.4.model.unet.up_conv.1.layers.4.weight", "varnet.cascades.4.model.unet.up_conv.2.layers.0.weight", "varnet.cascades.4.model.unet.up_conv.2.layers.4.weight", "varnet.cascades.4.model.unet.up_conv.3.0.layers.0.weight", "varnet.cascades.4.model.unet.up_conv.3.0.layers.4.weight", "varnet.cascades.4.model.unet.up_conv.3.1.weight", "varnet.cascades.4.model.unet.up_conv.3.1.bias", "varnet.cascades.4.model.unet.up_transpose_conv.0.layers.0.weight", "varnet.cascades.4.model.unet.up_transpose_conv.1.layers.0.weight", "varnet.cascades.4.model.unet.up_transpose_conv.2.layers.0.weight", "varnet.cascades.4.model.unet.up_transpose_conv.3.layers.0.weight", "varnet.policies.0.sampler.channel_layer.layers.0.weight", "varnet.policies.0.sampler.channel_layer.layers.0.bias", "varnet.policies.0.sampler.down_sample_layers.0.layers.0.weight", "varnet.policies.0.sampler.down_sample_layers.0.layers.0.bias", "varnet.policies.0.sampler.down_sample_layers.1.layers.0.weight", "varnet.policies.0.sampler.down_sample_layers.1.layers.0.bias", "varnet.policies.0.sampler.down_sample_layers.2.layers.0.weight", "varnet.policies.0.sampler.down_sample_layers.2.layers.0.bias", "varnet.policies.0.sampler.down_sample_layers.3.layers.0.weight", "varnet.policies.0.sampler.down_sample_layers.3.layers.0.bias", "varnet.policies.0.sampler.feature_extractor.0.layers.0.weight", "varnet.policies.0.sampler.feature_extractor.0.layers.0.bias", "varnet.policies.0.sampler.feature_extractor.1.layers.0.weight", "varnet.policies.0.sampler.feature_extractor.1.layers.0.bias", "varnet.policies.0.sampler.feature_extractor.2.layers.0.weight", "varnet.policies.0.sampler.feature_extractor.2.layers.0.bias", "varnet.policies.0.sampler.feature_extractor.3.layers.0.weight", "varnet.policies.0.sampler.feature_extractor.3.layers.0.bias", "varnet.policies.0.sampler.feature_extractor.4.layers.0.weight", "varnet.policies.0.sampler.feature_extractor.4.layers.0.bias", "varnet.policies.0.sampler.fc_out.0.weight", "varnet.policies.0.sampler.fc_out.0.bias", "varnet.policies.0.sampler.fc_out.2.weight", "varnet.policies.0.sampler.fc_out.2.bias", "varnet.policies.0.sampler.fc_out.4.weight", "varnet.policies.0.sampler.fc_out.4.bias", "loss.w". 

In [None]:
for layer in pretrained_copy.keys():
    if layer.split('.',2)[1].isdigit() and (args.cascade <= int(layer.split('.',2)[1]) <=11):
        del pretrained[layer]
model.load_state_dict(pretrained)

In [None]:
args.exp_dir = './result' / Path(args.user) / args.net_name / 'checkpoints'
args.val_dir = './result' / Path(args.user) / args.net_name / 'reconstructions_val'
args.json_dir = './result' / Path(args.user) / args.net_name / 'jsons'
args.main_dir = './result' / Path(args.user) / args.net_name / __file__

args.exp_dir.mkdir(parents=True, exist_ok=True)
args.val_dir.mkdir(parents=True, exist_ok=True)
print(f"*** Experiment <{args.exp_name}> with model <{args.net_name}> starts ***")
# train(args)

In [52]:
def save_exp_result(save_dir, setting, result, load=''):
    for key in setting.copy():
        if isinstance(setting[key],Path):
            setting[key]=str(setting[key])
    for key in setting:
        print(key, setting[key])
            
    exp_name = setting['exp_name']
    filename = save_dir / '{}.json'.format(exp_name)
    
    if load != '':
        with open(filename, 'r') as f:
            prev_result = json.load(f)
        result['train_losses'] = prev_result['train_losses'] + [result['train_losses'][-1]]
        result['val_losses'] = prev_result['val_losses'] + [result['val_losses'][-1]]

    result.update(setting)
    print(result)

In [53]:
import os
os.listdir('/root/SNU_fastMRI/result/JB/newUnet/jsons')

['.ipynb_checkpoints', 'newUnet_test02.json']

In [54]:
prev_result = {"train_losses": [0.1126745434482247, 0.08656777790693562, 0.07871177827122593, 0.07378690729561649, 0.07049268290504225, 0.06808781471717108, 0.06613185317223144, 0.06456057557926931, 0.06343760484885727, 0.062393130393149955, 0.06151985347685847, 0.06069883476000651, 0.0599703875882443, 0.059316725576006896, 0.05878266602270565, 0.05808181663123748, 0.057610175034286254, 0.05716629564900409, 0.056687517813352864, 0.05626321737838179, 0.055850942007626844, 0.05545629135693862, 0.05508866318417259, 0.05489134027897067, 0.05446442069697546, 0.05420566421097227, 0.05379068007878527, 0.05342896272300844, 0.05311635323300992, 0.05290453547391427, 0.05290313621684736, 0.05290313621684736, 0.052437487583425924, 0.05226639777491098, 0.052061733402951414, 0.051850627857128594, 0.05167068184943321, 0.05147498514148863, 0.05166849390260028, 0.05113865452410173, 0.05105413693563014, 0.05096610848145806, 0.05081353801585682, 0.05095774336370128, 0.050551524848229924, 0.0507162285238297, 0.050575832203203455, 0.050446935845085195, 0.05012047498242761, 0.04999625129655453, 0.04996588680970973, 0.04974741200006755, 0.049968050278533085, 0.04860320224009644, 0.04815446210294754, 0.04800150095725004, 0.047890832833513586, 0.04779776343060203, 0.04771393993891032, 0.047636117846672606, 0.04756143101406761, 0.0474912318165507, 0.04742370611553679, 0.04735887271898805, 0.0472961769303147, 0.04723558514411377, 0.047176470889292846, 0.04709132959560562, 0.04703813155123363, 0.047020926691263294, 0.047008423821832075, 0.04699766940141222, 0.046987873772180826, 0.04697866506200379, 0.0469698192348613, 0.04696129743018449, 0.04695303630939492, 0.046944922609838145, 0.04693701247881294, 0.04692924243944704, 0.04692163199117178, 0.04691414033730456, 0.04690673484049927, 0.04689943693630103, 0.04689220887049168, 0.04688507470619927, 0.04687801522058013, 0.04687102059477167, 0.046864085158726455, 0.04685723850732613, 0.04685045809867487, 0.046823266487387105, 0.046821408647948794, 0.046820097898939095, 0.04681904866357967, 0.04681801560859946, 0.04681709816726897, 0.04681625153238425, 0.04681116010639341, 0.04681095238904665, 0.046810798053121896, 0.046810669024969755, 0.0468105543793769, 0.04681043987207789, 0.046810334907053795, 0.04681023879283542, 0.04681013963615258, 0.04680947140032067, 0.04680946545368558, 0.046809458538993605, 0.0468094507945386, 0.046809445124491186, 0.04680944263520208, 0.046809435582216265, 0.04680942977387501, 0.0468094202316001, 0.04680941663596027, 0.046809411657382054, 0.04680940335975169, 0.04680940405122089, 0.046809393402595253, 0.046809388285723195, 0.046809381509325065, 0.04680937708392221, 0.04680936920117336, 0.046809364499182826, 0.04680935689302166, 0.046809348733685134, 0.046809345691220666, 0.04680933863823486, 0.046809335734064234, 0.04680932923425378, 0.046809321628092614, 0.04680931623463288, 0.046809311670936174, 0.04680930876676555, 0.04680930337330581], "val_losses": [0.07701320268789814, 0.06585551620907976, 0.061017491086850666, 0.05846954715903606, 0.05406928298481534, 0.05278321527169123, 0.051271362544789804, 0.0500909573440621, 0.04946734130218826, 0.048750690719284215, 0.0483173459917648, 0.04764226000752933, 0.04702367416726294, 0.04635298759710643, 0.0459594466834779, 0.04548600833345286, 0.04518114024015082, 0.044764938422265366, 0.044362111699176955, 0.04434877992217861, 0.04410303113657932, 0.04373351353738377, 0.04349284792311298, 0.04333410508663855, 0.043107767185128286, 0.04304487510956441, 0.04289999282622632, 0.042629402882912844, 0.042556315126087006, 0.042459442890655404, 0.042235005832180816, 0.042235005832180816, 0.04221791078540079, 0.04216124558345087, 0.04207138159420431, 0.04197790294883636, 0.04197377291905489, 0.04179987619726175, 0.041921154199984025, 0.04184543466425959, 0.04193928585824301, 0.041715172190197734, 0.04171321947668402, 0.04158482951727707, 0.0418606002927049, 0.0416446015800981, 0.04149748928447408, 0.041770541248421854, 0.041778352928221026, 0.04200054122886407, 0.04196268720814386, 0.04257735269278013, 0.04172716872220217, 0.03973842909735647, 0.03965190089933748, 0.03961216693755756, 0.03959025181301756, 0.03957630606006218, 0.039568538750724136, 0.03956297619832298, 0.03955894439978465, 0.03955689453882156, 0.0395551410592463, 0.03955509466330273, 0.03955510925184342, 0.03955561622201829, 0.039555090148746676, 0.039388655159594964, 0.03937280789612834, 0.03936580836583552, 0.03936187736752048, 0.039359117303746936, 0.03935706307151727, 0.03935548986890743, 0.039354236039784635, 0.039353068318736334, 0.039351875729441134, 0.03935082971798344, 0.03934979994007595, 0.0393489188481803, 0.03934824217847903, 0.03934753002820568, 0.03934685509100581, 0.03934613287460026, 0.039345416581820516, 0.03934478669329759, 0.03934422496842617, 0.03934367351497681, 0.03934307040528762, 0.03934260542203555, 0.03934202428155458, 0.039257600411805496, 0.03925669721769365, 0.03925626580566328, 0.03925496995290252, 0.03925509665857566, 0.039255431210932785, 0.03925460798325049, 0.03925408992729895, 0.03925368091413328, 0.03925338135551504, 0.03925323773044325, 0.039252990334401154, 0.03925297073089265, 0.039252921026669226, 0.03925283456611146, 0.03925280343893089, 0.039252789415031936, 0.03925277656089387, 0.039252763882364164, 0.03925275111368668, 0.039252737404680584, 0.03925272545838465, 0.039252713214662595, 0.039252699307674485, 0.03925269073650476, 0.03925267902714311, 0.03925266759613818, 0.03925265647905403, 0.03925264662687492, 0.039252635527930284, 0.03925262517115569, 0.03925261507435313, 0.03925260493523185, 0.039252595036429735, 0.039252585007752104, 0.03925257529036179, 0.03925256544851727, 0.03925255586906909, 0.03925254635221903, 0.039252537006831935, 0.03925252821987165, 0.03925251992033688, 0.03925251046251668, 0.03925250160188047, 0.039252492898538396, 0.03925248634948134], "GPU_NUM": 0, "batch_size": 12, "num_epochs": 60, "lr": 1e-05}

In [58]:
len(prev_result['val_losses'])

137

In [59]:
    with open(path, 'w') as f:
        json.dump(prev_result, f)

In [55]:
from pathlib import Path
path = Path('/root/SNU_fastMRI/result/JB/newUnet/jsons/newUnet_test02.json')

In [56]:
from copy import deepcopy
save_exp_result(path, deepcopy(vars(args)), deepcopy(prev_result))

GPU_NUM 0
batch_size 256
num_epochs 3
lr 0.001
report_interval 2
net_name test_Unet
optim Adam
scheduler Plateau
data_path_train /root/input/train/image
data_path_val /root/input/val/image
in_chans 1
out_chans 1
input_key image_input
target_key image_label
max_key max
load 
exp_name test
{'train_losses': [0.1126745434482247, 0.08656777790693562, 0.07871177827122593, 0.07378690729561649, 0.07049268290504225, 0.06808781471717108, 0.06613185317223144, 0.06456057557926931, 0.06343760484885727, 0.062393130393149955, 0.06151985347685847, 0.06069883476000651, 0.0599703875882443, 0.059316725576006896, 0.05878266602270565, 0.05808181663123748, 0.057610175034286254, 0.05716629564900409, 0.056687517813352864, 0.05626321737838179, 0.055850942007626844, 0.05545629135693862, 0.05508866318417259, 0.05489134027897067, 0.05446442069697546, 0.05420566421097227, 0.05379068007878527, 0.05342896272300844, 0.05311635323300992, 0.05290453547391427, 0.05290313621684736, 0.05290313621684736, 0.0524374875834259