In [1]:
def set_gpu_memory_growth_mode(gpu_id=0):
    import tensorflow as tf
    try:
        gpus = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(gpus[gpu_id], True)
    except RuntimeError as e:
        print(e)

In [2]:
set_gpu_memory_growth_mode()

In [3]:
import os
import sys
import math
import pickle
from tqdm import tqdm

import data as ds
from model_shuffle import *
from matrix import *
from context import *

# Context

In [4]:
config = {
    'seed' : 1234,
    'use_64bits': True,
    
    # Loss
    'loss_type': 'mse',
    
    # LEO & target parameter theta dim
    'num_latents': 16*4,
    'num_latents_final': 16,
    'gen_theta_dim': 80*32 + 32,
    # 'gen_theta_dim': 80*16 + 16 + 16*2 + 2,
    # 'gen_theta_dim': 80*32 + 32 + 32*2 + 2,
    'gen_final_theta_dim': 32*8 + 8,
    'num_k_shot': 5,
    'num_valid_shot': 5,
    
    # Batch & Step size
    'batch_size' : 3,
    'first_decay_steps': 40,
    'meta_lr': 1e-4,
    'latent_lr': 1e-7,
    'finetuning_lr': 1e-7,
    'num_latent_grad_steps' : 5,
    'num_finetune_grad_steps' : 5,
    'num_meta_grad_steps' : 5,
    
    'gradient_threshold': 0.1,
    'gradient_norm_threshold': 0.1,
 
    # Regularizer Term
    'dropout_rate': 0.5,
    'kl_weight': 1e-3,
    'l2_penalty_weight': 1e-8,
    'encoder_penalty_weight': 1e-9,
    'orthogonality_penalty_weight': 1e-3,
    
    'num_epochs' : 20,
    'num_workers' : 8,
    'resource_path': '',
    'npy_root_path': '/home/elvin/banner/mnt/ssd3/nps/',
}

ctx = Context.create(config)

In [5]:
np.random.seed(ctx.seed)
tf.random.set_seed(ctx.seed)

# Data Provider

In [6]:
# meta_path = os.path.join(ctx.npy_root_path, "meta.pkl")
# with open(meta_path, 'rb') as f:
#     metas = pickle.load(f)

pads = ['01726', '01727', '01805', '01809', '01819', '01843', '01859', '01862', '01876', '01896', '01901', '01905', '01939', '01941', '02002', '02027', '02032', '02033', '02048', '02087', '02117', '02118', '02119', '02155', '02165', '02174', '02190', '02193', '02194', '02198', '02204', '02223', '02243', '02267', '02272', '02298', '02342', '02349', '02353', '02364', '02368', '02370', '02413', '02414', '02417', '02436', '02450', '02456', '02480', '02522', '02526', '02533', '02542', '02551', '02613', '02622', '02700', '02734', '02739', '02761', '02805', '02840', '02857', '02869', '02878', '02879', '02883', '02902', '02961', '02967', '02976', '02984', '03007', '03027', '03039', '03059', '03060', '03117', '03212', '03224', '03239', '03266', '03277', '03380', '03389', '03413', '03442', '03474', '00117', '00124', '00133', '00135', '00138', '00198', '00207', '00208', '00210', '00229', '00241', '00251', '00252', '00258', '00295', '00317', '00322', '00325', '00326', '00358', '00383', '00463', '00500', '00509', '00521', '00546', '00549', '00010', '00028', '00578', '00595', '00597', '00613', '00619', '00653', '00666', '00696', '00718', '00740', '00741', '00748', '00756', '00757', '00779', '00791', '00796', '00801', '00804', '00806', '00808', '00825', '00827', '00828', '00842', '00850', '00853', '00861', '00876', '00880', '00890', '00891', '00894', '00921', '00927', '00930', '00932', '00934', '00960', '00976', '00980', '00981', '00986', '00998', '01001', '01022', '01025', '01029', '01030', '01038', '01039', '01052', '01060', '01066', '01085', '01088', '01089', '01090', '01091', '01099', '01104', '01109', '01122', '01126', '01128', '01134', '01151', '01152', '01158', '01173', '01183', '01200', '01206', '01221', '01224', '01225', '01232', '01233', '01243', '01266', '01267', '01269', '01282', '01283', '01295', '01350', '01352', '01366', '01367', '01373', '01383', '01384', '01390', '01392', '01432', '01443', '01474', '01508', '01524', '01544', '01556', '01582', '01618', '01627', '01661', '01702', '01717']
i6s = ['01729', '01734', '01738', '01748', '01755', '01760', '01762', '01763', '01768', '01770', '01773', '01778', '01782', '01783', '01786', '01789', '01794', '01816', '01817', '01818', '01825', '01849', '01858', '01866', '01869', '01887', '01906', '01907', '01908', '01922', '01924', '01927', '01966', '01984', '01995', '01997', '02015', '02020', '02038', '02051', '02058', '02059', '02064', '02077', '02078', '02084', '02090', '02099', '02112', '02115', '02131', '02152', '02156', '02159', '02162', '02168', '02173', '02186', '02203', '02207', '02213', '02219', '02229', '02232', '02244', '02264', '02265', '02281', '02293', '02297', '02300', '02301', '02348', '02352', '02359', '02375', '02419', '02434', '02447', '02448', '02452', '02457', '02459', '02462', '02465', '02478', '02525', '02534', '02559', '02571', '02575', '02581', '02587', '02601', '02611', '02669', '02681', '02718', '02749', '02773', '02819', '02832', '02846', '02873', '02885', '02898', '02899', '02920', '02942', '02955', '02964', '02966', '02979', '02989', '02991', '02998', '03004', '03012', '03093', '03122', '03125', '03177', '03179', '03190', '03193', '03199', '03205', '03211', '03253', '03302', '03312', '03314', '03315', '03326', '03327', '03328', '03340', '03348', '03358', '03377', '03384', '03397', '03467', '03469', '03501', '03523', '00033', '00097', '00099', '00104', '00121', '00126', '00130', '00145', '00149', '00150', '00153', '00156', '00164', '00194', '00200', '00209', '00222', '00225', '00227', '00236', '00237', '00239', '00266', '00267', '00268', '00288', '00299', '00351', '00377', '00459', '00480', '00491', '00493', '00503', '00505', '00507', '00513', '00514', '00540', '00553', '00554', '00002', '00005', '00563', '00566', '00569', '00574', '00580', '00581', '00588', '00606', '00607', '00611', '00616', '00626', '00638', '00643', '00644', '00649', '00650', '00658', '00679', '00691', '00700', '00712', '00729', '00733', '00755', '00789', '00798', '00831', '00837', '00840', '00852', '00868', '00869', '00872', '00888', '00889', '00899', '00900', '00914', '00923', '00924', '00938', '00939', '00944', '00945', '00947', '00948', '00949', '00953', '00956', '00961', '00963', '00971', '00974', '00999', '01003', '01009', '01010', '01015', '01021', '01024', '01031', '01034', '01046', '01054', '01055', '01057', '01058', '01059', '01064', '01069', '01077', '01086', '01087', '01095', '01102', '01120', '01127', '01139', '01147', '01149', '01156', '01157', '01171', '01172', '01186', '01188', '01191', '01231', '01244', '01256', '01275', '01276', '01281', '01285', '01293', '01298', '01300', '01301', '01316', '01319', '01326', '01327', '01328', '01330', '01361', '01362', '01380', '01382', '01388', '01430', '01434', '01445', '01446', '01448', '01451', '01467', '01477', '01478', '01482', '01486', '01492', '01496', '01511', '01525', '01528', '01533', '01534', '01543', '01546', '01553', '01569', '01636', '01637', '01651', '01658', '01669', '01671', '01678', '01684', '01690', '01703', '01705', '01710', '01713', '01718', '01719']

