-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_ner_debug.py
78 lines (69 loc) · 2.56 KB
/
main_ner_debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import pytorch_lightning as pl
from problem import DataModule
from problem.lit_recurrent_ner import LightningRecurrent_NER
from problem.ner_problem import NERProblem
from evolution import Optimizer
import warnings
warnings.filterwarnings("ignore")
import os
import pickle
from datetime import datetime
import numpy as np
path = os.path.dirname(os.path.abspath(__file__))
today = datetime.today().strftime("%Y-%m-%d")
def input_chromosome(args):
try:
with open(args.checkpoint_file,'rb') as f:
d = pickle.load(f)
return np.array(d['population'][0])
except:
print('Read from txt fle')
try:
with open(args.file_name, 'rb') as f:
d = pickle.load(f)
print(d)
return np.array(d['population'][-1])
except:
with open(path + args.file_name, 'r') as f:
chromosome = f.read()
try:
chromosome = chromosome.split()
chromosome = [int(x) for x in chromosome]
return np.array(chromosome)
except:
return np.array(chromosome)
def parse_args():
parser = argparse.ArgumentParser()
parser = NERProblem.add_arguments(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = DataModule.add_argparse_args(parser)
parser = DataModule.add_cache_arguments(parser)
parser.add_argument("--file_name", default= '/chromosome.txt', type=str)
parser.add_argument("--checkpoint_file", default= '/checkpoint.pkl', type=str)
parser.add_argument("--save_path", default = path + f"chromosome_trained_weights.gene_nas.{today}.pkl", type= str)
parser = LightningRecurrent_NER.add_model_specific_args(parser)
parser = LightningRecurrent_NER.add_learning_specific_args(parser)
parser = Optimizer.add_optimizer_specific_args(parser)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
args.num_terminal = args.num_main + 1
args.l_main = args.h_main * (args.max_arity - 1) + 1
args.l_adf = args.h_adf * (args.max_arity - 1) + 1
args.main_length = args.h_main + args.l_main
args.adf_length = args.h_adf + args.l_adf
args.chromosome_length = (
args.num_main * args.main_length + args.num_adf * args.adf_length
)
args.D = args.chromosome_length
args.mutation_rate = args.adf_length / args.chromosome_length
return args
def main():
# get args
args = parse_args()
# solve problems
problem = NERProblem(args)
chromosome = input_chromosome(args)
problem.evaluate(chromosome= chromosome)
if __name__ == "__main__":
main()