# pids = [p for p in list(metas.keys()) if p not in differs and p not in pads and p in i6s]
pids = i6s
pids = [p for p in pids if os.path.exists(os.path.join(ctx.npy_root_path, 'metas', 'meta-{}.pkl'.format(p)))]

np.random.shuffle(pids)

In [7]:
len(pids)

334

In [8]:
print(len(pids))
for_train = pids[:30]
for_test = pids[314:]
print(len(for_train), len(for_test))

334
30 20


In [9]:
def make_mixed_tasks(data_providers):
    all_tasks = []
    for dp in data_providers:
        for i, tasks in enumerate(dp):
            if tasks is None:
                continue
            all_tasks.append(tasks)
    np.random.shuffle(all_tasks)
    return all_tasks

In [10]:
mat = Matrix(ctx)
t_mat = Matrix(ctx)

In [11]:
net = Leo.create(ctx)

In [None]:
num_profiles = 10

for out_e in range(0, 20):
    for i in range(0, len(for_train), num_profiles):

        # train
        profile_ids = []
        data_providers = []
        for p in pids[i:i+num_profiles]:
            profile_ids.append(p)
            dp = ds.DataProvider(ctx, p)
            data_providers.append(dp)
        tasks = make_mixed_tasks(data_providers)
        test_task_ids = np.random.choice(len(tasks), 2)
        print(profile_ids)
        net.run_with_test(20, tasks, mat, test_task_ids, prior_task_id=0)

    # test
    profile_ids = []
    data_providers_test =[]
    test_idxs = np.random.choice(len(for_test), 20)
    for i in test_idxs:
        tpid = for_test[i]
        profile_ids.append(tpid)
        tdp = ds.DataProvider(ctx, tpid)
        data_providers_test.append(tdp)
    t_tasks = make_mixed_tasks(data_providers_test)
    print(profile_ids)
    for tasks in t_tasks:
        t_mat.add_test(net.test(tasks))

['01275', '02162', '02885', '01478', '01719', '02020', '01763', '02265', '02832', '01281']
e[000] test: 2.3839 |3 / 누적: 3.328000 |                                                           
e[000] test: 3.4125 |
e[001] test: 2.4383 |7 / 누적: 3.095000 |                                                           
e[001] test: 3.2821 |
e[002] test: 2.3698 |3 / 누적: 3.061000 |                                                           
e[002] test: 3.3499 |
e[003] test: 2.4415 |0 / 누적: 3.031000 |                                                           
e[003] test: 3.026 |
e[004] test: 2.3679 |7 / 누적: 3.037000 |                                                           
e[004] test: 3.1609 |
e[005] test: 2.3885 |0 / 누적: 3.012000 |                                                           
e[005] test: 3.2569 |
e[006] test: 2.4142 |3 / 누적: 2.986000 |                                                           
e[006] test: 3.2213 |
e[007] test: 2.3176 |7 / 누적: 3.014000 |                        

In [7]:
40 * np.tan(np.deg2rad(1)) 

0.6982025971287034

In [2]:
import numpy as np