-
Notifications
You must be signed in to change notification settings - Fork 178
/
neural_networks.ipynb
1981 lines (1981 loc) · 222 KB
/
neural_networks.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Neural Network Potentials.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jax-md/blob/main/notebooks/neural_networks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b7RYjwcysgZI"
},
"source": [
"Copyright 2020 Google LLC\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "code",
"metadata": {
"id": "KVIHud2diL51",
"cellView": "form"
},
"source": [
"#@title Imports & Utils\n",
"!pip install -q git+https://www.github.com/deepmind/haiku\n",
"!pip install -q git+https://www.github.com/deepmind/optax\n",
"!pip install -q --upgrade git+https://www.github.com/google/jax-md\n",
"\n",
"# Imports\n",
"\n",
"import os\n",
"import numpy as onp\n",
"import pickle\n",
"\n",
"import jax\n",
"from jax import lax\n",
"\n",
"from jax import jit, vmap, grad\n",
"\n",
"# TODO: Re-enable x64 mode after XLA bug fix.\n",
"# from jax.config import config ; config.update('jax_enable_x64', True)\n",
"import warnings\n",
"warnings.simplefilter('ignore')\n",
"import jax.numpy as np\n",
"\n",
"from jax import random\n",
"\n",
"import optax\n",
"\n",
"from jax_md import energy, space, simulate, quantity\n",
"\n",
"# Plotting.\n",
"\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"import pylab as pl\n",
"from IPython import display\n",
"from functools import partial\n",
" \n",
"sns.set_style(style='white')\n",
"sns.set(font_scale=1.6)\n",
"\n",
"def format_plot(x, y):\n",
" plt.xlabel(x, fontsize=20)\n",
" plt.ylabel(y, fontsize=20)\n",
" \n",
"def finalize_plot(shape=(1, 1)):\n",
" plt.gcf().set_facecolor('white')\n",
" plt.gcf().set_size_inches(\n",
" shape[0] * 1.5 * plt.gcf().get_size_inches()[1], \n",
" shape[1] * 1.5 * plt.gcf().get_size_inches()[1])\n",
" plt.tight_layout()\n",
"\n",
"def draw_training(params):\n",
" display.clear_output(wait=True)\n",
" display.display(plt.gcf())\n",
" plt.subplot(1, 2, 1)\n",
" plt.semilogy(train_energy_error)\n",
" plt.semilogy(test_energy_error)\n",
" plt.xlim([0, train_epochs])\n",
" format_plot('Epoch', '$L$')\n",
" plt.subplot(1, 2, 2)\n",
" predicted = vectorized_energy_fn(params, example_positions)\n",
" plt.plot(example_energies, predicted, 'o')\n",
" plt.plot(np.linspace(-400, -300, 10), np.linspace(-400, -300, 10), '--')\n",
" format_plot('$E_{label}$', '$E_{prediction}$')\n",
" finalize_plot((2, 1))\n",
" plt.show()\n",
"\n",
"# Data Loading.\n",
"\n",
"def MD_trajectory_reader(f, no_skip=20):\n",
" filename = os.path.join('Supplementary/', f)\n",
" fo = open(filename, 'r')\n",
" samples = fo.read().split('iter= ')[1:]\n",
" steps = []\n",
" lattice_vectors = []\n",
" positions = []\n",
" forces = []\n",
" temperatures = []\n",
" energies = []\n",
" for sample in samples[::no_skip]:\n",
" entries = sample.split('\\n')\n",
" steps.append(int(entries[0]))\n",
" lattice_vectors.append(onp.array([list(map(float, lv.split())) for lv in entries[1:4]]))\n",
" assert entries[4]=='64'\n",
" temp = onp.array([list(map(float, lv.split()[1:])) for lv in entries[5:69]])\n",
" positions.append(temp[:,:3])\n",
" forces.append(temp[:,3:])\n",
" remaining_lines = entries[69:]\n",
" temperatures.append(float([entry for entry in entries[69:] if 'Temp' in entry ][0].split('=')[1].split()[0]))\n",
" energies.append(float([entry for entry in entries[69:] if 'el-ion E' in entry ][0].split('=')[1].split()[0]))\n",
" assert (len(set(steps))-(steps[-1]-steps[0]+1)/no_skip) < 1\n",
" return np.array(positions), np.array(energies), np.array(forces)\n",
"\n",
"def build_dataset():\n",
" no_skip = 15\n",
" data300, energies300, forces300 = MD_trajectory_reader(\n",
" 'MD_DATA.cubic_300K', no_skip=no_skip)\n",
" data600, energies600, forces600 = MD_trajectory_reader(\n",
" 'MD_DATA.cubic_600K', no_skip=no_skip)\n",
" data900, energies900, forces900 = MD_trajectory_reader(\n",
" 'MD_DATA.cubic_900K', no_skip=no_skip)\n",
" dataliq, energiesliq, forcesliq = MD_trajectory_reader(\n",
" 'MD_DATA.liq_1', no_skip=no_skip)\n",
"\n",
" all_data = np.vstack((data300, data600, data900))\n",
" all_energies = np.hstack((energies300, energies600, energies900))\n",
" all_forces = np.vstack((forces300, forces600, forces900))\n",
" noTotal = all_data.shape[0]\n",
"\n",
" onp.random.seed(0)\n",
" II = onp.random.permutation(range(noTotal))\n",
" all_data = all_data[II]\n",
" all_energies = all_energies[II]\n",
" all_forces = all_forces[II]\n",
" noTr = int(noTotal * 0.65)\n",
" noTe = noTotal - noTr\n",
" train_data = all_data[:noTr]\n",
" test_data = all_data[noTr:]\n",
"\n",
" train_energies = all_energies[:noTr]\n",
" test_energies = all_energies[noTr:]\n",
"\n",
" train_forces = all_forces[:noTr]\n",
" test_forces = all_forces[noTr:]\n",
"\n",
" return ((train_data, train_energies, train_forces),\n",
" (test_data, test_energies, test_forces))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "muBsYny61Gbv"
},
"source": [
"# Neural Network Potentials"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3gjTwTSvvozm"
},
"source": [
"An area of significant recent interest is the use of neural networks to model quantum mechanics. Since directly (or approximately) solving Schrodinger's equation is extremely expensive, these techniques offers the tantalizing possibility of conducting large-scale and high-fidelity experiments of materials as well as chemical and biochemical systems. \n",
"\n",
"\\\n",
"\n",
"Usually, neural networks are fit to energies computed from Density Functional Theory (DFT). DFT is a ubiquitous *ab initio* formalism for approximating solutions to Schrodinger's equation. It offers a balance between accuracy and speed; DFT is much faster than more precise solutions to quantum systems, but is fast enough to use on systems of hundreds of atoms. Nonetheless, DFT calculations scale as $\\mathcal O(N^3)$ and so they are prohibitively expensive to run on large systems or for long simulation trajectories. \n",
"\n",
"\\\n",
"\n",
"As with many areas of machine learning, early efforts to fit quantum mechanical interactions with neural networks relied on fixed feature methods with shallow neural network potentials. Lately, however, these networks have been replaced by deeper graph neural network architectures that learn salient features. JAX MD includes both popular fixed-feature methods as well as graph neural networks. \n",
"\n",
"\\\n",
"\n",
"Here we will use JAX MD to fit a state-of-the-art graph neural network to open-source DFT data from a 64-atom Silicon system that accompanied a [recent paper](https://aip.scitation.org/doi/10.1063/1.4990503). This Silicon system was simulated at several different temperatures. We will uniformly sample data from these trajectories to construct training and test sets. We will follow modern best-practices and fit to both energies and forces computed using DFT. We will then use this network to run a simulation using JAX MDs simulation environments. To start with we first download the data. This might take a little a minute or two."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZgP6T5VV3wSn",
"cellView": "form"
},
"source": [
"#@title Download Data\n",
"\n",
"!wget https://aip.scitation.org/doi/suppl/10.1063/1.4990503/suppl_file/supplementary.zip\n",
"!wget https://raw.githubusercontent.com/google/jax-md/main/examples/models/si_gnn.pickle\n",
"!unzip supplementary.zip"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pjbaj79a4Av4"
},
"source": [
"We will then load the data using a small utility function into training and test sets. Each split will include particle positions, whole-system energies, and per-particle forces. To assist in training we will compute the mean and standard deviation of the data and use this to set the initial scale for our neural network."
]
},
{
"cell_type": "code",
"metadata": {
"id": "dNIar1gB3zig",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bebe6c4b-021f-4262-f483-2bd987cd78d6"
},
"source": [
"train, test = build_dataset()\n",
"\n",
"positions, energies, forces = train\n",
"test_positions, test_energies, test_forces = test\n",
"\n",
"energy_mean = np.mean(energies)\n",
"energy_std = np.std(energies)\n",
"\n",
"print('positions.shape = {}'.format(positions.shape))\n",
"print('<E> = {}'.format(energy_mean))"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"positions.shape = (2416, 64, 3)\n",
"<E> = -368.9131164550781\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kib1B2Mm4yRp"
},
"source": [
"Next we create a space for our systems to live in using `periodic` boundary conditions."
]
},
{
"cell_type": "code",
"metadata": {
"id": "nMzMGNf0463Y"
},
"source": [
"box_size = 10.862 # The size of the simulation region.\n",
"displacement, shift = space.periodic(box_size)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "WNs8v2745Mc3"
},
"source": [
"We can now instantiate a graph neural network using the `energy.graph_network_neighbor_list` command. This neural network is based on [recent work](https://www.nature.com/articles/s41567-020-0842-8) modelling defects in disordered solids. See that paper or the review by [Battaglia et al.](https://arxiv.org/abs/1806.01261) for details. We will add edges between all neighbors that are separated by less than a cutoff of 3 Angstroms. In JAX MD neural network potentials are specified by a triple of functions: a `neighbor_fn` that creates a list of neighbors that reside within the cutoff, an `init_fn` that initializes the parameters of the network, and an `energy_fn` that evaluates the model."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Fu4RvlXU5Cwb"
},
"source": [
"neighbor_fn, init_fn, energy_fn = energy.graph_network_neighbor_list(\n",
" displacement, box_size, r_cutoff=3.0, dr_threshold=0.0)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BLF5A-D96jDU"
},
"source": [
"To start with, we construct an initial neighbor list which will be used to estimate the maximum number of neighbors. This is necessary since XLA needs to have static shapes to enable JIT compilation. See [here](https://github.com/google/jax-md#spatial-partitioning-partitionpy) for details."
]
},
{
"cell_type": "code",
"metadata": {
"id": "FzcabwoM6iiN",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ab600785-58cc-4733-b56c-32b10488ae93"
},
"source": [
"neighbor = neighbor_fn.allocate(positions[0], extra_capacity=6)\n",
"\n",
"print('Allocating space for at most {} edges'.format(neighbor.idx.shape[1]))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Allocating space for at most 804 edges\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t2pWF0jE_Ha4"
},
"source": [
"Using this neighbor prototype we will write a wrapper around our neural network energy function that will construct a neighbor list for a given state and then compute the energy. This is helpful because it allows us to use JAX's automatic vectorization via `vmap` along with our neighbor lists. Using JAX's automatic differentiation we can also write down a function that computes the force due to our neural network potential.\n",
"\n",
"Note that if we were running a simulation using this energy, we would only rebuild the neighbor list when necessary."
]
},
{
"cell_type": "code",
"metadata": {
"id": "CZVtIaMs_3IY"
},
"source": [
"@jit\n",
"def train_energy_fn(params, R):\n",
" _neighbor = neighbor.update(R)\n",
" return energy_fn(params, R, _neighbor)\n",
"\n",
"# Vectorize over states, not parameters.\n",
"vectorized_energy_fn = vmap(train_energy_fn, (None, 0))\n",
"\n",
"grad_fn = grad(train_energy_fn, argnums=1)\n",
"force_fn = lambda params, R, **kwargs: -grad_fn(params, R)\n",
"vectorized_force_fn = vmap(force_fn, (None, 0))"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "R3MvxVtj-SUQ"
},
"source": [
"Next we will initialize the parameters of the graph network. This is done by providing the `init_fn` with a random key as well as an example input. As with the neighbor lists, this example input is used to deduce the shape of the various parameters."
]
},
{
"cell_type": "code",
"metadata": {
"id": "dhnKug1_-MPS"
},
"source": [
"key = random.PRNGKey(0)\n",
"\n",
"params = init_fn(key, positions[0], neighbor)"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "1H2wGxzg-tJv"
},
"source": [
"Now, we can use JAX's automatic vectorization via `vmap` to compute predicted energies for all of the states using the untrained network."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ns5oB-gN-n9P"
},
"source": [
"n_predictions = 500\n",
"example_positions = positions[:n_predictions]\n",
"example_energies = energies[:n_predictions]\n",
"example_forces = forces[:n_predictions]\n",
"\n",
"predicted = vmap(train_energy_fn, (None, 0))(params, example_positions)"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BP6x1sEPAY_A",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 441
},
"outputId": "85cb3cd3-0a60-4abe-9eec-7f397184e69c"
},
"source": [
"plt.plot(example_energies, predicted, 'o')\n",
"\n",
"format_plot('$E_{label}$', '$E_{predicted}$')\n",
"\n",
"finalize_plot((1, 1))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAGoCAYAAAATsnHAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU5dk/8O9MJhtZBLISQZRgQha2IlLUmkStVAyLUQNSQSSiFrQKirXVKuBata9WERXKZoSwWLAlKlpA4o8Yd9YsgEEwGEIgBJOQyTrz+yOdYZZzzpyZOTM5M/l+ruu9XnPmmbM1nDvPc+7nfjRGo9EIIiIildF29wkQEREJYYAiIiJVYoAiIiJVYoAiIiJVYoAiIiJVYoAiIiJV0nX3CfRU9fXnYTB4JsM/KiocdXVNHtm3P+N9cx7vmWt43y7QajXo0ydM8DMGqG5iMBg9FqBM+yfn8b45j/fMNbxvjnGIj4iIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIkBioiIVIm1+HxESWkNNhdVoq6hFVGRwcjJSMTYtPjuPi0iIo9hgPIBJaU1WPNxBdo6DACAuoZWrPm4AgAYpIjIb3GIzwdsLqo0ByeTtg4DNhdVdtMZERF5HgOUD6hraHVqOxGRP2CA8gFRkcFObSci8gcMUD4gJyMRQTrr/6mCdFrkZCR20xkREXkekyR8gCkRgll8RNSTMED5iLFp8QxIRNSjcIiPiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUiQGKiIhUietBEfmhktIaLnBJPo8BisjPlJTWYM3HFWjrMAAA6hpasebjCgBgkCKfwiE+Ij+zuajSHJxM2joM2FxU2U1nROQaBigiP1PX0OrUdiK1YoAi8jNRkcFObSdSKwYoIj+Tk5GIIJ31P+0gnRY5GYnddEZErmGSBJGfMSVCMIuPfB0DFHkV05+9Y2xaPO8r+TwGKPIapj8TkTMYoMjMnd6NnO9KpT8zQPkn9pjJHQxQBMC93o3c7zL9uWfprh4zg6L/YBYfAXBvcqfc7zL9uWfpjgnDpqBo+qPHFBRLSms8dkzyHL8LULt27cJ9992HzMxMDBs2DGPGjEFubi62bNkCg8Fg1761tRWvvvoqsrKyMHToUIwbNw4rV64UbGswGLBy5UqMGzcO6enpyMrKwquvvorWVt/vAbjTu5H7XaY/9yzd0WNmFQ3/4ndDfIcPH0ZgYCCmTJmC6Oho6PV6FBUV4fHHH8fevXuxaNEiq/YPP/wwdu3ahalTpyI1NRVffvkl/va3v+HkyZN44oknrNo+//zzyM/PR3Z2Nu655x6UlZVh2bJlOHLkCJYuXerNy1RcVGSw4INDTu9G7neZ/tyzuPM75SoOI/sXvwtQ9957r922GTNm4L777sPGjRsxf/58XHTRRQCAoqIi7Ny5E/PmzcP9998PALj99tsRHByM/Px85Obm4vLLLwcAHDlyBO+99x5yc3PxzDPPmPcdHR2N119/HUVFRcjIyPDCFXpGTkai1fsCQH7vxpnvMv2553Dnd8pV3REUyXP8bohPTL9+/WAwGNDU1GTetnXrVuh0Otx5551WbWfOnAmj0YgPP/zQvK2wsBBGoxEzZ860ajt9+nTodDoUFhZ69Pw9bWxaPO66aYj5H3JUZDDuummIrGDiznfJf3XH7wWHkf2L3/WgTJqamtDW1obGxkYUFxdj8+bNGDx4MBISEsxtDhw4gMGDByM8PNzqu0lJSQgPD8fBgwfN2w4ePIiIiAgkJlr/okdGRmLQoEFWbX2VO70bb/eM8j+pwGd7qs0/BwcGYMbvku3OwV8yunz1Orz9e8FhZP/itwHqoYcewu7duwEAGo0GV111FRYtWgSNRmNuU1tbi9GjRwt+Py4uDqdOnbJqGxcXJ9g2Pj4e3333nYJn3z1cfQh6++FpG5wAoLW9EysKywFceEjZtpOb5qy2YCCVrg24/jAWuk539qcWHEb2H6oNUA0NDVizZo2str169UJeXp7VtkcffRR5eXmora3Fzp07UV9fj+bmZqs2LS0tCAoKEtxncHCw1XCgXq9HRESEaNuWlhZZ52oSFRXuuJGFCY/826n2lvpGBOHc+XYYDEbZ36lraMXyrWVYvrXMqWPVNbRi1UflWL/jCJqa2xHdJxSjh8Tim4panK7XQ6vVwGAwIqJXIACY28y4KQWZowYAAHZ9V4V3Py7HmXq91We7vqvCsg8OoLG5XfDYBqMRH+z+EZERIaLt2joMWL/jCD7Y/aPd/gGg9KdzeHfbIbS2d5qv591thxAZEWJuYyJ2nkr7YHeJYGba+h1H0NZukHWutnZ9V2V3nSs/LIdGA3R0Gp3aX0yM8L8Lksb75piqA9SSJUtktY2OjrYLUCkpKeb/njx5Mp566ilMnz4d27ZtQ9++fQEAISEhaGtrE9xna2srQkJCzD+HhobKbitHXV2T7IAx68WdTu3b1tlG4fP2lI5Oozk4nK7X46OS4+bPTNdsGTxO1+vxxsa9aGjsCvKWvQXTZ9+X1+DzvdXodHDLTO1tH+iWGpvbrc7PdOyJmZdjdWGp+aFt0treidWFpUi7pLd5m22vxnI/Sv/1frpeL3odtoTOVYjQdXYK/D62tnfi1YLvRa8rJiYCp083Sh6L7PG+XaDVakT/YFdtgOrfvz8OHTqk2P6ys7OxYcMGbN++Hbm5uQCA2NhYq2E8S6dOncLIkSPNP8fGxmLfvn2CbWtqakSH/0gey7kqQr2For3VkBPPtRr778s99sTMy2WnKXuzbJNYZpoYd+auCTEYwZqJ1C16TBafaQjul19+MW8bOnQoKisrrYbygK6U8qamJqSlpZm3paeno7GxEZWV1hP+GhoacPToUau25Jq6hlbRB6e84KSR1U7s2ID8ahfenG8jlpkWFhIg2F7u3DVncLIrdQe/C1Bnzpyx22Y0GlFQUAAAGD58uHl7dnY22tvbsXbtWqv2q1atgkajwc0332zeNn78eGg0Grv3Yvn5+ejo6MCECROUvIweKSoyWPTBqdUIbraSl53i8MEbpBPekel7ctOUvVm2SSxde9pvk11OqRa6zgANoAsQv9Gc7EreptohPldlZ2dj9OjRSE1NRXR0NM6cOYNt27ahoqIC2dnZuPLKK81tMzMzkZWVhddeew2nTp1CSkoKvvzySxQWFuLOO+9EUlKSuW1ycjKmTZuGtWvXorm5GWPGjEF5eTkKCgqQlZXl05N0labRAEYnezKWD1ahyZ1XD40XfQelC9Dg7vEp5uEn2+8DQHiozpys0dbRIXps2zRl05Chqfdg+tzbk1ClMtNcyboTS8cGgBWFZYI9UU52JW/TGI3OPkrUbcmSJdi9ezeOHTuGxsZG9OrVC8nJyZg8eTJycnKg1Vr/1djS0oI333wTW7duxZkzZ5CQkIDc3FzcfffdCAiwHkLp7OzEqlWrsHHjRlRXVyM6OhoTJ07E3LlzERzs3D9eZ5IkAPcSJXqHBaKhud3l4S+5gnQa3HVTV3KK5YNvWGIU9lfWmR/4BiMQFhIAjUaDJn2HeZujVOeS0hqs++8hnG+58HI/PFSHO25Isnooi6VPCwWusJAATPtt1/wpyxfXtkkQXdentZpoqrZ0dKXIuXYTvux3De/bBVJJEn4XoHyFswHKGb7yy+/Mg9BdC5YWi5bAeXnO1QAu3LeS0hrJXoSpvRglA5cr+1Li+HL34Su/a2rD+3aBT2bxkf/zZiac3KQGU9AU+9vB0XsYJddAcmVfSh2fk11JDfwuSYJ8hzcz4eQmNQgFTTn7kfq+qxlwruxLrctNlJTWYMHSYsx6cScWLC3m+kwkC3tQ1G28WXlaLKlhWGKUefgvpk+oZHCUkwShZNB1ZV9qXG6iu1bWJd/HAEXdxlOZcFLvT2yTN4oP1FhVgxCj1UDWuzElg25YSIBVQojldm8cXyneHMol/8IARd3GE5WnHf21brnvBUuLZVWdcJS4YRkQu7ITrdPstRqNS0HXsrCxnO1A96zB5Igae3XkGxigqFsJvYx3JwvNmb/WpR6Qpp6Io+PbBkShHo/BaMQPJ845HXib9B1ObQfUudyEGnt15BsYoEhV3H1f4cxf61IPTkep5CaOkipMivZWY/q4IbL2Kef8pKgtA0+NvTryDcziI1VxNwvNmRJESqy+KneYymCE01ls/rI6LFdcJlexB0Wq4u77Cmf+WrcdDovpE4rJ11zm1IPTmUrjpnZye4VKDtd1d9ULtfXqyDcwQJGquPu+wtmHuuWD05XZ/UIBUQ65WWxKPNjlDJt2dwAjEsIARari6fcVSj+IhQJibJ9QHPrpHAxGmOsMCvFWFpujxBHOUyK1YoAiVXFlWMsy6ISH6qBv6TBXPTc9bH84cQ5fl5+yyrJT4kEsFvAst4sFKXey2JwJtI6GTTlPidSKAYpUx5lhLdu//oVSsNs6DPhsT7Xg9915EAv1PJZvLcO728rR0WE0B0mh4OROr1DouCsLy1Cw/TCa9B12AcvRsKnS85R64nBhT7xmb2AWH/k0uWneUlx9EIsdu7XdKLhulWnRRXez2ISO22m8EJxNPUNTpqCjbEAlF180BU/bhBB/rr3XE6/ZWxigyKcp8R5Hzmq9ShzbYARWPn4dXp5ztVt/Xcs5ru0ii1Jp3kqms6u1WK0n9cRr9hYO8ZFPcybNW4yry3I5e2ylKifIPa5lG6lhU9N20xAhAATqXIvaPbGsUU+8Zm9hD4p8mtBf/7oAjbmgalRkMLJGJti1seRq4BA6thilMhFLSmvQ2m5fTkmIs9fV1n6hF3C+pRMrC8vwx3987tTkYiWHC31FT7xmb2EPinya3Ky/wf172y0XD7gXOIR6Hia6AA2CA7U439Lp9EtzqcxAoTlXQToNDEagw+LFl7PXJfe91g8nzmF/ZZ353GZmpyHtkt7m7/TEskY98Zq9hUu+dxMu+d49pLKt3LlvzmRxSbUVC0JZIxPMgUFIVxV1jWAWnxyzXtwpu62l4MAAzPhdstWxemJGm7PXzH+jF0gt+c4A1U0YoNTHG/dNKABZLudhWjzRFY6WBZHiznGdKa5LXfhv9AKpAMV3UERe5CjjSypIOMo2dCdzzJn3abaYDECewgBF5EWOMr6kXqwbjHAYRFwNFrap6GEhAdAFyMvkYzIAeQqTJIi8yFFVh5yMRCzfWib63ZyMRPO7DrE2rrJNRbd9rzIsMQrFB2qseoDBgQF2yQBqeAelhnMg9zFAEXmRo4yvsWnx+OHEObvSTKY2piAi9i5LycwxoblTg/v3tnrwm7L4LAOCpe4oPMvit/6DAYrIhif/+naUFl9SWoP9lXVW39FqgKuHWgcLOen1nrgO26AVExOB/+w6IrnkiLcLz7L4rf9ggCKy4I2/vsWqOoilmBuMQPGBGgzu39suSEmlsrtzHc4ENzn1EL2ZSMHKDv6DSRJEFrqzrprUg97Zc3DnOpwtfirnwe/NRApWdvAfDFBEFrrzr29Hx3DmHNy5DmeDm6MHv7erKihZ/Ja6FwMUkYXu/Ovb0TGcOQd3rsPZ4CY1h8rdpUVc4ah6O/kOvoMisuDJumolpTVWdfvCQgIw7bcXygQJHdvVc3DnOsRS4cNCApD/SQWK9labl7P/3a8H4rb/7VNNad3OLHpJ6sVSR92EpY7Ux3TfPJH9VlJag1UflVsVdDUJD9XhjhuS7JaKt2Ub0OQc05XryP+kQnAFYo0GEHpaZI1MwPRxQ2SdE3Xhv9ELpEodsQdFZMMTf31vLqoUDE5AV8Vwyww7U6CyDWimJTAs92kZfIS2OVsjr6S0BsUHhJMhxP6ULdpbzQBFHsEAReQFjpITbOfpiAW0TiOw7r+H0N5htEohX1lYBo1WY/5OXUMrlm8tww8nzkkGD9teVmt7p8OUcVseGgggYpIEkTc4m5wgFdDOt9gHkU6b9aBMPttTLZoeLpRObruulRyOitgSuYoBisgLcjISHRZftQxiSmYNiqWHy5lgK0fGiAS390EkhAGKyAvGpsXj7vEpCA8VHlW3zbATC2gBGojuQ4ySc7uCAzXmHpNWA4wfO5Dvn8hj+A6KyEssky8cZdgJLSdvyuIDIFn7zpbUnCixbMHzLZ2C32ltN2L2hFRFViEmcoQBiuh/dn1XhdWFpS6llzub0i0nU1Bq+YvgwABZ56UL0IjOfRKaKxWgATQa6aFIZ2sTcukLchUDFBG6HqLvbjuE1vaunoMzxVXdKcxaUlqDdf89ZO6xWM6JkjqG6Twd6eg0omD7YcFzMS3tYZp4CwBGaBwmSlhmHDoK6lz6gtzBd1BE6EoYsH3oyy2u6mph1pLSGqwsLLMaTmvSd2DVR+V2mXfuJDSY5lnZ7tM058kyTdwgc95+XUMrSkprsGTTPsmist1ZfJd8HwMUEdxLJHD1u5uLKiE0d7ej02j3AHe3WK1QUHAn6EVFBssK6lz6gtzBAEUE94qruvpdqYe07WdKpJ0LrXbrqmGJUQ6Dj9j8K4BLX5A8DFBE6EoYsE08kFtc1dXlHaQe0rafSVUMl8t2n3KCRJBOi5SBve22Fx+oEU13N+1XahiPS1+QHAxQROh6Yf/A7cNdWqLB1eUdhiVGCW7Xauwz72yPER6qQ1hIgPl4syekSgYcoYApJ+hdPTQetfV6u+1tHQYYjUbJoC7VQxNKAlmwtBizXtyJBUuLJXtf1HMwi49UzZspypmjBiDtEvveghyuFJjdX1knuD00WCu4LznHEJofpYH1uyHLfQQFaiXfQxUfqBH9/HxLJx6Z9ivRLD6tRrhOn21pJG8uT0++hQGKVMvfU5TFehhik2QtST2ULedKtbZ3whQjLO8fIG+yb1uHQTTQREUGSwZ1sSKyttulMv08meJP6scARarlzoPLF4hVcnD0bsjRQ9m0XMfyrWV237XsScnN4BMLNGJDlCZyr89Ty9P7w+9IT8d3UKRa/p6i7GpyhZy5RVIJCnUNrU7dQ7Fq5WJDlCZyr8+by9OTb2EPilTL1R6Gr7AdkpP7/kTOQ1nqAW26f3If4mI9KEffl3t9nlieXuh3hO+qfA8DFKmWOw8uX+FKcoWch7JYG6CrTNLoIbGSCRAmWSMTsL+yTnRfs579FJOvuUz0GuTWHAScD9SA/N8RvqvyTQxQpFruPLj8mZyHslAbkyZ9B4oP1ODqofHm4BMVGYxhiVFWP5v293X5KdFzOV2vV+RBLxTI5PR45P6O8F2Vb2KAIlX74cQ51Dd2/fVe39iKH06c6/EPFDkPZds2tto6DPimohavP3St6HFsex1iPPGgd6bHI6eXxndVvokBilQr/5MKfLan2vyzwQjzzz19kTyhtaWWby1zqpfZpO9ASWmNaFtnavXJedA78w5I6R6Pv7/P9FfM4iPVKtpb7dT2nsjU07CtKJ7/SYXVdjGOsv3kkpsaL1X5XM6xXe3xuJoxSd2LAYpUS+5Ez55MrKdRtLdaVu9HTrafI0qlxss5tqs9HlfLUVH34hAfqZbcUjk9mViAkRvEpR74UokWJjF9QiWz+Bydp9h2T2RwupIxSd2LAYpUK2NEgtU7KMvt1EXs3YpYcLcUoJGuKm56mK8oLBMtdbTyyRtx+nQjAOl3THLfAVnuIywkAEGBOjTpO5jB2UMFLFy4cGF3n0RPpNe3QebipU4LCwtGc3ObZ3buRcMHR6PhfCt+OtUII7oeupkjEzyWIOGL9y2iVxAOHq1Dp0UECdJp8Zvh/VB95rzVdktBOg1CgnX4svQUdu+vRkSvIAyIDbdrNyA2HDF9QgWPcccNSUi+NArNzW3md0ym5eL1rZ04eLQOUReFYEBsuOh53nFDkvm4tvto7+hqO+vmFMy6OVXw/HyVL/6ueYpGo0GvXkGCn7EHRao2fdyQHp+xJ0Uq5Xxw/9522wFg3X8P4XxLJ9o6ugKBo0mrpm0F2w+bg0dbhwEF2w8jMiIEaZf0dph1Jyc1nnOVyBYDFJGPE3u3Yrtdal6TnECgb+mw+rlJ34F/bNiDu8enSL5jMqWyO3oHxLlKZIsBiqiHcDSvSSoQbC6qRKfAaGFHpxGbiyolSyst31qGH06cc9gTVstcJdbsUw+mmRP1EI56IlLZkVLfrWtodbg672d7qh2ukqvEXCV3V+Z1dr4WeRYDFFEP4agnIpX1J/XdqMhg8zwjKcu3lkkGDXfnKikRXJydr0WexSE+oh7C0bwmraarvJRtwdixafHIyUjEysIyu2E+XYAGORmJ5mExR+oaWrHqo3JzooblMdwdWlMiyYLvwdTF7wLUrl27UFBQgEOHDuHs2bMIDQ3FwIEDcccdd2DSpEnQai90Gr/66ivMmDFDcD9TpkzB4sWLrbYZDAasXr0aGzZswM8//4yYmBhMnDgRc+bMQXAwa3qR+uR/UoGivdUwGLsCUHzfUFTX6QXbWtY6BISz+0yBBQDCQ3W475Zh+L68RnC+mpiOTiM6OjutjvHDiXNWy3+4shyGEsFFLe/BqIvfBajDhw8jMDAQU6ZMQXR0NPR6PYqKivD4449j7969WLRokd13pkyZglGjRlltu+yyy+zaPf/888jPz0d2djbuuecelJWVYdmyZThy5AiWLl3qsWsiksOyBxIeqoO+tQOdFh0KgxGiwUlMW4cBy7eWYXNRJXIyEvHGwxlWx1myaS9a2+UVlJU6hlCAc7b3Ex6qM6fBW3ImuPSENch8id8FqHvvvddu24wZM3Dfffdh48aNmD9/Pi666CKrz0eMGIFJkyZJ7vfIkSN47733kJubi2eeeca8PTo6Gq+//jqKioqQkZGhzEUQOck2hVzoQe0OsZ6Ou8FJznHlKCmtsUuDBy4MQcrFNcjUxe8ClJh+/frBYDCgqanJLkABQHNzM3Q6HYKChGc0FxYWwmg0YubMmVbbp0+fjqVLl6KwsJABirqNM0tjuMpUhNabxXrl9H5KSmtEyzEFB2qdDi6s2acefpvF19TUhLNnz+L48eNYt24dNm/ejMGDByMhwb6O27PPPouRI0di6NChGD9+PDZt2mTX5uDBg4iIiEBiovVfY5GRkRg0aBAOHjzosWshcsRbL/HdDU4adNUAlEPO0Jqp5yh2Xqb3ZeSb/LYH9dBDD2H37t0Aumo9XXXVVVi0aBE0mgv/OnQ6Ha677jpkZGQgNjYWJ0+exPr16/Hkk0/ixIkTmDdvnrltbW0t4uLiBI8VHx+P7777zrMXRCRBaqKsM4IDAxAeqhPdl5witFKMAGZlp4r2eEz7lzu05qjnyOQG36baANXQ0IA1a9bIaturVy/k5eVZbXv00UeRl5eH2tpa7Ny5E/X19WhubrZqM2rUKLvkiNzcXEydOhXLly/HbbfdhgEDBgAA9Ho9IiIiBI8fHByMlpYWuZcGAIiK8mzhy5gY4XMlab5632Zmp2HJpn1obXe9xxAcGIAHbh+OzFEDsOu7Krv9BQcG4Por+mPHtyckj6MBEN4rEI3N7XafabUaREaEYN4dvxLcv+n4cp2VCMrBgQGYmZ2m2v9N1XpeaqIxGj1VU9s9J06cwPXXXy+rbXR0NIqLiyXbPPXUU9i2bRu2bduGvn37Srb96KOPMG/ePCxevBhTpkwBAEyYMAEGgwEffvihXfvZs2fju+++w/fffy/rfAGgrq4JBg8N5sfERJiXQCD5fP2+WWbXaTQwV8sP0mlgMHald4sR6rGIzUuy3C4mPFSHtnaDYO8mSKc1T+p1NxlhwdJi0eVG8rJTVfsuydd/15Sk1WpE/2BXbQ+qf//+OHTokGL7y87OxoYNG7B9+3bk5uY6PDYA1NfXm7fFxsZi3759gu1rampEh/+IvMX0MLafjKvBb4bFiyY4REUG4+U5VwvuzzIgWaabm9rnvbgTQmGvSd+B2ROEh/JM6eMvz7na7QAilhbeHavlsoaf8vw2ScKWaQjul19+cdj2+PHjAICoqCjztvT0dDQ2NqKy0nq2fENDA44ePYq0tDQFz5bINWLVFL4uPyX67kiqJ+SofFB0n1DB75nKH7lyTNNx5dTUU8tS7qzh5xl+F6DOnDljt81oNKKgoAAAMHz4cPN2yx6SSXNzM9555x0EBgbimmuuMW8fP348NBqN3Xux/Px8dHR0YMKECUpdApHLxB78UtlsUokEjmrTzbgpRbLAq9i+pY7p7MN+bFo8Xp5zNVY+fp0ivTJXsIafZ6h2iM9V2dnZGD16NFJTUxEdHY0zZ85g27ZtqKioQHZ2Nq688kpz29mzZyMuLg6pqanmLL4tW7aguroaCxYsQL9+/cxtk5OTMW3aNKxduxbNzc0YM2YMysvLUVBQgKysLM6BIlVwNpvPUSq3o/JBmaMGoKGxRXRoy5XKDN5euFCJoTnW8PMMvwtQd955J3bv3o01a9agsbERvXr1QnJyMp577jnk5ORYtb3xxhuxY8cO5Ofno7GxEWFhYUhPT8fTTz+NzMxMu30/8cQTSEhIwMaNG7Ft2zZER0dj9uzZmDt3rpeujkiao4KwthwNh8mpTSc0sbWktMZqBV4TrQa4eqh6Fi60rcDhSg1AgDX8PEW1WXz+jll86uMv9822R9Da3ilao04oOcJS/icVdnXyLJMQLO+ZnOw+k6yRCaILGIpl5sk5X2cpdSyh1YqlkjX85XdNCT6ZxUdErrHNvhMKTnKrNBQfsH/vo9EYzRl9M7PTkHZJb8nl5IWYgp5paY/wUB2MRiPOt3QiLCQAugCNVVq8pwq2KtVbYw0/z3A5QI0bNw4pKSlITU3FkCFDkJKSgpiYGCXPjYhcJBUwTA9P4EIPQuiBKlalobW9K3B0VTPfhxm/S3apFqBlz8wyiJ5v6USA5kJ1ck8+7JUcmmMNP+W5HKBmzpyJiooK7NixA2+99RZaWloQFRVlDlam/xNatoKIPEssYJiGruS8e5HTi2ht75Q9rOeMTmNXJYjXH7rW7X1JJUG4ksTB+U7e43KAuuOOO8z/vX//fjz88MO4+uqrodPp8PXXX8TtSKgAACAASURBVGP58uXQaDQICQnBnj17FDlZIpLH0dCVnEw5uRmBpge10kFKif05CsTODs0plVRB8ijyDurJJ5/E008/bZVq/e233+Kxxx6zy5wjIs9zNHQl592L3IxA00PdmXdQciiRAScnEDszNOftFPieTpGJusePH8egQYOstl1xxRV46qmnHNbIIyLl5WQkuj2B1lSlISwkQPQ4wYEB5h6HZUWH8FCd7GU1xAxLjHLcyAGlU9Y538m7FOlBDR8+HO+//77V8hQAkJiYiPLyciUOQUROcDR05cy7l/YO4ekQUZHB5iw+0zEtexH5n1Q4XODQsqitrf2VdZLXCDh+H6T0/CTOd/IuRQLUn//8Z0yfPh3V1dWYOXMmkpOT0dbWhuXLlzusHE5EniE1dCX33YtYsoVW0xXkMkcNEJzPY0pRdzTVT2oWppx6fY7eB7mSBCFF6f2RNEUCVEpKCt5//30888wzuPXWW6HT6dDZ2YnAwEC8+OKLShyCiBRmG8BMBVotA5ZYkDAYu6qmR0aEmHtQlpxJOzelk9ty1CuR+37J1FaJrLuxafH44cQ5c89QTmUMcp1iE3UvvfRSrFixAidPnkR5eTk0Gg3S09M5N4rIB4j1RsSCB9AVDN79uBx/u2+s3WfOvJMxGo0I0mmd7pXIfR+k5Pwk256hwQgUH6jB4P69GaQ8QPFKEv369bMqskpE6ifWGwnUaeyCh6Uz9XrB7c6knZ9v6cTsCalO93LCQgIEq7S78j5I7twmZvF5lyIBqq6uDgsXLsQXX3yBoKAgbNmyBfHx/B+LyFdILdMhtvAgAGi0Gsx6cSe0mq7ehOnh7kzauWntKGce8CWlNWhtt9+3Bl2Th2e9uNOqYoZU8HFmbpNUr40TeJWnSJr5okWLcO7cObzxxhtobm5Ge3s7AGDx4sVYtmyZEocgIg8SSyU3BY+87FS7tHUA5oLHpuBl+XC3XUgwa2SCZOq7HKb3ZMu3lgkuYW/EhbJJdQ2tWFlYhlUflUuuLeXMWk5ivbPwUB0XLPQARXpQJSUlyM/Px5AhQ6DVXvgFvP766/Hyyy/j3nvvVeIwROQBYr2RgP9l6gH2yQamHpMQqSXdB/fv7bA3I/a5swVpga6SSbAJZLZDcs7MbRLL4jMajRz68wBFApRWq0VwsP1fFpdccgmqqqqUOAQRecjmokrB3khoiM7q4Wo5DDfrxZ2S+3Rl4qqjoTZXCtLKOT9n5jaJZQUu31rm8DjkPEWG+DIzM/HBBx/YbW9qakJAgPgsdCLqfmIPUbHsPcBxIoLQ546Wchcbalv330OS5+kKy/NzVHXDltAS864sbU+OKRKg5s+fj/fffx+vvvqqeZter8ebb76J1NRUJQ5BRB7iysNV6KFuIvZwd/SuRypRo6S0RvJ8xN5xBWgAnU3NJdvzsy3TFBUZ7HClYVvOBjmSR5Ehvri4OKxfvx4LFy6EXq9HTk4O9Ho9IiMjsXz5ciUOQUQe4kp1BNPD+4PdP+J0vd4ui0/o4e7oXY9UavrmokrR87QMJkLvuEzfl8quc3euFBcs9AxFAlR2djYKCgqwYsUKVFdXo6KiAjqdDiNGjEBkZKQShyAiD3H14To2LR4TMy+XvXS5o3c9wxKj7JaXN6lraJV1nmKBRupalEoP54KFylMkQP3www9oa2sDACQkJCAhIQEA0NjYiEWLFuHpp59W4jBE5CFKPlwti8RqNUDGiARMHzcEORmJWPVRuVVChi5AY+7lSBWHNQUxpYOAs+s7ca6Td7kVoO655x4MGzYMGo0GJ0+eRFSUdXl8vV6PDRs2MEARdTNvPVjzP6mw6gUZjBeWdh/cvzeMNrnplj9LJUF46l2OM5UhuFih97kVoJKSkvDNN9/AaDTi9ttvR1hYmHnJ9+TkZBw9epS1+Ii6mTcfrEV7hYfoivZWY39lne2UJHQaYQ4GYkOA4aEX0t2VDrTOzIFimSPvcytAPfbYYwCA9PR0bNiwAbW1tSgvL0dFRQWKiorQ0dGBBQsWKHKiROQabz5YxSbvGoyOg4FYEsQdNyQB8EygdWYOFBcr9D5F3kHt2bMHgYGBSEtLQ1ZWlhK7JCKFePPBKlZhQqsB+kRIBwNHSRCeCLTOZDBK9fBslylhj0oZigSob775BuHh4Rg2bJgSuyMiBXlzFdiMEQmCmXgZIxIwuH9vh8FAKglCKtBaFod1Jjg4k8EoVgD3vL7Dqv4f30spR5EA9cILLyAvL88uQB06dAjR0dF2yRNE5D3eXAV2+rghAGDO4tMACArU4LM9Xe+grh4aj/2VdXbBQM67JUdLeLgaHORmBpoWK7QNwLYdRr6XUo4iAer48eMYNWqU3fb9+/dj27ZtWLFihRKHISIXeHsS6fRxQzB93BDzOyNTIdq6hlYUH6ixq9Ig9W7J8rzDQgKgC9AI1g008WRwKCmtEU0CscX3UspQJEBFRkbi7NmzGDBggNX2UaNG4eWXX1biEETkhu6YRCr3nZFYu4Lth9HWbjB/dr6lEwEa8SXiTTwRHExBVCwJxBZr8ClDkVp8v/nNbwRLGhkMBnR22q94SUT+T25yhlSxWtvA1WkEggMDsPLx67xaoNWZSuqswaccRQLUww8/jAMHDuD+++9HWVlX2fnz58/j7bffRnJyshKHICIfIzeAOBtQLNPSvVWgVapXpgvQmBd8dKXQLIlTrFjsxo0b8de//hU5OTnQ6XTo7OxEZGQk3n77bSUOQUQ+Rm5yhli7QJ0G51vsR2DkpqUDyk3sFUvQ0GqAu8enMCB5iMZoNMocVZXn559/NheLHT58OHr37q3k7v1GXV2TeblspcXERMgu4EkX8L45z9E9kxsghNoBsAtcAZquhRSb9B0OA47QCry21c8dsTwvW87uyxJ/1y7QajWIigoX/MzlAPXggw/ipZdeQmhoKH788Udcdtllbp1kT8MApT68b87z9D2zDBBhIQFobTdYZfFJBQnT5FkhcnpTUkvMWxbBdfY6oiKDMTM7DWmX8I93QDpAuTzEFxMTg46Orkyam266CaGhoUhOTjbX4jPV4xNaCp6IyFn61k67LDqptHJ350xJJUYYjEDxgRoM7t/bYQ9KKI1+yaZ9mPG7ZA4NOuBygHrqqafM/11UVISKigpzHb6VK1eiqqoKGo0GAwcOxEcffaTIyRJRz2H7YBcbcJDqJUkFKUdzphylq8udcyUU6FrbOzmZVwbFkiTi4uKQkZFh3qbX61FRUYFDhw4pcQgi6mHkpnaLZQGKlSayJBWEHAU4R9931IaTeR1zOUCNGzcOKSkpSE1NNQ/rWS6tERoaipEjR2LkyJGKnCgR9SxyHuBSaeW2WX5CpFLc5QQ4OSny3qyF6G9cDlAzZ85ERUUFduzYgbfeegstLS2IioqyegeVkpLC5AkicklYSIBgmrmpYrqcRAdTBQ2xjD6pOVOOApzcOVdCgS44MICTeWVQJM18//79ePjhh3H11VdDp9OhrKwM+/btg0ajQUhICPbs2aPEufoVZvGpD++b8zx1z0pKa+yWhwe60sxnZae69O7G3TlR7nyfWXziPJLFZ+nJJ5/E008/bfUO6ttvv8Vjjz2GnJwcJQ5BRH5A7kN+c1GlYFHY0BCdXXu5+3S3HqE737f9Lv8YkkeRUkfHjx/HoEGDrLZdccUVeOqpp1BcXKzEIYjIx5mG2UzDZaZU75LSGru2UvX5XN0n+R5FAtTw4cPx/vvv221PTExEeXm5EocgIh8nVd3cllQCgWXwcWaf5HsUGeL785//jOnTp6O6uhozZ85EcnIy2trasHz5cvTt21eJQxCRj3Mm3TonIxHLt5YJtrecP6RUCrdSNfuUptbz8hZFelApKSl4//33cfbsWdx6660YMWIERo0ahQ8++ACPPvqoEocgIh/nzPIYUg9hy+CjxJIbah0mVOt5eZMiPSgAuPTSS7FixQrU1NSgrKwMGo0G6enpVnOjiKjncnbpeTnzh1xdzl6qCCygjmXb5S746M8UCVB1dXVYuHAhSkpKEBgYiC1btiA+vmfcQCKSx9ml5+UEH1eWs5cqAmvJG5UepIbwWIFCoQC1aNEinDt3Dq+//jr+8Ic/oL29HQCwePFixMfH495771XiMETk45xJ1ZYbfJxN/5ZbQkmrkf5ciXlVtkVkLQvYsgKFQgGqpKQE+fn5GDJkCLTaC6+1rr/+erz88ssMUETkEnfnLgmR2wORmkfvKLg4suu7KqwoLJOszu7q8KU/USRJQqvVCi6rcckll6CqqkqJQxARKUJuD0SqnTvp7SWlNViyaZ/D6uxj0+Jx101DzOfRE5eTV6QHlZmZiQ8++ADz5s2z2t7U1ISAgAAlDkFEpAg5RWBN7cS4835oc1ElWtvtawyaWAZGT/QgfYkiAWr+/Pl2JY30ej3efPNNpKamKnEIIiJFjE2Lxw8nzqFob7XkMJ5UYHDn/ZBUEOtpQ3iOKDLEFxcXh/Xr1+PgwYPQ6/XIycnB6NGj8f3332PBggVKHIKISBElpTUoPlAjGZy0GkjON8rJSESQzvrxKTe4iAUxrQY9bgjPEbd7UJ2dnfj3v/+N66+/HitWrEB1dTUqKiqg0+kwYsQIREZGKnGeRESKkJPFZzBCMunBlfR2k5yMRLy77ZDVMF+QTsvgJMDtABUQEIBFixZh9OjRuOiii5CQkICEhAQlzo2ISHFys/gcTYp19f3Q2LR4REaEYHVhaY8tYSSXIu+gRowYgR9//BEDBgxQYndERB4jZyl3E09Nis0cNYDrQcmgyDuoKVOm4LXXXmNKORGpntD7IzE9aVKsGimWxQcAEyZMwLXXXosrr7wSqampSElJQWhoqBKHICJShND7o2GJUSg+UCN7UqyzVSS4oq5rFFny/dSpU6ioqEB5ebn5/1dVVUGj0WDgwIH46KOPlDhXv8Il39WH9815/nTP5Aad/E8q8NmeaqttUkkOQrX/ggMDMON3yXzvBC8s+R4XF4e4uDirJd/1ej0qKipw6NAhJQ5BRORRcpIeSkpr7IITIJ1QIZQ12Nre2aOqkrtKseU22tracPToUQDAZZddhtDQUIwcORIjR45U6hBE5IfUuiif0HlJlTJytrpET6pK7ipFAlRxcTEeffRRnDt3DkajESEhIbj55psxf/58REVFKXEIIvJD7hZd9fZ5OZo/tWBpsV2AZVVy1ymSxbd48WKMHj0an376Kb744gu88sorOH78OG655RZm9hGRKHeKrnqS2Hk5IrTqrVDWYHBgAEsayaBIgDp58iQeeeQRDBgwAH379sUNN9yA9957D2PGjMFzzz2nxCGIyA+pdfjLnePbBlihquQP3D5cFcOYaqfIEF9SUhJOnTqFgQMHWm3/wx/+gFtvvVWJQxCRH1Lr8JfYeWk0gJy8Z9vv2iZg+FP2oycp0oPKzs7GM888g59++slq+9mzZ7u9Ft+mTZuQnJyM5ORknD9/3u7z1tZWvPrqq8jKysLQoUMxbtw4rFy5EgaDfXfeYDBg5cqVGDduHNLT05GVlYVXX30Vra182UnkCleKrpaU1mDB0mLMenEnFiwtlizq6ur3hM5LF6ABZM4M6e4A6y8U6UG9+OKLAIDx48cjKysLKSkpMBgM+PDDD/GnP/1JiUO45OzZs3jllVfQq1cvNDc3C7Z5+OGHsWvXLkydOhWpqan48ssv8be//Q0nT57EE088YdX2+eefR35+PrKzs3HPPfegrKwMy5Ytw5EjR7B06VJvXBKRX3G26KqzSRWWmXiWHH1P6Lxa2jpwvlN8HScTsQBreS4xfUIx+ZrLumWYT61Zk0IUW/LdNEG3vLwcH3/8MX788UcYjUa8+eab+O9//2vuxWRlZSlxSFleeOEFxMXFITk5Gf/5z3/sPi8qKsLOnTsxb9483H///QCA22+/HcHBwcjPz0dubi4uv/xyAMCRI0fw3nvvITc3F88884x5H9HR0Xj99ddRVFRkNQ+MiORxpuiqVFKF7T6EJsjK+Z7Yec16cafoeZmGBMUe+Lbncrpe3y3ZimrNmhTjMEC99dZbuPHGG5GYKN7l7tOnD8aOHYuxY8eat7W1teHw4cPm6hKff/45/vnPf+Lbb79V5swdKCkpwdatW7F27Vps2rRJsM3WrVuh0+lw5513Wm2fOXMm/vWvf+HDDz/Eww8/DAAoLCyE0WjEzJkzrdpOnz4dS5cuRWFhIQMUkYc5k1QhZ1kNZ5IhpN6XvTznasnvigXWFYVlALwXHJwJ8GrgMED94x//QGdnJx544AHztpaWFoSEhEh+LygoCOnp6UhPT3f/LJ3U1taGhQsX4pZbbsGoUaNEA9SBAwcwePBghIdbl9lISkpCeHg4Dh48aN528OBBRERE2AXqyMhIDBo0yKotEXmGM0kVcoKPM++KhJaKl7tIodi5OFp3SmlqzZoU41KSxD//+U+r3pKl2tpawWQEb3r77bdRX1/vcDXf2tpaxMXFCX4WFxeHU6dOyWobHx9v1ZaIPMOZpApHwcfZ5dWF0sXlLjIodS7enPcldh5qTepw+R3UuXPnBLdv3LgRb731FkpLS10+KQBoaGjAmjVrZLXt1asX8vLyAABHjx7FsmXL8Je//AV9+/aV/F5LSwuCgoIEPwsODkZTU5P5Z71ej4iICNG2LS0tss7VRKw4olJiYoTPlaTxvjnPm/dsYmYEIiNC8O7H5ThTr0d0n1DMuCkFmaPs16KbmZ2GJZv2Wa1caxIj8T1Hx5+YebnT5y11LgBwtqHVK/dR6DyCAwMwMztNlb/7itXisySUou2shoYGLFmyRFbb6Ohoc4BauHAhkpKSMHXqVIffCwkJQVtbm+Bnra2tVsOYoaGhstvKwWrm6sP75rzuuGdpl/TG3+6zHsEROoe0S3pjxu+SJTPW5Jy7WNabM9lwpnNZUVgGoX/2fSODvXIfxe5J2iW9u+133+PVzD2hf//+TldC//TTT/HVV1/h73//u1WJJdOQY1VVFSIiInDxxRcDAGJjY0WH5k6dOmVV6DY2Nhb79u0TbFtTUyM6/EdE3cfVZdlNxLLefjhxzmr9KDnZcKbtrr7HUoq798SbVBugXFFd3VUG/5FHHhH8fNKkSbj44ouxc2dXuujQoUOxbds2NDU1WSVKHDlyBE1NTUhLSzNvS09Px+7du1FZWWmVKNHQ0ICjR4/ipptu8sQlEVE3Est6K9pbbdcTkpMNZ/rsg90/4nS9XvXzkLqbrACl0Wg8fR6KyMrKQny8/f/Qa9euxddff42XXnoJvXtfWMUyOzvbnIp+3333mbevWrUKGo0GN998s3nb+PHj8c4772DNmjVYvHixeXt+fj46OjowYcIED10VEXUXqew7Z9pbGpsWj4mZl3M4WQZZAeqtt97Cjh07MHToUAwdOtTcU1GbgQMH2tUDBIBdu3YBAG644QaEhYWZt2dmZiIrKwuvvfYaTp06hZSUFHz55ZcoLCzEnXfeiaSkJHPb5ORkTJs2DWvXrkVzczPGjBmD8vJyFBQUICsri3OgiPyQWFq7ViMcpNSaDeerHAaoq666CqWlpSgrK0NZWRk2btxo/uz3v/89UlJSkJycjCFDhiA5OdmjJ+sJr732Gt58801s3boVGzduREJCAhYsWIC7777bru0TTzyBhIQEbNy4Edu2bUN0dDRmz56NuXPndsOZE5Gnic19unpovNU7KNN2LqGhLI3RKKc2b1eCwcGDB83/V1ZWhsbGri6qaQgwICAAoaGhaGpqQnl5uefO2g8wi099eN+c1xPumRJZfLZ6wn2TSyqLT3aAEnLs2DG7oNXc3AyNRsMA5QADlPrwvjlPDfdMKlA4CiKeLpwqtn813De18Fia+aWXXopLL70U2dnZAACj0YijR4/iwIED7uyWiEgWqeKnACQLo3q6cKrU/idmqm9SrBopsh6UiUajQWJiIiZPnqzkbomIBEkVP3W0nLynl5tX63L2vsSv5kERUc/iSvFT02eeLpzqa4VZ1YgBioh8lqPq5lKfeWK5ect3TkxFd5+iQ3xERN4kVd3cUeVzV5abl2J652QKekLBianozmEPioh8lpwl48U+c3a5eUfkLJAYqPONqjxq4VaaObmOaebqw/vmPN6zC6SWhLcUpNPiwdwRSLukt+PGPYBUmjmH+IiIFCD33VJbhwHvfsx5onIwQBERKcCZd0tn6vUePBP/wQBFRKSAsWnxCAsJkNU2uk+oh8/GPzBAEREpZNpvk+0yA20F6bSYcVOKl87ItzFAEREpZGxaPO66aYjVXKuskQlWP9910xBkjhrQnafpM5hmTkR+z9NFYS350pLqascARUR+Lf+TCny258Iiq0oXhTXxZhDsKTjER0R+q6S0xio4mShdtNW2ioQpCJaU1ih2jJ6IPSgi8ltSQUjJoq1SlcvV1ovypZ4eAxQR+S2pIKRk0VZfqVzu6TWwlMYhPiLyW1JBSMmirWLHUVvlcl9bo4oBioj8llDFcgDIGpmgaI9B6cronuIrPT0TDvERkd9SumJ5dx/HXZ5YA8uTGKCIyK95a16SL8x/yslItHoHBaizp2fCAEVE1EP4Sk/PhAGKiKgH8YWengmTJIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJUYoIiISJW4oi4RkZeUlNZgc1Elzja0oq/Kl1tXAwYoIiIvKCmtwZqPK9DWYQAA1DW0Ys3HFQDAICWCQ3xERF6wuajSHJxM2joM2FxU2U1npH4MUEREXlDX0OrUdmKAIiLyiqjIYKe2EwMUEZFX5GQkIkhn/cgN0mmRk5HYTWekfkySICLyAlMiBLP45GOAIiLykrFp8RibFo+YmAicPt3Y3aejehziIyIiVWIPiojIB5km/dY1tCLKT4cLGaCIiHxMT5n0yyE+IiIf01Mm/TJAERH5mJ4y6ZcBiojIx/SUSb8MUEREPqanTPplkgQRkY+xnPTLLD4iIlIV06RfR3w5HZ0BiojIT/l6OjrfQRER+SlfT0dngCIi8lO+no7OAEVE5Kd8PR2dAYqIyE/5ejo6kySIiPyUr6ej+32A2rRpE5588kkAwPfff4+wsDDzZ1999RVmzJgh+L0pU6Zg8eLFVtsMBgNWr16NDRs24Oeff0ZMTAwmTpyIOXPmIDjYN7rMRNSzyE1HVyO/DlBnz57FK6+8gl69eqG5uVm03ZQpUzBq1CirbZdddpldu+effx75+fnIzs7GPffcg7KyMixbtgxHjhzB0qVLFT9/IqKezK8D1AsvvIC4uDgkJyfjP//5j2i7ESNGYNKkSZL7OnLkCN577z3k5ubimWeeMW+Pjo7G66+/jqKiImRkZCh27kREPZ3fJkmUlJRg69atePrppxEQEOCwfXNzM9ra2kQ/LywshNFoxMyZM622T58+HTqdDoWFhe6eMhERWfDLANXW1oaFCxfilltusRu6E/Lss89i5MiRGDp0KMaPH49NmzbZtTl48CAiIiKQmGid/RIZGYlBgwbh4MGDip0/EfVcJaU1WLC0GLNe3IkFS4tRUlrT3afUbfxyiO/tt99GfX09FixYINlOp9PhuuuuQ0ZGBmJjY3Hy5EmsX78eTz75JE6cOIF58+aZ29bW1iIuLk5wP/Hx8fjuu+8UvQYi6nl8vTSR0lQboBoaGrBmzRpZbXv16oW8vDwAwNGjR7Fs2TL85S9/Qd++fSW/N2rUKLseVm5uLqZOnYrly5fjtttuw4ABAwAAer0eERERgvsJDg5GS0uLrHM1iYoKd6q9s2JihM+VpPG+OY/3zDVC9+2D3SWCpYk+2P0jJmZe7q1TUw1VB6glS5bIahsdHW0OUAsXLkRSUhKmTp3q0nEDAwORl5eHefPm4YsvvsCUKVMAAKGhoaLvqFpbWxESEuLUcerqmmAwGF06R0diYiJw+nSjR/btz3jfnMd75hqx+3a6Xi/Y/nS93m/vs1arEf2DXbUBqn///jh06JBT3/n000/x1Vdf4e9//zuqqqrM28+fPw8AqKqqQkREBC6++GKHxwaA+vp687bY2Fjs27dPsH1NTY3o8B8RkVxRkcGCdfJ8pTSR0lQboFxRXV0NAHjkkUcEP580aRIuvvhi7Ny5U3I/x48fBwBERUWZt6Wnp2P37t2orKy0SpRoaGjA0aNHcdNNN7l7+kTUw+VkJFq9gwJ8qzSR0vwqQGVlZSE+3v5F4tq1a/H111/jpZdeQu/evc3b6+vr0adPH6u2zc3NeOeddxAYGIhrrrnGvH38+PF45513sGbNGqsKE/n5+ejo6MCECRM8cEVE1JP4emkipflVgBo4cCAGDhxot33Xrl0AgBtuuMGq1NHs2bMRFxeH1NRUcxbfli1bUF1djQULFqBfv37mtsnJyZg2bRrWrl2L5uZmjBkzBuXl5SgoKEBWVhYn6RKRIny5NJHS/CpAOevGG2/Ejh07kJ+fj8bGRoSFhSE9PR1PP/00MjMz7do/8cQTSEhIwMaNG7Ft2zZER0dj9uzZmDt3rvdPnojIz2mMRqNnUslIErP41If3zXm8Z67hfbtAKovPLytJEBGR72OAIiIiVWKAIiIiVWKAIiIiVerRWXxEROSektIaj83bYoAiIiKXeLr6Oof4iIjIJZuLKgWrr28uqlRk/wxQRETkEqHCtlLbncUARURELhGrsq5U9XUGKCIicklORiKCdNZhRMnq60ySICIil3i6+joDFBERucyT1dc5xEdERKrEAEVERKrEIT4i8huerGpA3scARUR+wdNVDcj7OMRHRH7B01UNyPsYoIjIL3i6qgF5HwMUEfkFT1c1IO9jgCIiv+DpqgbkfUySICK/4OmqBuR9DFBE5Dc8WdWAvI9DfEREpEoMUEREYMqQ5AAAECxJREFUpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEoMUEREpEq67j4BIiLqfiWlNdhcVIm6hlZERQYjJyMRY9Piu/WcGKCIiCyo8UHtaSWlNVjzcQXaOgwAgLqGVqz5uAIAuvXaOcRHRPQ/pgd1XUMrgAsP6pLSmm4+M8/aXFRpDk4mbR0GbC6q7KYz6sIARUT0P2p9UHuaKSDL3e4tDFBERP+j1ge1p0VFBju13VsYoIiI/ketD2pPy8lIRJDOOhwE6bTIyUjspjPqwgBFRPQ/an1Qe9rYtHjcddMQcyCOigzGXTcN6fbkEGbxERH9j+mB3NOy+ICua1fbdTJAERFZUOODuqfiEB8REakSAxQREakSAxQREakSAxQREakSAxQREakSAxQREakSAxQREakSAxQREakSAxQREakSAxQREakSSx11E61W49P791e8b87jPXMN71sXqfugMRqNRi+eCxERkSwc4iMiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlVigCIiIlViLT4ftGvXLhQUFODQoUM4e/YsQkNDMXDgQNxxxx2YNGkStFrrvztOnjyJt956C1988QVqa2vRp08fDB8+HPfddx/S0tLs9r9582asXr0aP/74Iy666CL89re/xbx58xAZGemtS/QIZ+7b448/ji1btoju6/bbb8ezzz5rtc0f75uzv2sAUFVVhaVLl2L37t2or69Hnz59MGzYMCxatAjR0dFWbf3xngHO3bevvvoKM2bMENzPlClTsHjxYqttBoMBq1evxoYNG/Dzzz8jJiYGEydOxJw5cxAcHOzR6/I2BigfdPjwYQQGBmLKlCmIjo6GXq9HUVERHn/8cezduxeLFi0yt62rq8Ott96Kjo4OTJ06FQMGDMDJkyexfv16fPbZZ1i/fr1VkFq9ejVeeOEF/OY3v8Gdd96Jn376CWvWrMGBAwewbt06BAUFdcclK8KZ+zZlyhSMHTvWbh9btmxBSUkJMjMzrbb7631z5p4BwL59+zBr1iz069cPv//97xEdHY2zZ89iz549aGpqsgpQ/nrPAOfvG9D1Ozdq1CirbZdddpldu+effx75+fnIzs7GPffcg7KyMixbtgxHjhzB0qVLPXZN3cJIfuPee+81DhkyxHju3DnztlWrVhmTkpKM27dvt2r77bffGpOSkozPPvuseVtdXZ1x+PDhxlmzZhkNBoN5+5YtW4xJSUnGtWvXev4iuoHQfRPS2dlpvPbaa42//vWvjW1tbebtPfG+Cd0zvV5vzMrKMubl5VndHyE98Z4ZjcL37csvvzQmJSUZ//Wvfzn8/uHDh43JycnGJ5980mr7kiVLjElJScZdu3Ypfs7die+g/Ei/fv1gMBjQ1NRk3tbY2AgAiImJsWobGxsLAAgNDTVv27FjB/R6PWbMmAGN5kIJ/AkTJiAqKgqFhYWePP1uI3TfhHzxxReoqanBhAkTEBgYaN7eE++b0D378MMP8fPPP2PBggUIDAyEXq9He3u74Pd74j0DHP+uNTc3o62tTfT7hYWFMBqNmDlzptX26dOnQ6fT+d19Y4DyYU1NTTh79iyOHz+OdevWYfPmzRg8eDASEhLMba666ioAwDPPPINvv/0Wp06dwp49e/CXv/wF0dHRmDJlirntgQMHAAAjR460Ok5AQACGDRuGsrIyGP1gdRY5903I5s2bAQA5OTlW23vCfZNzz/7f//t/CA8PR0NDAyZNmoQRI0Zg2LBhmDZtGvbv32+1v55wzwDnfteeffZZjBw5EkOHDsX48eOxadMmuzYHDx5EREQEEhMTrbZHRkZi0KBBOHjwoMeupTvwHZQPe+ihh7B7924AgEajwVVXXYVFixZZ/UU6atQoLFy4EK+99hp+//vfm7cPGTIEGzduxMUXX2zeVltbi9DQUMEX1PHx8dDr9fjll1/Qu3dvD16V58m5b7YaGxuxfft2pKWlYciQIVaf9YT7JueeHTt2DJ2dnZg9ezZ+97vfYc6cOfj555/x1ltvYcaMGdi0aRMuv/xyAD3jngHy7ptOp8N1112HjIwMxMbGmt8RP/nkkzhx4gTmzZtnbltbW4u4uDjBY8XHx+O7777z7AV5GQNUN2poaMCaNWtkte3Vqxfy8vKstj366KPIy8tDbW0tdu7cifr6ejQ3N9t9Nzo6GpdffjmuuuoqJCcno6qqCv/85z+Rl5eHd9991zzcp9frRV9Mm7KDWlpanLlEj/DWfbP04YcforW11a73BPjGffPGPTt//jz0ej0mTJiAF1980bw9LS0NM2bMwJtvvonXXnsNgG/cM8A7923UqFF2yRG5ubmYOnUqli9fjttuuw0DBgwA0HXfIiIiBI8fHBysinumJAaobtTQ0IAlS5bIahsdHW33y5+SkmL+78mTJ+Opp57C9OnTsW3bNvTt2xcA8Omnn+LBBx/EihUrcM0115jbX3311Zg8eTL+8Y9/4LnnngPQ9T5KbPy7tbUVABASEiL/Aj3EG/fN1pYtWxAYGIjs7Gy7z3zhvnnjnpmu0TaIjxkzBgkJCfj666/N23zhngHd87sGAIGBgcjLy8O8efPwxRdfmIfiHd03NdwzJTFAdaP+/fvj0KFDiu0vOzsbGzZswPbt25GbmwsAePfddxEWFmYVnADg8ssvx6BBg/DNN9+Yt8XGxkKv16OhocFu6KWmpgahoaG46KKLFDtfV3njvlk6evQo9u7di3HjxgkOOfnCffPGPYuNjcXhw4ft5joBXUk6ZWVl5p994Z4B3v9dsz02ANTX15u3xcbGYt++fYLta2pqRIf/fBWTJPyIqXv/yy+/mLedPn0aRqNR8IVzR0cHOjo6zD8PHToUALBnzx6rdgaDAQcOHEBKSorkexpfJXTfLJmSI2699VbBz3vifRO6Z8OGDQPQ9aC0VVNTY9Vj6In3DHD8u2bp+PHjAICoqCjztvT0dDQ2NqKystKqbUNDA44ePSo48d6XMUD5oDNnzthtMxqNKCgoAAAMHz7cvH3w4MFobm7GJ598YtV+3759OHbsGNLT083brr/+eoSEhODdd9+1avuf//wHZ86cERze8iXO3DeTzs5O/Pvf/0ZMTIxdL9TEn++bM/csOzsbWq0W69evt2q/c+dOnDp1Ctdee615mz/fM8C5+2bZQzJpbm7GO++8g8DAQKvfu/Hjx0Oj0di9F8vPz0dHRwcmTJig1CWoAof4fFB2djZGjx6N1NRUREdH48yZM9i2bRsqKiqQnZ2NK6+80tz2vvvuw+eff45HH30U33zzDZKSklBVVYV169YhODgYc+bMMbft27cv/vjHP+Kll17C7NmzceONN+Knn37C6tWrkZaWhttvv707Llcxztw3k+LiYtTW1mL27NkICAgQ3K8/3zdn7lliYiLuvvturFixArNnz0ZmZiaqq6vx3nvvITo6Gg888IC5rT/fM8C5+zZ79mzExcUhNTXVnMW3ZcsWVFdXY8GCBejXr5+5bXJyMqZNm4a1a9eiubkZY8aMQXl5OQoKCpCVlYWMjIzuuFyP0Rj9YbJBD7NkyRLs3r0bx44dQ2NjI3r16oXk5GRMnjwZOTk5dvXRfvjhByxduhT79u3DqVOnEB4ejtGjR2Pu3Ll2KdMA8P7772PNmjU4duwYIiMjccMNN2D+/PmqeCfgDmfvGwDMmzcPH330ET766CO7uSe2/PG+OXvPjEYj1q1bh4KCAhw7dsz8/nP+/PlWUxpM/PGeAc7dt2XLlmHHjh04fvw4GhsbERYWhvT0dMyYMcOupBbQ1atftWoVNm7ciOrqakRHR2PixImYO3eu39XiY4AiIiJV4jsoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJQYoIiJSJdbiI/IRs2bNQnFxsWSbP/7xj5g7d66XzojIsxigiHxEaWkpdDod7r//ftE248aN8+IZEXkWAxSRD6iqqsK5c+eQmpqKBx98sLtPh8gr+A6KyAccOHAAwIWF/oh6AgYoIh9w8OBBAAxQ1LNwiI/IB5gC1Pfffy+4pDoA5OXloVevXt48LSKP4npQRCpnNBoxevRoNDY2irbp3bs3vvrqK/PPr7zyCkpLS7Fq1SrZx3n11Vfx/fffIz8/3+VzdeW4RGLYgyJSOdOqrL/61a9QUFAg6zvl5eWCqyU7+k5KSoorp+jWcYnE8B0UkcqZhvecCR4VFRVITU116jjl5eVOf0eJ4xKJYYAiUjlTgJLbMzl9+jTOnDljFdCWLl2KCRMmYOTIkfj1r3+Nxx9/HC0tLebP6+rqUFtbC61Wi7vuugvDhw/HpEmTsH//fnObU6dO4bHHHsOYMWNwxRVX4MEHH8SZM2ckj0vkDgYoIpVzNkCVl5cjJCQEl112mXlbZ2cnFi5ciMLCQvzf//0fiouLsWbNGqvvAMCqVaswd+5cbNmyBfHx8XjooYfQ0dGBqqoq3HLLLYiLi8O6deuQn5+P+vp6PP3005LHJXIH30ERqZjBYEBZWRkCAgKQlJQk6zsVFRVISkpCQECAeZvl5N6LL74YmZmZOHr0qHlbeXk5AgMD8cYbb6B///4AgAULFuDmm2/GTz/9hOeeew633XYb5s+fb/7OnDlz8MADD0gel8gdDFBEKnb06FE0NzcjPDwcy5cvF22Xm5uLuLg4APbJDidPnsSKFSvw1Vdf4dSpU2hvb0dbWxvuuecec5vy8nL89re/NQcnAIiMjAQAnDlzBrt378a3335rleHX2dmJ0NBQq31weI+UxABFpGKm4b2mpiYsWbJEsI1Wq8Xdd99t/rm8vBwzZswAANTX1+O2227DFVdcgcceewzx8fHQarW47bbbrIYMy8vLkZOTY7Xfffv2oVevXmhsbER4eDg2b95sd+zAwEDB4xIpgQGKSMUmT56MyZMny26v1+tx/PhxcyZdUVERWltb8dprr0Gj0QAAtmzZgubmZnNvR6/X49ixYzAYDOb9GI1GrFy5EhMnToROp4Ner0d0dDTCwsJkHZdICQxQRH7k0KFDAIDk5GQAXRN4m5ubsX37diQlJeHzzz/HO++8g7CwMAwcOND8Ha1Wiw8++ABjxoxBnz598MYbb+DkyZN48803ERAQgMjISDz22GOYO3cuwsPDUVVVhe3bt+Ovf/0rtFqt3XGJlMAAReRHysvLMXDgQPO7oYyMDEyZMgV/+tOfEBwcjJtvvhkTJkzA3r17zT2q8vJyDBgwAPPmzcPDDz+M+vp6XHvttdi4cSP69u0LAFi+fDlefvllzJgxA52dnRgwYABuvvlmaLVaweMSKYGljoiISJU4D4qIiFSJAYqIiFSJAYqIiFSJAYqIiFSJAYqIiFSJAYqIiFSJAYqIiFSJAYqIiFTp/wNz+LJ0QH6U9AAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q55oI3EAAeyN"
},
"source": [
"Despite the fact that the neural network is untrained we see that the outputs of the graph network correlate strongly with the labels. This hints that perhaps graph networks provide some sort of \"deep molecular prior\". \n",
"\n",
"Next, we define losses for the energy and the force as well as a total loss that combines the two terms. We fit both the force and the energy using Mean-Squared-Error (MSE) loss."
]
},
{
"cell_type": "code",
"metadata": {
"id": "pwxERR68AeoB"
},
"source": [
"@jit\n",
"def energy_loss(params, R, energy_targets):\n",
" return np.mean((vectorized_energy_fn(params, R) - energy_targets) ** 2)\n",
"\n",
"@jit\n",
"def force_loss(params, R, force_targets):\n",
" dforces = vectorized_force_fn(params, R) - force_targets\n",
" return np.mean(np.sum(dforces ** 2, axis=(1, 2)))\n",
"\n",
"@jit\n",
"def loss(params, R, targets):\n",
" return energy_loss(params, R, targets[0]) + force_loss(params, R, targets[1])"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "seSkwLZqBIrK"
},
"source": [
"Now we create an optimizer using ADAM with gradient clipping. We will also write helper functions to perform a single update step and perform an entire epochs worth of updates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "xhgLDCFSAcdI"
},
"source": [
"opt = optax.chain(optax.clip_by_global_norm(1.0),\n",
" optax.adam(1e-3))\n",
"\n",
"@jit\n",
"def update_step(params, opt_state, R, labels):\n",
" updates, opt_state = opt.update(grad(loss)(params, R, labels),\n",
" opt_state)\n",
" return optax.apply_updates(params, updates), opt_state\n",
"\n",
"@jit\n",
"def update_epoch(params_and_opt_state, batches):\n",
" def inner_update(params_and_opt_state, batch):\n",
" params, opt_state = params_and_opt_state\n",
" b_xs, b_labels = batch\n",
"\n",
" return update_step(params, opt_state, b_xs, b_labels), 0\n",
" return lax.scan(inner_update, params_and_opt_state, batches)[0]"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHkWBfh7CHGt"
},
"source": [
"Finally, we will write a function that creates an epoch's worth of batches given a lookup table that shuffles all of the states in the training set."
]
},
{
"cell_type": "code",
"metadata": {
"id": "We1uelXUBaXj"
},
"source": [
"dataset_size = positions.shape[0]\n",
"batch_size = 128\n",
"\n",
"lookup = onp.arange(dataset_size)\n",
"onp.random.shuffle(lookup)\n",
"\n",
"@jit\n",
"def make_batches(lookup):\n",
" batch_Rs = []\n",
" batch_Es = []\n",
" batch_Fs = []\n",
"\n",
" for i in range(0, len(lookup), batch_size):\n",
" if i + batch_size > len(lookup):\n",
" break\n",
"\n",
" idx = lookup[i:i + batch_size]\n",
"\n",
" batch_Rs += [positions[idx]]\n",
" batch_Es += [energies[idx]]\n",
" batch_Fs += [forces[idx]]\n",
"\n",
" return np.stack(batch_Rs), np.stack(batch_Es), np.stack(batch_Fs)\n",
"\n",
"batch_Rs, batch_Es, batch_Fs = make_batches(lookup)"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZCqMAUvCVS6"
},
"source": [
"We're now ready to train our network. We'll start by training for twenty epochs to make sure it starts training."
]
},
{
"cell_type": "code",
"metadata": {
"id": "DrmKVE6lCQz_",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 460
},
"outputId": "5249d254-58a9-4588-ab06-80a5a5e7db91"
},
"source": [
"train_epochs = 20\n",
"\n",
"opt_state = opt.init(params)\n",
"\n",
"train_energy_error = []\n",
"test_energy_error = []\n",
"\n",
"for iteration in range(train_epochs):\n",
" train_energy_error += [float(np.sqrt(energy_loss(params, batch_Rs[0], batch_Es[0])))]\n",
" test_energy_error += [float(np.sqrt(energy_loss(params, test_positions, test_energies)))]\n",
" \n",
" draw_training(params)\n",
"\n",
" params, opt_state = update_epoch((params, opt_state), \n",
" (batch_Rs, (batch_Es, batch_Fs)))\n",
"\n",
" onp.random.shuffle(lookup)\n",
" batch_Rs, batch_Es, batch_Fs = make_batches(lookup)"
],
"execution_count": 19,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1cAAAGoCAYAAACqmR8VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeXxc5Xk3/N99zuybRqMZSZb3VbYxMQ4xNCRgDCEEIwhP2phAcXBeG5yyNhRSUvclISQ0KQ9vCBDnUz8pxRjix27qJuA4oqHGSg1JCGYJwvuCwcZaRpp9zqznvH+MJCwkWSNpZs5I8/v+ZY7uc+7ryFj2Nfd9X5fQNE0DERERERERjYmkdwBEREREREQTAZMrIiIiIiKiAmByRUREREREVABMroiIiIiIiAqAyRUREREREVEBMLkiIiIiIiIqAIPeAZS7QCAGVWW1+rGqqXGgqyuqdxjjHr+PhcPv5dhJkkB1tV3vMGiERvv3WqX9meH7Tmx834lttO9biL/XmFwNQ1U1JlcFwu9jYfD7WDj8XlIlGsvfa5X2Z4bvO7HxfSc2vd6X2wKJiIiIiIgKgMkVERERERFRATC5IiIiIiIiKgAmV0RERERERAXA5IqIiIiIiKgAmFwREREREREVAJMrIiIiIiKiAmByRUREREREVABsIkxERFQAu3fvxpYtW3Dw4EF0d3fDarVi+vTpuOGGG/DFL34RktT/88xkMokNGzbg+eefh9/vR0NDA66//nqsXr16wFhVVfH0009j69atOHXqFHw+H6699lrcdtttMJvNpXxNIiI6CyZXREREBXDo0CEYjUZcf/318Hq9UBQFLS0tuP/++/HWW2/hwQcf7Df+b//2b7F792585StfwcKFC/GHP/wBP/zhD3H69GmsX7++39iHH34YmzdvRlNTE9auXYt9+/Zh48aNOHz4MDZs2FDK1yQiorNgckVERFQAt95664BrX/3qV7Fu3Tps27YN99xzD6qqqgAALS0t2LVrF77xjW/g61//OgDgy1/+MsxmMzZv3oyVK1di7ty5AIDDhw/j2WefxcqVK/HQQw/1Pdvr9eLxxx9HS0sLli1bVoI3JCKi4fDMFRERURFNmjQJqqoiGo32XXvhhRdgMBhw00039Ru7evVqaJqGX//6133XduzYAU3TsHr16n5jV61aBYPBgB07dhQ1fiIiyh9XrsqYqmoAAEkSOkdCRET5ikajSKVSiEQieOWVV7B9+3bMmTMHDQ0NfWPeeecdzJkzBw6Ho9+98+bNg8PhQGtra9+11tZWOJ1OzJ49u99Yl8uFWbNm9RtLRET6YnJVpjRNw84tWyEySVx98816h0NERHm6++67sWfPHgCAEAIXXXQRHnzwQQjx0QdlHR0dWLp06aD319XVob29vd/Yurq6QcfW19dj7969BYyeiGh8S77xPJTGTwD2GbrMz+SqTB1468/4bPRFJDUj4soNsFlNeodERFQRwuEwNm3alNdYm82GNWvW9Lt27733Ys2aNejo6MCuXbsQCAQQj8f7jUkkEjCZBv+5bjab+20hVBQFTqdzyLGJRCKvWHvV1DiGHzQEn2/wOCYqvu/ExvedeIJ/fAGR17cjpsbgu+pcXWJgclWG0skkzH96GgIabFIK7x/cj/nnLdY7LCKiihAOh/Hkk0/mNdbr9Q5IrhYsWND36+uuuw4PPPAAVq1ahebmZng8HgCAxWJBKpUa9JnJZBIWi6Xvv61Wa95j89HVFe3bdj4SPp8TnZ2REd83XvF9Jza+78STPvQKErufhmHG+ai5cs2o3leSxJg+gAKYXJWlY795Bg0I4PS8v8Kkw79A+HgrwOSKiKgkpkyZgoMHDxbseU1NTdi6dSteeuklrFy5EgBQW1vbb+vfmdrb27FkyZK+/66trcXbb7896Ni2trYhtwwSEVWKzIm3kGj5V8iTF8Jy+dchJFm3WFgtsMxEjr+D+vZXsM90LuZeejUCogrmriN6h0VERKPUu20vFAr1XTv33HNx9OjRftv/gFzZ9Wg0inPOOafv2qJFixCJRHD06NF+Y8PhMI4dO9ZvLBFRJRJOLwzTFsN6xZ0QslHXWJhclREtpSC+62foUp2Y8oXVEEIg6pqF+swpJBJpvcMjIqKz8Pv9A65pmoYtW7YAABYv/mgHQlNTE9LpNJ577rl+4//t3/4NQghcffXVfddWrFgBIcSAc2CbN29GJpPBNddcU8jXICIaN9R4EJqmQfZMgfXKuyFMVr1D4rbActL98jOwZMLYO3kVVtTXAAAsUxfCFnoT7x/aj3mf+ITOERIR0VCampqwdOlSLFy4EF6vF36/H83NzThw4ACamppwwQUX9I299NJLsXz5cjz22GNob2/HggUL8Ic//AE7duzATTfdhHnz5vWNbWxsxI033ojnnnsO8XgcF154Ifbv348tW7Zg+fLlbCBMRBVJDbUh/vzDMJ7zOZg/ea3e4fRhclUm0sf3wnTi99id+QSWX3Fx3/X6+ech2/ocwsdaASZXRERl66abbsKePXuwadMmRCIR2Gw2NDY24vvf/z6+9KUvDRj/2GOP4Sc/+QleeOEFbNu2DQ0NDbjvvvvwta99bcDY9evXo6GhAdu2bUNzczO8Xi9uueUW3H777aV4NSKisqLGAoj/+hFA02CcNXhbC70ITdNGXjKogoy2qtJIqEoY4f/7LZxWTDj1qbvwuQtm9vv6+xvvRlD24BNrvl3UOIqpEqrUlAK/j4XD7+XYFaKqEpUeqwXmh+87sfF9xy8tEUX8hX+CGu2Crel+yL4ZA8aM9n0L8fcaz1zpTNM0KL/7N2gpBc2GK3Dp+dMHjIm4ZqIucwrJJM9dEREREVFl0jQVyn89DjXUDuuVdw+aWOmNyZXOModfgXriTexQluBzn7sQBnngb4l5ykLYRAonDx3QIUIiIiIiIv0JIcF47udhufxvYGhYMPwNOmBypSM14ofyyrM4nq1D16TP4txZNYOOm7TwPABA6Ng7pQyPiIiIiEh3mqoi23EMAGCc+SkYZ56vc0RDY3KlE01Tkdj9M2QyKp6LfQbXf27ekGPtnlp0owoG/+ESRkhEREREpC9N05B8ZTPiv/oesoFTeoczLCZXOkm/81tkTx/AL6LnY/F58zGpxn7W8RFn7txVOp0pUYRERERERPpKvb4d6f0vw/SJL0Cunqx3OMNicqWDbPcpJP/073hPnol3xHxc+5mZw95jmryA566IiIiIqGKk3nkRqTdfgHH+JTBd8GW9w8kLk6sS07IZJF7eiIxkxs/8n8J1F8+Gw2oc9r7ec1fBozx3RUREREQTW7b9CJK/3wLDjPNh/uxqCCH0DikvbCJcYqk3fgW16wR+qX4ezhovLl3SkNd9Tm8dTvDcFRERERFVAKl2NizL1sAw5y8gpPGzHjR+Ip0Asu1HkHprB9qrz8P/BOvxlcvnQB7B/yxhx0zUpk/y3BURERERTUiZtsPIBj+EEALGxosh5OF3eJUTJlcloqWTUHb/H2jWamz44Bwsnl2DRTMHL70+FOPk+bCJFD48crBIURIRERER6SPrPwHlN/8fkr97Gpqm6R3OqDC5KpHkH7dBC7Xjf+xXIpIx4PrL5474GZMW9J67+nOhwyMiIiIi0o0aaoPym0chTFZYLls3bs5YfRyTqxLInGxFet9/IzHrUmw/aMTl509Bvcc24udU1dajG1WQO3nuioiIiIgmBjUWQPzXjwCaBuvV90JyjGx3VzlhclVkWjKGxO6fQbgbsKltIexWI679zIxRPy9kn4Ha1ElkM9nCBUlEREREpJPU3l9CS8ZgveoeyO78ir2VKyZXRZbYsxmaEsF7s/4K+05G8b8umQWbZfQH8wyTF8AqUjh9lOeuiIiIiGj8M1/017Bdcz9k3/C9X8sdk6siSh/9IzJH/wDDkmvwzJ8SmOKz45LFk8b0zPr5iwEAgcM8d5WvZCqLjS+8i+5wQu9QiIiIiAi53q/J1/4dWjIGYTBB9s7QO6SCYHJVJFoiisSeZyD5ZmFX6lz4Qwl85fK5Iyq9PhhP/SR0aVWQeO4qbyfeO4XPn/wpjra26h0KERERUcXTVBWJlzci9davkTm1T+9wCorJVZFkPtwPJGPInPeX2PGHk1gy14uFMzwFeXbIPgO+1AfIZnnuKh/xtvdQI0ch/Mf0DoWIiIioommahuQrzyBz7DWYL1wJ46yleodUUEyuiiTbfgSQjdi+T0Mmo2LlZXMK9mxDw3xYRQrtRw8V7JkTWSrYAQDQYgGdIyEiIiKqbKnXtyO9fzdMi1fAtHiF3uEUHJOrIsl2HEW6air2vNOBKz41FXXVIy+9PpTa+bl+V91HeO4qH2q0CwAgJUM6R0JERERUubRkDOlDr8DYeAlMF3xZ73CKwqB3ABORls1A9b+HVnEuHDYjmi6aUdDneyfV44TmgtTBlat8yIncipUpHdY5EiIiIqLKJcx22P7XAxAW57htEjycCblylUql8I//+I+47LLLsGTJElx11VXYvn17yeZXu04A2QzeDLhw3cWzYLMUNocVQiBomwFv8gOoPHc1LHMqt2Jly0Z1joSIiIio8mROvIXEK5uhqSokmxtCkvUOqWgm5MpVJpOBz+fD008/jalTp+Ltt9/GLbfcgkmTJuHTn/500efPth8FALyX8eH/afQVZQ65YT6sR/+MjuOHUT9nflHmmAgyWRVOLQIAcCKGdCYLo2Hi/oEmIiIiKieZ0wehvPQTSJ4pQDYNSGa9QyqqCblyZbPZcPfdd2PatGkQQuC8887DhRdeiL1795Zk/mz7EUSFEw6vD06bqShz1PWeu2K/q7PqDsXhluLIQIZVSiMQ4NZAIiIiolLI+k9AaX4MktML61X3QBgndmIFlEFyFYvF8OSTT+LrX/86PvvZz6KxsRF33XXXkONVVcVTTz2FK6+8EosWLcLy5cvxox/9CMlkcsh7kskk/vznP2Pu3LnFeIUBsu1HcTRVg/lTq4s2h69hEvyaC4Lnrs6qu6MTBqEiYm0AAEQ623WOiIiIiGjiU0NtUH7zKITJCuuKeyFZnHqHVBK6J1eBQABPPPEEWltbsWjRomHHP/zww/jhD3+IRYsW4dvf/jYuvfRSbNy4Ed/4xjcGHa9pGtavX4/p06fjiiuuKHT4A6ixALRYF46lvGic5i7aPEIIBGwzUJN4n+euziLqzyVTwjcbABAP+vUMh4iIiKgiqBE/IBthu/o+SI4avcMpGd3PXNXW1uJ3v/sd6urqAACNjY1Djj18+DCeffZZrFy5Eg899FDfda/Xi8cffxwtLS1YtmxZ33VN0/Dtb38bx48fx9NPPw1JKn4ume3Inbc6nvHhuqnFS64AQK5vhPX4n+E/cRi1s8r33FW7PwKv16HL3MmeHlfumQuA93cjFerWJQ4iIiKiSqCpKoQkwTBlEezX/wBCNuodUknpvnJlMpn6Eqvh7NixA5qmYfXq1f2ur1q1CgaDATt27Oi7pmkaHnzwQbzzzjt46qmn4HSWZiky234UWchQ3VPhshfnvFUvX8+5q65D5XvuquPkKRj/4268/crvdZlfDedWqmxTckl7NsrkioiIiKgYtHQCygv/hPSB3wFAxSVWQBkkVyPR2toKp9OJ2bNn97vucrkwa9YstLa29l377ne/izfeeANPPfUUqqqqShZjtv0ITmY9mDOt+Mufk6Y0wK+6INoPFn2u0Qp+eAJGoSJ0Qp8YJSWAJMyQ7W4oMEFSArrEQURERDSRadkMlN8+iWzHEcBs1zsc3ei+LXAkOjo6hlzlqq+v76sGeOrUKfz85z+HyWTCZZdd1jfmmmuuwXe/+90RzVlTk/92Ni2bRtj/Ho6l52DpOZPg8xV/text50xMju6Ht8ZWlj0DjqRyZdAzwY6SfD8+zpwOIWGpgs/nxAnZCVMmokschTTe4y8n/F4SERGNnaaqSLy8EdmTrTBf8jUYZ56vd0i6GVfJlaIoQ27vM5vNSCQSAIDJkyfj4MHCrJR0dUWhqlpeY7OdxyGyabyX8WGZ24zOzkhBYjgbzTcXltjbOPj6m6iZOfR5Nb3E/LkzT1LMX5Lvx5nSmSwcagRpsw+dnREkDS6YEuGSx1FIPp9zXMdfTvi9HDtJEiP6AIqIiCYeTdOQfPVZZI69BvOFK2Gav2z4myawcbUt0Gq1IpVKDfq1ZDIJi8VS4oj6y7YfAQDEHNNQ5ShNHX9fY+7clf/Q2yWZb6Q0JQQAsKaCJZ+7K5xEtRSFsHsAAFlzFRxaDKqWX7JMRERERMMT9mqYFq+AafEKvUPR3bhKrmpra9HePnifora2trwLYxRLpv0oQqoNDVMnl2zOhmmT4VddQHt59ruSk7mmvVUII5XOlHTuLn83bFIaxqpaAICwu+EUCiIRpaRxEBEREU1EWjIGIQTMS66B6YIv6x1OWRhXydWiRYsQiURw9OjRftfD4TCOHTuGc845R6fIclKnD+N4xovGacVrHvxxkhDosk6HR3kfmqqWbN58GdO5bVdmkUFXR2l7TEU6c1sSbTW55Mrg9EASQNDfWdI4iIiIiCaa9KFXEP2/30S2+xSAXA9WGmfJ1YoVKyCEwKZNm/pd37x5MzKZDK655hqdIgPUeAhyvAvvZXwlTa4AQNQ3wiJSCH5wdPjBJWZVY0giV4Yz2PZhSedOBHKrnA5vPQDA4vYBAGJdTK6IiIiIRitz4k0kWv4Vsnc6pJ4dQpRTFgUtnn32WYTD4b7/PnbsGDZs2AAAWLp0KZYuXQog12D4xhtvxHPPPYd4PI4LL7wQ+/fvx5YtW7B8+fJ+DYRLrbd5cMg6BdXO0py36uWb9wngxC/gP/gWqqfPLencZ6OqKuyII2iegrrke4j7T5d0/mykCwAgO3Nl8R01ueQqGe4qaRxEREREE0Xm9EEoL22A5J0O6xV3VmQvq7Mpi+TqqaeewqlTp/r++/Dhw/jxj38MALjjjjv6kisAWL9+PRoaGrBt2zY0NzfD6/Xilltuwe23317yuM+UbT+CjCbBOWX28IMLbMr0qXhPdUErs35XsUgUZpFBunom0PYeMqHSrhiJeDeykCBsuT5nTo8PcQCZCBsJExEREY1UNvAhlObHIDm9sF51D4TJqndIZacskqtdu3blPVaWZaxduxZr164tYkQjFz95CB9mqzFveumXRiVJoMsyDTOUQ9BUFUIqj92eka5OOAGIqjoobRaIeGnPXJlTQSgmF9wi9/2QbC5kNAlajI2EiYiIiEZKcvlgnHsRTOetgGRhr8jBlMe/wsc5Tc1CCpzoOW/l1iUGUdcIC1IInyqfc1exQC6ZslR5oJiqYU6Wrhx7MpXrcZWxfPT7IYRATLLDkAyVLA4iIiKi8U6NB6ElohCyEZbProLkqNE7pLLF5KoA1O4PIKtpdJka4HHp02urZt4nAACdB8un31UqlNt+Z6+uQdZWA6caQiZbmoqG/nAC1VIMWk+Pq14J2QVTJjzEXURERER0Ji0RhfLrRxBv/hE09godFpOrAsi051aLzJPm6RbD1BlT0am6oJ4+oFsMH5eO5rbfOWtqYaiqRbUUQ3coXpK5uwJRVEkKjC5fv+sZswt2NVqSGIiIiIjGMy2dRLz5R1BD7TAv/UuWW88Dk6sCiJ44gJBqxdSZ03SLwSBL8Junobqc+l3FQ0hrMsx2B+y+STAIFV3tbSWZOtTZAUlofT2u+tjccIk44ol0SeIgIiIiGo+0bAbKb5+A2nkMlsv/BobJC/UOaVxgclUAaucxnChx8+DBaHXzYEEK0TI5dyUlQ4gKG4QQcDdMAQBEOkpTjj0R6Gkg7Knrd122e2ASWQS6WdSCiIiIaCjJ1/4d2ZOtsFz8NRhnnq93OOMGk6sxUpUwLMkudMiT4HXrW46yZm7PuatD5XHuypCOIiE5AADVk3LJVbK7oyRzZ8K5Yhqy09vvuqkqdwAz5i9NHERERETjkWnxVbAsWwPj/Ev0DmVcYXI1Rtme81aitvT9rT5u+sxp6My6kD1dHv2urNko0sZcmU6T2wcNgBYpUa+reK6YhnD0L2hh9+TOYCnB0paFJyIiIhoP0u/tzVXCtrlhbLxY73DGHSZXYxQ6cQBZTcA7s1HvUGA0SOg0T4U7fkL3c1eapsGOGFSLCwAgDEbEJCcMidI08DUlA0hKNgiDqd91Z88ZrFSYjYSJiIiIzpR657+Q+K8nkD7Qonco4xaTqzFKfngIp7IezJtZ+ubBg1Frc+eu4qeP6RpHPKbAKtKA9aM+U0lTNeyZENQil/FUkhk41ShS5oE9x3q3BapsJExERETUJ33oFSR//3MYZpwP4/xL9Q5n3GJyNQaaqsIaPYk2UQefzuetennmLgYAdOjc7yrcldv+Z3R+lOBo9hpUiwiCkWRR5/aHEqiWY9BsngFfE7IBMVghJ0rX0JiIiIionGVOvIlEy79CnrwQlsu/DiExRRgtfufGINv9AYxaGpmamWVT93/mrKnoyLqQ/VDfflfx7lxyZXZ91MHbUFWLKimOzu5IUef2B+OolmIwVnkH/boiO2FMsZEwERERkZZOINHyFCTvdFivuBNCNuod0rhm0DuA8SxwfD8sAFzT5usdSh+TUUaHeSrmxg9DU1XdPnlI9JxpslV/lFzZvPWQjgHB9tPADN9Qt45ZsLsbc0QGWvXgWzXTRiescW4LJCIiIhJGC6xX3QPh9EKYymMn1njGlasxiL5/EBHVgllzZ+kdSj9a7TyYkYLSdly3GDKR3LY7Z81HSZSzdhIAIO4vbiPheFeuzLp1iORKtbrhRAyZbJk0WyYiIiIqMTXUjlRP4QrZNxOSxalzRBMDk6sxMAXfw0nUotZj0zuUfjyze/pd6XjuSo0HkdUELE5X3zWDK5fsZMLF7THV2+NKcg6+LVDYq+GQkgiGokWNg4iIiKgcqbEA4jsfQeq1X0BL8N9DhcRtgaOkKhG4sgEccy0qm/NWvWbMnopTu13Ah/t1i0EkQojBBrckf3TN7kYWMqRYV3En73m+cNYM+mWjM1foIuzvgNfjGnQMEdFI7d69G1u2bMHBgwfR3d0Nq9WK6dOn44YbbsAXv/hFSGds0/7jH/+InTt34vXXX8eHH34Im82GWbNmYc2aNbj00ksHPFtVVTz99NPYunUrTp06BZ/Ph2uvvRa33XYbzGZzCd+SiMY7LRGFsvNRaIkobE1/D2Fx6B3ShMLkapS6juXOW1knz9M7lAEsJgPaTVMxL6bfuStjOgJFtve7JoQExVAFSyIATdOKlpQak0FkTAYI8+A/LGzVua2KsW4/gDlFiYGIKs+hQ4dgNBpx/fXXw+v1QlEUtLS04P7778dbb72FBx98sG/so48+io6ODlxxxRWYN28ewuEwtm/fjnXr1uGuu+7C7bff3u/ZDz/8MDZv3oympiasXbsW+/btw8aNG3H48GFs2LCh1K9KROOUlk4i/uJjUENtsF51D2TfTL1DmnCYXI1S8Ph+1GoCk+efo3cog1J9c2FuexeJ9vdgnVT6M2GWbBRJ08BS6BmrB+5EABElDZfNNMidYxNLpOHUIkiZ3EMmb/aec2CpUJFX0Iiootx6660Drn31q1/FunXrsG3bNtxzzz2oqqoCANx333345Cc/CVn+aHX/r//6r3Hdddfhpz/9KW666aa+sYcPH8azzz6LlStX4qGHHuob7/V68fjjj6OlpQXLli0r8tsR0USQ+eDPUDuOwfK522GYvFDvcCYknrkaLf8xtGse1NdW6x3JoKrnnAsA6Dz4li7z27QYsuaBByOF0wuPFEVHQCnKvP5gAtXS4D2uetl7Vq4y0e6ixEBEdKZJkyZBVVVEox+da1i6dGm/xAoALBYLLr30UqTTaRw//lFBoh07dkDTNKxevbrf+FWrVsFgMGDHjh1FjZ+IJg7jrKWwr/wnGGeer3coExaTq1FQs1lUJz9E2D6l7M5b9Zo1ezo6si5kdOh3pSgJOEQSwuYe8DWLpw4OKYkuf3Ga+PpDCjxSDPIQ560AQJhtSGkGCIWNhImo8KLRKLq7u3HixAn8/Oc/x/bt2zFnzhw0NDQMe29bW66aqsfz0QdEra2tcDqdmD17dr+xLpcLs2bNQmtra2FfgIgmFE3T0PXfz/T9m1CqqtM5oomN2wJHoevkcVhEGsb68j2vYzUb0GacgsbokZKfu4p0+WEFIDsGJlcO7ySkAUQ6PgQwveBzdwcimCclIDyDl2EHACEE4pIDhmSo4PMTEd19993Ys2cPgNzPm4suuggPPvjgsB/GHThwAC+99BLOO+88TJs2re96R0cH6uoG/8dQfX099u7dW7jgiWjCSe39T6TeeB6m87IwNJRPb9aJisnVKHQcehfTANTOKc/zVr2y3nkwd+xDquM4zPWzh7+hQKLdueTK5Bq4Nc/orkUaQCrYXpS5Yz09rkxVZ29SnDS6YElGihIDEY1v4XAYmzZtymuszWbDmjVr+l279957sWbNGnR0dGDXrl0IBAKIx+NnfU4oFMLdd98Ng8GA733ve/2+pigKnM7B+8+YzWYkEom8Yu1VUzP6ymA+X2X1weH7TmyV8L6h13Yg8sbzcC6+HN4Vq8t2x1Ux6PX7y+RqFDJtRxDTzKibPkPvUM7KM6MR6AA6jh/B1BImV4lQ7iyTrWrg1jzhyiU9WsRflLnToU4AQ/e46pUxu2BPHC9q1UIiGp/C4TCefPLJvMZ6vd4BydWCBQv6fn3dddfhgQcewKpVq9Dc3Nxvu1+vaDSKW265BadOncJPfvITzJ07t9/XrVYrUqnUoPMnk0lYLJa8Yu3V1RWFqmojugfI/UOls7NyPpTi+05slfC+6cOvIvHyv8Ew43x4V6yD3185/axG+/srSWJMH0ABTK5GxRn/AAHzZNTrUOJ8JOqnTQZeA+LdxVklGko6kkuunN6Bq0fC7EBamGBIFKeYhBbLPVdyDH3mCgBgq4YrvA+ReBIu+8j+YUJEE9uUKVNw8ODBgj2vqakJW7duxUsvvYSVK1f2+1o8Hse6devw7rvv4rHHHhu06l9tbS3efnvwpvBtbW1DbhkkosqWOfUu5IYFsFy2DkKSh7+BCqK8s4My5O/wwyeC0GrKv0ht71cAACAASURBVC9AtduBkGqFVuKqeGo8CFUDbFUDKykKIZAyV8OphhFPZAo6r6ZpMCSC0AAIx9mrOBocHhiEimAny7ETUXH1btsLhUIDrq9btw5vvvkmHnnkEVxxxRWD3r9o0SJEIhEcPXq03/VwOIxjx47hnHPKe4s6EZWWpuVWpi3L1sD6hb+FMBS+9Q0NjcnVCJ06+C4AoHpW+fcGkCUJEeGErJQ2uRJKGHFYIcmDL4yq9hrUSFF0Bgtbjj2qpOFCBCmjC0I6+6KspWfLYjTQWdAYiKhy+f0DtztrmoYtW7YAABYvXtx3PZlM4m/+5m/w+uuv4wc/+AFWrFgx5HNXrFgBIcSAc2CbN29GJpPBNddcU6A3IKLxLus/gfgvH4Ia8UMICcJg1jukisNtgSOknDwEVROonbNg+MFlQDFUwZfuKOmchlQYcWno/arGqlrU+A/heCCO6fWFO2zoD+V6XKnW4XuP9TYSTgaLc/aLiCpPU1MTli5dioULF8Lr9cLv96O5uRkHDhxAU1MTLrjggr6x9957L1599VUsW7YMmqbhV7/6Vb9nfeYzn4HXmzs72tjYiBtvvBHPPfcc4vE4LrzwQuzfvx9btmzB8uXL2UCYiAAAaqgNym8eBSQDwPPkumFyNUKW8AkEDTWoMtv0DiUvGUs1HNGj0DQVQpRmodKcjSJlGjppsnnroR7LIOD3AyjcWQF/KIEaKQbZOXnYsQ5vLRIA0mE2EiaiwrjpppuwZ88ebNq0CZFIBDabDY2Njfj+97+PL33pS/3GvvtubhdES0sLWlpaBjzrmWee6UuuAGD9+vVoaGjAtm3b0NzcDK/Xi1tuuQW33357cV+KiMYFNRZAfOf/BjQN1qvvHf7sORUNk6sR6A7FMUlrR8S9ePjBZUI4amCIqchEgzA6B1apKgabGkPANHSCY3bXQQEQ62oDULizAv5ADHOkGIzVZy/DDgAGuxuqJqDFAwWbn4gq2x133IE77rgjr7G7du0a0bNlWcbatWuxdu3a0YRGRBOYlohC2fkotEQUtqa/h+wevmE5FQ/PXI3AewcPwyalYZ/aqHcoeTP29HsKdZwuyXzJZBoOkQCsVUOOEa7cp7FqqLDnnSLd3TAIddgeVwAgJBlxYYOcYCNhIiIiGr80TQVMFlg/fxdkX/kXXJvomFyNQPjEfgCAr8ybB5/JVlMLAIh1lqYce7i7C5LQINndQ46RnLnkR4oXtlJfKpQ7W5bvUrhicMKUDhc0BiIiIqJS0LIZaGoGktUF27XrYZhc/sXWKgGTqxGQut9DUpghV9frHUreqmonAQCSwdJUxYt25+Yxu4begiiMFqRkG6zpIFLpbMHm1qK5ZE0480uu0iYXbGrlNNQjIiKiiUFTVSRe3gjlxcehqSoEC1iUDSZXeQpEkqjPtiHmmFaywhCF4KmpRlw1IhspTVW8RDBXIMLqPnuCk7F6UCMXrhy7pmmQE0EA+a9caVY3nIghmSpcgkdERERUTJqmIfnqs8gcew2GhvkQ0vj5d2kl4O9Gng4fP406OQhLwxy9QxkRo0FCGE5IJep1lYrk5nF4vGcdJzl9qJGi6AgUJrkKx1KoQgQZyQxhyq+So2yvhk1KIxDkuSsiIiIaH1J7/xPpfbtgWrwCpsVD98gjfTC5ylPX0f2QBOCZOX7OW/WKG1wwp4IlmUuN5eaxe86+emTx1KFaiqEjECvIvL09rrLW/Csimly5GCOdpe0DRkRERDQaqXdfQuqN52FsvASmC76sdzg0CCZXeVI7j0IDYKifpXcoI5Y2V8OejZRmMiWEuGaGbDCddZi5ug4GoSLiL8xZsM6QgmopBsmRf3Jl9eQKa8TZSJiIiIjGAbluDozzl8F88c08Z1WmmFzlIRRNwps6jbilNu8tZ+VEs3tgESlkE4VZJTobORVBXLIPO05y5RKbZLAwVQy7elauzNW1ed/j6qmkmAoVtmohERERUSGp4dyH0bJ3BiyXfA1CknWOiIbC5CoPB9/vxnRDJyTfbL1DGRVjTyITLkGvK3MmgqTBOey43nLsKFChjUB3CHYpBYPr7Ge9zmR29/TbirKRMBEREZWnzOmDiP37PyC1b2TNx0kfTK7ycOrYcdilFKpmjJ/mwWeyeHIrNJES9LqyqjFkTcMnV8JRAw2AORlAJquOed5UT6n5fCsFAoAwmpGACSJRmvNoRERERCOR9Z+A0vwYJEcNDLOW6h0O5YHJVR6Spw8DAAz1c3WOZHRctbm+XIlAcZOrdCYLp1CgWaqGHStkA9KmKlRLEXSFE2OeO9vT42okyRUAxCUHjClWCyQiIqLyoobaofzmUQiTFdar74NkGf7Da9Ifk6thRJU03MpJZCQzJPckvcMZFY/Ph7QmIRMubuGGcHcABqFCdrjzGq/ZvQUpx65qGmQlt7VPOPPfFggAKaMLlgwbCRMREVH50DIpxH/zKKCqsF5974g/PCb9GPQOoNwd/zCMGQY/sp6Z46p58JksZiNOaw6IeHF7XUW6O1EDwOiozmu80V2Lmq7TODTG5CoYScItolCFBGEdftXsTKrFDYfSjqyqQmYTPiIiIioDwmCC+VNfguSqhexu0DscGgH+a3IY7530Y5IchH3qPL1DGZOY7IIpWdzCDYlgbmue1Z3fpyuW6lq4RBz+7rGVie/tcaWa3SPuUi5sbrhEAqFwYZoZExEREY2Wlk4i25Y7jmKc8xeQa8dfC6BKx+RqGPG245CEBmP9HL1DGZOUyQ1bNlzcOcK5lTG7J7+teZKrFpIA4l1jOwvWW4ZdjGLJ3ODyQBIaQl3sdUVERET60bIZKL99AvFfPwI1zmJb4xWTq2FYYx8CAOTa8VmGvZdq88ABBWo6VbQ5MrHcDwJnT3Pe4fSej8qGx9ZIuDOkwCNFYawa2XkrALD2lGOPdRWmmTERERHRSGmqisTLG5E92QrLZ26CZMvv/DqVHyZXw5hq6EbaXgdhHr4xbjmTexKZ6BhXic5KCSGhGSGbLXkN7+11JStdUDVt1NN2BeOokpQR9bjqZe9pJJwIceWKiIiISk/TNCRffRaZY6/BfOFKGOdfondINAZMroYxRe6GpWF8lmA/U2+vq3BHW9HmkJNhxEX+Saiwu6EKGW5EEIwkRz2vEuiEJLRRbQu096yyZdlImIiIiHSQee8NpPftgmnxCpgWr9A7HBojVgschk1KQR7n560AwOmtAwAoRVy5MmWiSBgceY8XQkLW6kFNIor2gAKPK78Vr4/LRroAGZBGWIYdACSrE1lNAuJMroiIiKj0DDM+CcvnboNhJpsETwRcucqDXDe+z1sBgLu+HqomkB7j+aazsalRZEwja3AnO72okaPoCMRHNWdWVWFI9PS4cnhGfL8QEmKSHXKCjYSJiIiodNLH/gQ13AEhBIyzLoAQQu+QqACYXA1DlU2Q3JP1DmPM7FYLwpoViBan11Umm4UDcWiWkfWZMlXXwSNF0REcXSn0QCQJt4gBwKgb7CUNTpgzYysHT0RERJSvzPtvIfHfP0XyT/+hdyhUYEyuhmHwTh9x76RyJIRAVHLBUKReV5FgBCaRHXF1G8nlg0NKortrdCtH/mCuDHvWaIcwmEf1jLS5CnY1Cm0MRTWIiIiI8pE5fRDKb38CyTsdlotX6x0OFdj4zxqKTPZO1zuEgkma3LBmirP9LdKd225odFaP6L7eioHp4OjOguUaCEch7CPfEtjH6oZLiiOqpEf/DCIiIqJhZLveh/LiY5AcNbB+4RsQJqveIVGBMbkaxkRKrrLWaji1KDRVLfiz48EuAIClamRb83qLUGhR/6hWjvwhBdVSDIaq/HprDUZ2eGAWGYS62bCPiIiIiif1+n9CGK2wXn0fJKtL73CoCFgtcBhSzXRMlM1iktMLOaQhHuiEvaauoM9OhnJnuezVI0uuhCuXFLnUMMLxNKrsphHd7w8qqJFjkJ2jO28FAOae5sORrk5gSu2on0NERER0NpbL1kFTwqM+J07ljytXwxBmm94hFIzZnUtkgu2F73WVjeXOcjlrRraCJMwOqLIZNXIUnYGRF7WIhEIwiQwkx8jLsPeyeXI/4BLB4lVSJCIiosqkJaJI7HkGWkqBMFogufhB7kTG5KqC9K5WxbuL0OsqHkJKk2Gw5t9EGMgV2oCjBh4pivZRlGPP9pSWH00Z9l69PcBS4a5RP4OIiIjo47R0EvEXH0P6wO+gdp/UOxwqASZXFaS6vgEAkAp2FPzZUjKMuLCPqkeDoao2t3I1wnLsmawKuafH1WgaCPcyOnOJmRZjI2EiIiIqDC2bgfLbJ6B2HIXl8q9Drp+rd0hUAkyuKojT5UBMNUOLFn6FxpSOICE7RnWv7PLBK0XR0T2ylavucALVPT2uxBj2LgvZiDgskNhImIiIiApAU1UkXt6I7MlWmC9eDePMT+kdEpUIk6sKIoRAWHLCoBR+hcaqRpE2Okd1r+T0wSQyCAdGFleuDHsMmmSEsIxu7l6K7IQpFR7TM4iIiIgAQIsHkG07BNMFK2Gav0zvcKiEWC2wwiQMVXCluwv6TFXTYEccMcvoSor29rpSwyMrKNGXXNmrR7Ud8UxpkwtWbgskIiKiApAcNbD/1fcgLKPb1UPjF1euKkzGWg2HFhlVT6mhRMNRWEUaks09qvuFK3deypoJIJ7Iv5GvP6TAI8dgGMN5q16qxQ0n4kils2N+FhEREVWmVOtvkfj9FmiaysSqQjG5qjDCXgOzyCAZKdwWuHBXbsXJ6Kwe1f29K1c1UhQdIyhq4Q8m4JFjkAuQXEn2ajilBIKh6JifRURERJUnffhVJF99DlrEDxTwQ2waX5hcVRhTX6+rDwv2zFggVyDDXDW6cujCaIFqcuSSqxH0ugoEo3AKZUzFLHoZXbnYw37/mJ9FRERElSXz/ltI7P4Z5IYFsFy2DkKS9Q6JdMLkqsLYenpdRf2F63WVDOXOcNndo09yZJcPNXIU7SNIrtLhXCIkOceeXFndudWveICNhImIiCh/mbZDUH77E0je6bB+/i4Ig0nvkEhHTK4qTFXtJABAsoC9rjI9hSCcXt+onyG7fPAaYujMM7lKZ7IwJIIAxlaGvVdvI+FkiI2EiYiIKH+aEoFUVQ/rF74BYbLqHQ7pjNUCK0yVx4OAZoAaKVwSocVDyGgSHPaqUT9DcvngFq+jszu/M09d4SSqpVyPK6kAyZW12osogGy0sJUUiYiIaGLSshkI2QDjzPNhmL4EQuKaBXHlquLIsoQwHJCVwiURUiKEuLCNqRy6cPogQ0Uiz5Ujf1BBtRyFBgFhH91Zr35MNqRhAJTg2J9FREREE5oaCyD2i/VIH/0jADCxoj78P6ECKYYqmFKFSyJM6QgUyT6mZ0g9Ff8MiW4kU8OXQ+/tcQWrC0Ie+wKsEAJxYYcxyUbCRERENDQtEYWy81Fo8VBfxWOiXkyuKlDa7IZDjRTseeZsFGmjc0zPOLMce2ce5dg7Qwo8UrwgZdh7JY0uWDJMroiIiGhwWjqJ+IuPQQ21wfr5uyDXztI7JCozTK4qkGb3wC4SSCn5V+Yb8lmaBocWh2pxjek5wlEDDQI1cn69rrpCCXgNsYKct+qVMVfBrsWgquxNQURERP1pahbKS09C7TgKy+Vfh2HyQr1DojLE5KoCGat6el11jL3XVTSWgF1KQljdY3qOkA0Qtmp48ux11RlQ4BKxglQK7IvBXo0qKY5wLFmwZxIREdEEISTI3hkwX7waxpmf0jsaKlNMriqQ1dPT66qzbczPCnfl+kIZHGNLrgBArvKh1hDLa+UqGe6GAdmCrlwZHB4YhIpgF8uxExERUY6maVDjQQghYF76lzDNX6Z3SFTGmFxVIFdtPQAg0T32XlfxQC4RMVeNPckRTm9uW2AgftZxyVQWhmQIQGEaCPey9DRBjnaxkTARERHlpPb+EvFf/L9Qo/zwlYbH5KoCuX11yGoC2Yh/zM/qLZ1urR57kiM5fXAghq7A2Xtd+cMJeKTcmEJuC7R7ctslk8Gxf1+IiIho/Eu1/hapN36V62NViNYvNOExuapARqMBYdiBeGDMz0pHcs9wecZeilRy+iAAaNEuZLLqkOO6Qgqq5cI1EO5lr6kFAKQjbCRMRERU6dKHX0Xy1edgmHE+zBffPKZ+nlQ5mFxVqJhcBVNy7L2utHgQqiZgdlWP+Vmip6y6R4rAH0oMOa4zmOtxpRksgMk25nl7yXY3VE1AK0DSSURERONX5vRBJHb/DHLDAlguWwchyXqHROMEk6sKlTZVwZ4Njfk5IhFGXFgL0pn8zF5XZ6sY2BVKoEaOQXbWFPRTJCHJiAsr5OTYvy9EREQ0fsm+GTCeeyWsn78LwmDSOxwaR5hcVSjV7oETcWQyqTE9x5gKQ5HsBYlJ2N3QJENPOfahi1p0hhR4jfGCnrfqlTA4YUoXrsEyERERjR/ZwCloqTiEwQzLX1wPYbLqHRKNM0yuKpTs9EISGkIdY6sYaM5GkTI6CxKTEBIkRw18hthZV678oQTcorANhHtlTFWwZyPQNDYSJiIiqiRqqB3Kjh9C2bVR71BoHDPoHQDpw1pdC7wHRDraUNMwZVTP0DQNdi2OkHl09w9GcvlQF27H62fpdRUOhmGxJSAKWIa9l2Z1oyr6HpRkFjYL/3gQlbvXXnsNP/jBD9Dd3Y1p06ZhwYIFWLBgAebPn485c+bAYOCfYyIanhoLIL7zEUBVYb5wpd7h0DjGv3UqlMOX63WldLeP+hlxJQ2HSCBkGXsD4V6S0wu3ODLkypWSzMCUCgG2wlYK7JvfUQ1bVwqBQBi2SSy5SlTu1q9fj7lz5+Lmm2/G+++/j3379uHFF19EW1sbjEYj3nnnnZLFsnv3bmzZsgUHDx5Ed3c3rFYrpk+fjhtuuAFf/OIXIZ3lbOqrr76Kr33tawCAF154AfPmzev3dVVV8fTTT2Pr1q04deoUfD4frr32Wtx2220wm81FfS+iiU5LRKHsfBRaIgrb1d+EXN2gd0g0jjG5qlDVdQ1IAUiHR9/TKdLth11okB2FS66E0weLlkAkFIaqapCk/gUr/KFEXxl24fAWbN5evc2QI10dAJMrorLn9/vx1FNPYerUqf2uB4NB7N+/v6SxHDp0CEajEddffz28Xi8URUFLSwvuv/9+vPXWW3jwwQcHvS+ZTOI73/kObDYb4vHBz5s+/PDD2Lx5M5qamrB27Vrs27cPGzduxOHDh7Fhw4ZivhbRhJfYswlqqA3Wq+6BXDtL73BonGNyVaHMVgu6NCsQG3238WjADzsAk6twSUhvxcAqRNAdScBb1f8gqT+kwCP19rgqfPJjrc7NrwTYSJhoPDj//PNx8uTJAcmV2+3Gpz/96ZLGcuuttw649tWvfhXr1q3Dtm3bcM8996CqqmrAmJ/+9KeIx+NYuXIlnn766QFfP3z4MJ599lmsXLkSDz30UN91r9eLxx9/HC0tLVi2bFlB34Wokpj/4iswzvssDJMX6h0KTQAsaFHBYpITxsToe10pwVyzXVt14bbnSWf0uhpsa6A/1NPjSkgQtrH31vo4p7enkXB49EknEZXODTfcgA0bNqC7u3ybf0+aNAmqqiIajQ742tGjR/Gzn/0Mf//3fw+HwzHo/Tt27ICmaVi9enW/66tWrYLBYMCOHTuKETbRhKZpKtIH/weapkJy1MAwbbHeIdEEwZWrCpY0ueFIjP7MVSaSa7br9PgKFRKE64xeV0EFH/8MyR9MoMEQg2SvLkhvrY8zuWqQApCNsZEw0Xhw++23AwCuvPJKLFu2DEuWLOkramG16lNCORqNIpVKIRKJ4JVXXsH27dsxZ84cNDQMPMfxne98B0uWLME111yDJ554YtDntba2wul0Yvbs2f2uu1wuzJo1C62trUV5D6KJStM0dL34r0jsbYbFbIdxxif1DokmECZXFSxr9cCVOIqsqkIeRaLSm4BYqgq3PU+YHYDRAu8Q5dj9IQWLjUpRilkAgDBakIQJQhn9ih4RlU5LSwsOHDiA/fv348CBA3jmmWfwwQcfQAiB6dOnY+fOnSWP6e6778aePXsAAEIIXHTRRXjwwQcHND3/j//4D7z55pv45S9/edbndXR0oK6ubtCv1dfXY+/evYUJnKhCpPb+Eqk3mmH8xFVMrKjgmFxVMNlZA2Mwi2hXF6p8I199khIhxDULnAZjwWISQkByejEpHcfvh9wWGIVwzCjYnB8XlxwwpsJFez4RFU5dXR3q6ur6nTlSFAUHDhzAwYMHR/XMcDiMTZs25TXWZrNhzZo1/a7de++9WLNmDTo6OrBr1y4EAoEBhSq6u7vxz//8z7j55psxZ86cs86hKAqczsH7CZrNZiQSibxi7VVTM/j2w3z4fIXpazhe8H0nntCffo3IG7+Cc/Hl8F69ZsCHHhNZJfz+nkmv92VyVcFM7lrgAyDYcXpUyZWcikCR7AWPS3L64A2+j/ZBkqvuUBw2e7RoK1cAkDK5YI1HivZ8Iiouq9WKJUuWYMmSJaO6PxwO48knn8xrrNfrHZBcLViwoO/X1113HR544AGsWrUKzc3N8HhyK/2PPPIILBZL37bGs7FarUilUoN+LZlMwmKx5BVrr66uKFR15I3SfT4nOjsr52cj33fiUeNBxHY9C8OM8+FdsQ5+/8BzkBNVJfz+nmm07ytJYkwfQAFMriqaw5vbZhLvGt25K3M2iqRxbP8DDkY4vXBqregMxaFpWt+nSrFEGsZ0BBI0CGfhy7D3Us1uOOMdSGdUGA2s+UJUzlKpFH784x9j165dSKVSmDdvHq699lpcddVVo37mlClTRr3qNZimpiZs3boVL730ElauXInW1lZs374d9913Hzo7O/vGhUIhAMDp06dhNpsxffp0AEBtbS3efvvtQZ/d1tY25JZBIupPsrlhu/YfILkbICRZ73BogmJyVcHc9Q3IAkiHOocdOxi7GkPEXPi/1CWnDwYtDWMmhnAshSpHrkGmP5iAR8p9ylSMMuy9hN0NZ1BBKKzA6yn8yhwRFc4Pf/hDvPzyy7jxxhthMplw5MgR/MM//AN27tyJH/3oRzAY9P9rrnfbXm/y1NbWBiC3evXII48MGN9b0r03wVu0aBH27NmDo0eP9itqEQ6HcezYsTElkkSVINN2CFqwDcb5l0D2ztA7HJrg9P9bh3RjdTjh14xQoyMvO64k03AIBRHrwJ4tY9Xb66pGiqI9oHyUXIUUVEvFayDcy+iqgfyhBn+Xn8kVUZlrbm7GE088gU9+8qND6XfddRduvfVWbNy4EbfddlvJYvH7/fB6+/9s0jQNW7ZsAQAsXpwr9Xzuuefixz/+8YD7f/Ob36C5uRnf+ta3UF9f33d9xYoV+Jd/+Rds2rQJ3/3ud/uub968GZlMBtdcc00xXodoQsh2vQ+l+UeQbG4Y5vwFhMGkd0g0wTG5qmBCCESEE4bEyMuOh7sDsAkVst1d+LhcPb2u5Cg6gwrmTc3N0dvjCkBRz1xZ3LlnxwOdAKYXbR4iGrtkMomamv4/D7xeL771rW/hW9/6VkmTq6amJixduhQLFy6E1+uF3+9Hc3MzDhw4gKamJlxwwQUAckU4vvCFLwy4//DhwwCAiy66CPPmzeu73tjYiBtvvBHPPfcc4vE4LrzwQuzfvx9btmzB8uXL2UCYaAhqqB3Kzv8NYbTCuuJeJlZUEkyuKlzCWAVLOjTi+6LdnbABMLkKvz2vd+XKK8f6FbXwBxOoNcYBswPCaC74vL0cNbXQACSD/qLNQUSFsXTpUvziF7/A3/3d3/W7Xl9fj66u0jYDv+mmm7Bnzx5s2rQJkUgENpsNjY2N+P73v48vfelLY3r2+vXr0dDQgG3btqG5uRlerxe33HJLXgUxiCqRGgsgvvMRQFVhbbq3qB/KEp2JyVWFy1iqUZM61a9wRD4SoW4AgLWq8D+shNECYXGiQYtjf+Cj8sX+kIJFpuL1uOpldXsRB5CJdhd1HiIau3vvvRdf+cpXEAwGcfPNN2PWrFlIp9PYtGnTsCXOC+2OO+7AHXfcMer777zzTtx5552Dfk2WZaxduxZr164d9fOJKknm/behJaKwXf1NyNUDG3gTFQuTqwonHDWwRVKIhiNwVrnyvi8VziUe9prinH0STh9q0zH8LnjGylU4ty1Qckwrypy9JJsLWU0C4mwkTFTuZs+ejc2bN+OBBx5AU1MTDAYDVFWF2+3Ghg0b9A6PiHRiWnApDNMWQ7JX6x0KVRgmVxXO5PYBp4FQ++kRJVdqLJd42KuLs4okOb1wBw6ho2dboKZp8AcVOF2RopZhBwAhJMQlOwzJkW+XJKLSmz9/PrZt24Zjx47hyJEjsNvtWLx4MRyOwreKIKLypWUzSOz+GUyLPge5bg4TK9IFm/hUOHtNrpR61D/CXldKCAnNCMk4suaV+ZJcPtiyEcQTKUSVNKJKGlJGgUFLFbUMe6+EwQlzunKa7RGNJ3feeScUJffBy/Hjx/uuz5o1C5///Ofxmc98hokVUYXRNBWJ3f8HmaN/gBr4UO9wqIJx5arCVdVOAgAkgx0juk9OhRGXilemXDh9kLQs3FIcnT1bA0tRhr1X1lwFe+IUVE2DNIKzaERUfD6fD5lMBgBw1VVXwWq1orGxEfPnz8eCBQuwYMECNDY2wmwuXuEbIiofmqYh+cpzyBz9I0wXrIRx/iV6h0QVjMlVhbN7ahDSJKiRkVXGM2eiSBmdRYoqty0QADxSFO2BOGRJgqcEZdh7aTY3qiKHEY2l4HLwH2hE5eSBBx7o+3VLSwsOHDiA/fv348CBA3jqqafwwQcfQAiB6dOnY+fOnTpGSkSlkNr7S6T3/TeMn7gK5vNW6B0OVTgmVxVOkmREhAOyMrJeV1Y1BsU0tUhR9W8k3BlQYJClM1auip9cGRwemDsyCARCcDlqiz4fEY3OmjVrsGXLln69nhRFwYEDCeTqSQAAIABJREFUB3Dw4EEdIyOiUtBUFWr3SRgbL4b5wpV6h0PE5IqAuKEK5nT+lfGSqQycIo64papoMeUSKIHJ1gQ6AgpMRhm1pjggGyCsxVsx62Wuyq2cRbo6gKlMrojK1ZEjR5BKpfpds1qtmDNnDp5//nmdoiKiUtBUFUKSYPnc7QBG1lKGqFhY0IKQMVfDkQ3nPT4cDMMkspDs7qLFJGQDhL0ak8xxdAQVdIYU1JkSEI4aCFH8/21tnlxylQiWtgkpEeVn7dq1ePzxxyGEwOnTpwd8XVEUbN26VYfIiKgUMu+/hfh/fgdqPAghSRCSrHdIRAC4ckUA4PDAFVMQjyuw2azDDo92+2EBYHIVt2qf5PLB0xVFR0CBzWKAR46VrMO6o8aHBIB0eGRn0YioNObNm4c//elP0DQNX/7yl2G32/sKWjQ2NuLYsWPw+Xx6h0lERZA5fRDKb38CqXoyhIHnoqm8MLkiGFw+oB0ItrfBNnPmsOOVQC7hsBQ5uRJOL1ydHyIUy5Vjd9ZEIOxzijpnL0NPuXc1NrKzaERUGt/85jcBAIsWLcLWrVvR0dHRV9SipaUFmUwG9913n85RElGhZbveh/LiY5AcNbBedQ+EafgPhYlKickVwerJnSmKdLYBeSRXyUg3AMBRU9yS6JLTB1MmChlZQAUs2SgkZ2lWroTBBAUWyAk2EiYqZ2+++SaMRiPOOeccLF++XO9wiOj/Z+/O46Oq78X/v845s2eybxBISEIgBJBFRQuiuFVFsVqqta61em17W2sr1dt+72K1y21vtb/aXpe2qA+xVqzXWrUu1B1lUVxQlpAAYQuEkD2Tyexzzu+PMYGQhSyTOVnez8fDxwM/c5b3JJDM+3w+n/d7GOmtR/C/ch+K1YnzkjtRnSlmhyREN5Jcic5eV4Hm/vW6irbHil8kZQx/cqVgkKG2oxPbpJqoZYEAfi0Za7j/e9GEEIn34Ycf4na7mTNnjtmhCCGGm8WGmpaHffENCf08IMRASEELQXJ2LroB0f72uvK3EDY0VJtrWONSPu91lam1JbQMe4ewLQVn1Juw+wkhBu6Xv/wle/bs6TZeWVlJY6MUpBFiLDBCfgw9ipqUjuvSH6Ol55kdkhC9kuRKoFmstONC9TX17/hgG+1K0rCXPO3odZVn95OhxpKcRD6p0p2ppCrt+IORhN1TCDEw+/fv55RTTuk2vmXLls59WUKI0csIB/G9+hsCb//J7FCE6BdJrgQA7VoqtmD/el3Zwm0ELe5hjgiUpDRQLUxyBphgD8TG3MNbRONYalIGyWqAFk97wu4phBiYlJQUmpq6Pxg65ZRT2Lp1qwkRCSHixYhG8L/xAHpdFZbiBWaHI0S/SHIlAAjZ00jS+7e/yKm3E7EN/yZSRVFRkjMpdAcpSY+iOFNRNOuw37eDLSU2S+ZpqE/YPYUQA3PmmWeycuXKbuO6rhONRk2ISAgRD4ahE3hnJdHqrdjPvBFr0almhyREv0hyJQAwkjJIoZ1QONzncZGojhsfhiMxFXrU5GxybD6KUyMoCaoU2MGZFtvz5WuWXldCjFQ/+MEP2Lp1K9/+9rcpLy8HoL29nT/84Q+UlpaaHJ0QYrCCHzxDpOoDbKd9FduMJWaHI0S/SbVAAYCWnIWlQae5rp7cSb1vFPW0enGqYRRXWkLiUpOzCNfvRXG40TILEnLPDu6sbCJAyCOb4oUYqXJzc3nmmWf4r//6L5YvX47FYiEajZKSksIf/vAHs8MTQgyStehUFKsT+7yLzQ5FiAGR5EoA4EzPhr3QWne47+SqsZ5MwJqcnpC4lORsCLZjhAIoU+Yn5J4dHKlZeIFoW/8KfQghzJGbm8uf/vQnDh06REVFBRaLhblz55KWlpiHQEKI+Ik2HUTLmIyWW4KWW2J2OEIMmCwLFAAkZ08AINDUd6+rjiVyjtTELNHrqBiIEU18Twt7EmE0FH//Cn0IIRLje9/7Hn6/H4C9e/d2jk+aNInzzjuPJUuWSGIlxCgU3rUB37P/RXjPJrNDEWLQZOZKAJCaOxE/EPH0Xbwh5InN4gx3A+EOavLR+ySyxxWAoij4VDeWkDQSFmIkyc7OJhKJtUhYunQpTqeT0tJSZsyYQVlZGWVlZZSWlmK3202OVAjRX5EDnxJ45xG0vBlYCuaZHY4QgxbX5Oo3v/kNlZWV/OlP0otgtLE4kvAbdmjvewlcxBubxUlOUHKlpGR3/tmMbuwhSwqOQFvC7yuE6N1dd93V+ee1a9dSUVHBjh07qKio4LHHHqO6uhpFUZgyZQqvvPKKiZEKIfojcrgS/+sPomYW4LzgNhSLzeyQhBi0uCZXzc3NvPfee/G8pEggr5aMNdDc90H+FiKGiupMTkhMit0NVgeEA6YkV1FHKm7/PiJRHYsmq2iFGGlyc3PJzc1lyZKj1cT8fj8VFRVUVlaaGJkQoj+MgBf/P3+H4s7EuXQFis1pdkhCDIksCxSdgtY0nIG+K+OpQQ8+xUW6oiQkJkVRUJOz0NsawJ6UkHt24UontbUcjzdIRuro/oGvGwZvfnyQxSdNxGmXf/pi9LrwwgspKytj5syZnUsBs7OPznI7nU7mz5/P/PmJLYIjhBg4xeHGsfh6tAnTUZ2JafMixHCST1iik+7KIDWwn3AkitWi9XiMNdxGQHMnNC41YzJoVpQEJXTHsiZnYK3VqWtqIiN1UsLvH0/VR7ysfmMXFlXhnJMnmx2OEIN24403UlFRwZtvvsnDDz9MIBAgMzOzy56rsrIyioqKzA5VCNEL3deC7qnHMmEa1pKFZocjRNz0mVx99NFHzJo1C6dzdD+xF/2jJWfiaA7T0tRMdk7Pe6qc0XYijsTst+rgOON6DD2a0Ht2sKfFliL6GuuhaHQnV83eIFmqhwNHZA+ZGN2uvvrqzj9v2bKFH/zgB5xxxhlYLBY2bdrEypUrURQFh8PB5s2bTYxUCNETI9iO/+X7MPytJF19H4pVis+IsaPP5Oq6665D0zSKioqYNWsWs2fPZvbs2ZSVleFwOBIVo0gQW1oOHADPkdoek6uorpOEjxZ7YqftFXsSiZ+zinFnxJYaBVpGfyNhX+Nh/iP1eV45fBFQZnY4QsTFf/7nf/KTn/yky56rjz76iH/7t39j+fLlJkYmhOiJEQ7iW/Nb9Nba2B4rSazEGNNncnXaaaexY8cOdu/eze7du3nxxRcBOhOujmSrI+ESo1tyVi4AvqYjwOxur3s8PtxqkFbX+Okf48rIxgdE2kZ/chVuPoKqQGr7fqK6jqZKgQ4x+u3fv5/i4uIuY6eeeip33XUXf/zjH7n11ltNikwIcTwjGsH/xgPodVU4zv8ulkkzzQ5JiLjrM7l64oknANi3bx/bt29n27ZtbNu2jfLycnbt2sWuXbt4/vnngVjCZbNJ6czRLDV3IkEg1NJzr6u2pgbSAas7PaFxmUlNSsMwwPCdoIriKBD1xsrsT9bqOdLkJy/LhAIhQsTZ3LlzefbZZ7n99tu7jE+dOpUdO3aYFJUQoifhirVEq7diP+sbWItONTscIYZFvwpaFBYWUlhYyCWXXNI5tmfPns5ka/v27ZSXl+Pz+UwpOiDiw+pOo83QMNp7nqVpb24kHbCnJr4kulkU1YJPcaEFW80OZeh8sfcwWWtiV22zJFdiTPh//+//cf3111NTU8ONN95IaWkpoVCIlStXkpGRYXZ4QohjWGeeg5qai2Vy99UxQowVg64WWFxcTHFxMV/60pcAMAyDqqoqtm3bFrfgRGIpioJXTcbSS6+rYGss6XKlj5/kCiBgScYWHv1FILSgBwCrotNSvQdmS8VAMfqVlZXx7LPP8rOf/YyvfOUrWCwWotEoVquVX/3qV2aHJ4QAQtvewFI4H9WdKYmVGPPiVopdURRKSkooKSmJ1yWFCQLWNBy9zNJEvC0AJGcmtlqg2SL2VFzBegzDGNUzs/aIh5DVgU0PoNfvBc4yOyQh4qKwsJBHH32U2tpaysvLURSF2bNnd+l9JYQwR2jb6wQ3/AWjvQn76V81Oxwhhp30uRJdRB1ppAVq0Q0D9bhEwvC1oBsKlqTxU9ACwHCmkdq2j/ZABLfTanY4gxKJ6iQZ7Xidebj8R3C1Hxz1yaIQAI2Njdx9991s3LgRq9XK3//+dyZMmGB2WEIIILxrA8ENf8FSeDK2BV8xOxwhEkLKhYkuVHcmyWqA1hZv99eCHnyKE2WcVZnT3BkkqSGamj1mhzJoLd4gqYoPw5mGL3kyedTR2h4yOywhhuyee+6hpaWF3//+9/h8PsLhMAA//elP+dOf/mRydEKMX5EDnxJ451G0vDIc534bRdXMDkmIhBhfn5LFCdnScgBoOXK422vWcBsBdfwVQbB9XsDD29hzFcXRoKUtQIrqR3OnY8kpJkdtpfrQ6H0/QnTYuHEj//Ef/8GiRYtQj3nwc9555/HKK6+YGJkQ45dhGIQ2v4yamY/zgttQLFJNWowfklyJLlyZsV5X3obabq85ol5C1uREh2Q6V1psj5m/ucHkSAavrakJTTGwJWeQNmUGqgKeAzvNDkuIIVNVFbu9exPSgoICqqurTYhICKEoCs6lt8eaBNucZocjREJJciW6SM2dCECoteushm4YJBk+dEeKGWGZyp0V2xQf8ozeRsK+llhi6MrIwpU3FYBo3V4zQxIiLs4+++zOfovH8nq9aJosQxIikXRPHf53HsGIBFFsLlTn+PvMIIQUtBBdONKyaDUU9LauiYTXG8Ct+Glzja9iFgDW5EyCgN4+ehsJhz2xBsLOtCwUh5tWNR1X+0GToxJi6FasWMHy5cu7jPn9fh588EFmzpxpUlRCjD+6rwXfy/dCyI8xfxlKqhSWEeOTzFyJLhRVo11JQvN3TSRaGxtRFbAkpZsUmXkUm5MgVlR/i9mhDFpHYqh+/v3zJ08mVz9CIBQxMywhhiw3N5enn36abdu24ff7Wb58OQsWLOCTTz7hzjvvNDs8IcYFI9iO/+X7MAJtOJeuQJXESoxjMnMluvFbUrGHuyYSvpYG0gB7SoY5QZnMryVjCY3eaoGKvwUDUFypAGg5xaS1bqWm+hDFU6eYG5wQgxSNRnnhhRc477zzePTRR6mpqaGiogKLxcK8efNISZElSUIMNyMcxLfmt+ittTiXrkDLKTY7JCFMJcmV6CbiSMft2dulD1KgJbaszJWeaWZoptHtqTg8HvzBCE776PtnYwl58KtJpHxeCjetoBR2Qeu+SpDkSoxSmqZxzz33sGDBAlJTU8nLyyMvL8/ssIQYV3RvI4anHsd538YySZbiCiHLAkV3SRmkKj7a2gOdQ2FvbFlZcmaWWVGZypI+gRy1lX01rWaHMiiOiJeQ5Wilx7SCEqKGQrReilqI0W3evHns3St/j4VINMMwANDS80j62q+xFp1qckRCjAySXIluLKnZaIpB85EjnWO6L7ZM0Jo8PpcFphZMw6WGObRvn9mhDFggFMFNO1F7aueYarXRqGXjbJNS1WJ0u+qqq7j//vul7LoQCWQYBsH1TxLc9GxslYu1ezsEIcYrSa5EN0d7XR1NrtSABx8OFG30LYmLB2deCQD+mt0mRzJwLd4Qqaof5bhKj373ZLL1I0SiUtRCjF4rVqygvLycSy+9lNtuu40nn3ySTz75BL/fb3ZoQoxZoY+fJ1z+JoYuvz+EON74/KQs+pSSMwEDCDTXdY517NkZr9T0PCJYsLRUd9mLNhq0tLYzSQ3Q6u6aXKnZxTg9n9BwYD8TiqaaFJ0QQ7N27VoqKirYsWMHFRUV/PnPf6a6uhpFUZgyZQqvvPKK2SEKMaaEtr1O6JMXsEw/E/vpV42q34dCJIIkV6IbV0YO7UDU09A55oh6CduTez9pjFNUC4GkiUxoqae+xU9OusvskPrN2xT7PtpTu+6XS51SClXQvK9CkisxauXm5pKbm8uSJUs6x/x+PxUVFVRWVpoYmRBjT3jXBoIb/oKl8GQcZ90oiZUQPZDkSnSjWh2040D1xyoEGoZBktFOu32iyZGZy5JTTL53PbsOtYyq5CrQGmsI7UrvmlzlTCmixbASrZNiAGJ0C4VC7NmzB4CioiKcTifz589n/vz5JkcmxBijKGiTZuE499son1efFUJ0JcmV6JFfS8EWjBWx8PpDJCt+2l2pJzhrbEspmEZo71rq9u2F2aOn3HPEE0uSHWldkyur1UKdmoPTK4UAxOi1fv167rjjDlpaWjAMA4fDwSWXXMKKFSvIzByfrSOEiDcjHESx2rGWLMQy9QsyYyVEH6SghehRyJ5Okh5rmtvW1ISmGGhJ6SZHZS5LThEAwSN7TI5kYDoqPSpJad1e8yVNJiNSjx4JJTosIeLipz/9KQsWLOC1115jw4YN3Hfffezfv58vf/nLUkFQiDiINh6g/ek7iRz4FEASKyFOQJIr0SPDlUGa4qXdH8LbFFtWZk8Zn2XYO6ipE4koVty+QwTDUbPD6Tc10EIUFcXu7vaakl2ERdHxHBxdCaMQHQ4fPswPf/hD8vPzycjI4Pzzz+fJJ5/k9NNP5xe/+IXZ4QkxqumeOvyv3AeqBTUj3+xwhBgVJLkSPbKkZmNTojTVN+D/fM+OM318J1eKqhJOmcxkrYH9tW1mh9Nv1pAHv5bc49PG1CmlADTvk43/YnSaPn06R47pydfhX//1X/nggw9MiEiIsUFvb8b38r2g6zgvvgPVLctshegPSa5EjxzpOQC0NRwh3NYMQHJ6tpkhjQiOiSVM0prZe6jZ7FD6xTAMnHo7YWvPlR7z8ifTqjuJ1lUlODIh4mPZsmX87Gc/48CBA13Gm5qaSElJMSkqIUY3I+TH/8pvMPwenEtXoKWPnn3GQphNClqIHiVnTwDA13gEvT22Z8eeNr5nrgCceVNRKl6n8cAe+EKR2eGcUHsgQorSju7oeTlHktPKDiWHCW0HExyZEPHxq1/9CoCLL76Yc845h7KyMnRd5+WXX+ZHP/pRwuJ45513WL16NZWVlTQ1NeF0OpkyZQpXX301l112Gara/VlmdXU1Dz30EOvWraO5uZn09HTmzJnDPffcQ1ZW1wI0zz33HI8//jh79+4lNTWVL37xi9x+++2SQIrhYbVjKZiDNmkWWk6x2dEIMapIciV6lJyV29nrSg20EsBGssVudlim07ILAYg27BsVzYRb2oKkqn58fRQjaU/KJ9W3DiPYjmIfv42ixei0cePGzibCO3bs4NVXX2Xv3r0YhsGDDz7I66+/TmlpKaWlpZxzzjnDFsfOnTuxWq1cddVVZGVl4ff7Wbt2LT/+8Y/59NNPueeee7oc/9lnn3HTTTcxceJErr32WrKysmhqamLz5s14vd4uydXjjz/OL3/5S84880yuu+46Dhw4wKpVq9i6dStPPfUUNptt2N6XGF+MaATD34rqzsR++lfNDkeIUUmSK9EjxeEmhAV8TVhCbfgV+dANoKTkENEcZAeO0NwWJCPFYXZIfWppaWWKEibs7n3WUc0qggPr8B+uwlU4J4HRCdG3hx9+mAsuuICpU3tvcp2ens7ChQtZuHBh51goFGLnzp3s2LGDiooK3n33XR555BE++uijYYv1m9/8ZrexG264gW9961s888wzrFixgtTUWDuLQCDA7bffzvz583n44YexWq29XrepqYn777+fxYsXs3Llys4HOiUlJfzoRz/i2Wef5ZprrhmeNyXGFcPQCbzzCNHDFSRd+Qt52CbEIElyJXqkKApeNQVboBm77ifUQ6W58UhRVPT0AvKDjVTVeEZ8cuVragDA0ceSzpSCaXAAWvZXSnIlRpTf/e53RKNRbr311s6xQCCAw9H3vzubzcbs2bOZPXv2cId4QhMnTkTXdbxeb2dy9fLLL3Po0KHOxMrv92OxWHpMst588038fj833HBDl5nySy+9lF//+te89NJLklyJITMMg8Z/Pkqk6n1sp10piZUQQyAFLUSvQrY0XFEPLsNLxCbr+ju48qaSpzWz92Cj2aGcUNATi9GV0XsxksmTcqiNphIZZf27xPj0yCOPdJmlOlZdXR3t7e0Jjqgrr9dLU1MT+/fv56mnnuK5556jpKSEvLyjBQHee+893G43Ho+Hyy67jHnz5jFnzhyuueYatmzZ0uV6W7duBWD+/PldxjVNY86cOZSXl2MYxvC/MTGmhT5+Hs/Ha7DOuQjb3IvNDkeIUU1mrkSvdFcGqb6DWJUoAWeq2eGMGJacYiyKTuuhPUCZ2eH0qaPSoy2595mrjBQ724xsZnoODPs+MiPkJ/DOSmwLrpDqU2LQWlpaehx/5plnePjhh9m+fXuCIzrq+9//PuvWrQNiKwAWLVrEPffc0+Xf1b59+4hGo9xyyy1cdNFFfOc73+mcybrhhhv4v//7P6ZNmwbEEkan09lj4YoJEybg9/tpbW0lLa17k/DeZGYOfiVCdnbPlUfHqvHwfr3b36PtkxdInnsuWZf8y4jfSxxP4+H7eyx5v4khyZXolZqcRVJTCAAtqf+/uMe6jqIWlpZqwhEdq2UETwD7Yh9CFVfv3z9FUfC6JuEI7cZob0IZxl4m4T2biOz7BMXhRjvrpmG7jxi/dF0f0vkej4dVq1b161iXy8XNN9/cZeyOO+7g5ptvpq6ujrfeeovm5mZ8Pl+XY9rb2/H7/Vx66aWd1Q4BZs2axQ033MCDDz7I/fffD4Df7++1YIXdHisyFAgE+v3+ABobvej6wGe7srOTqa8fPT3+hmq8vF8jfQa2U5eT9cWv0dDgNTuchBkv398O8n77R1WVIT2AAkmuRB/sadmwP/Znax8zH+ON4s4iYk1icrCB6jovxXkjd8mkFmwlhA3F5uzzODW7GA6tJXykCvswJleRXRsBCFdtwr7wWhSrVKAUI4vH4+GBBx7o17FZWVndkquysqOz2Zdffjl33XUX119/PWvWrCEjI/ZztGPP2PLly7uce/rpp5OXl8emTZs6x5xOJ6FQqMf7B4PBLtcTYiCitbtQM/NRrA7sJ38JRdXMDkmIMUGSK9Grjl5XAK406czeQVEU1KxC8v01VNW0jujkyhZuI9CPYiSpk4uJHFQJVu/CPvW0YYlF9zYRPVyJNmkW0UPbiez9COv0M4blXkIM1uTJk6msrIzb9ZYtW8Zf//pX3njjDb761Vhp65ycHHbu3NmtlxVAdnY25eXlnf+fk5OD3+/H4/F0WxpYW1uL0+nsLJQhRH9Fanfif/lerNMW4pBVBELE1QhezyTMdmxylZTR/UPAeOaYUMxErYX9h0ZuUYuorve7GEn+xHQORjOIHKkatngiVR8ABp+kng/J2YR3rhu2e4mxY7Tv/+hYstfa2to5NmdOrCpnbW1tt+Nra2s7Z7gATjrpJAA2b97c5Thd19m6dStlZWWj/mskEivaWI1/zW9R3JnYFlxhdjhCjDmSXIleaUnpRI3YL21nH6W8xyM1uwhVMfDVjNwKe572MKmKH8N54v1yEzNdHIxmYfccxBjinpXehHe/TzAln8fWNbNDKyNaswPdUz8s9xJjx8MPP8zy5cv5yU9+wrPPPktNTY3ZIfWooaGh25hhGKxevRqAuXPndo4vW7YMVVV5+umnuxz/1ltvceTIEc4666zOsfPOOw+Hw8ETTzzR5dgXX3yRhoYGli1bFs+3IcY43VOH/5V7UaxOXJfcieocuSsvhBitZFmg6JWiqrSryTh1H4rNZXY4I4qWXQRASuAwrd4gqe6Rt3eopS1AmuqjrR/FSCyaisc1CUu0Ar3lMFrGpLjGEm2uQW/cz47UcwF4tjqX/0xWCO9ch/3UL8f1XmLsWLRoEdu3b6e8vJzy8nKeeeaZzteuvfZaysrKKC0tZcaMGZSWlpoYaSxhWrBgATNnziQrK4uGhgbWrFlDRUUFy5Yt47TTji63nTp1Kt/4xjd49NFHueWWWzj77LOpqanhySefJCsrq0tfr4yMDG677TZ+/etfc8stt3DBBRdw4MABHn/8cWbNmsWVV15pxtsVo5BhGATeXgm6jnPZHajDuL9WiPFMkivRJ1taNoa/WZadHEdNSidqTyE/2MieGg/zp/feR8osnqZmshQdW0r/foEqmUVQB9G6qrgnV5HdG0FRePVwDtMnp7LrEDS6CsneuQ7bKZehKDKJLrp77LHHAKiurmbbtm2d/5WXl/Pxxx/z8ccfd/5s0jQNp7Pvwi3D6brrrmPdunWsWrWKtrY2XC4XpaWl/OIXv+hWuALgzjvvZNKkSaxevZpf/vKXJCUlcf7557NixQomTJjQ5dibb76Z1NRUVq1axU9/+lNSUlJYvnw5K1as6LWSoBDHUxQFxzm3YAR90gpDiGEkyZXoU/rpl2IEzG3KOVJZc4oo8O3jsxGaXPlbYsuUHP0sRpI+KR9frRW9Zje2GWed+IR+MgyD8O738aeXUNto48rTC8jYUcdr+wq4xrGWaE0Flkkz43Y/Mfbk5+eTn5/P0qVLO8f27dvXLeFqa2sz7UHQrbfe2mXG6UQUReHaa6/l2muv7dfxV1xxBVdcIftjxMAZ4SDhyvewzjoXNSXH7HCEGPMkuRJ9shTMMzuEEcuSU0xu9WdUH6oHppodTjdBT6zYRlJ6/4qRFOSmUB3NovBIfPeR6XVVGG31bM9YgMOmMbsog5w0Jz8tP8SVLjuWyvckuRIDVlhYSGFhYeeeI8Mw2LNnD1u3bjU5MiFGDiMawf/GA0QPbkPLKUbLKTY7JCHGPFmLI8QgadlFKEC4fh/RYSoCMRRRbzMAmrt/xUjyc5LZH8nC2laDEem5r85ghHe/D5qFVw5lMq8kC6tFY1K2mznTJ/JxsIjI3o8xQr4TX0iIPiiKwtSpU7n88svNDkWIEcEwdALvrCRavRX7mTdKYiVEgkhyJcQgqdmFAEyknkP1I2/ppOKPlX5WXP3rgeNyWGi2TURFR288EJcYDD1KZM8mfJnAD1oCAAAgAElEQVQzafQrnDrj6JKUSxZNYb2vGKIhwlWb+riKEEKIgTAMg+D6vxCp+gDbaVdim7HE7JCEGDckuRJikFRnCrorg3ytkaoaj9nhdGMJtRJQnCiatd/nKJ9XQYzWxWdpYPRQOYbfwxZ9KvbPlwR2KJyQQmrBdI7oaYQq3ovL/YQQQoDeXEO4Yi3WORdhm3ux2eEIMa5IciXEEFhziii0NrLnUOuJD04we7iNoCV5QOdkT5hAc9RFKE7NhMO73webkzUHU5g7NRObVevy+qVnFLExMBWjvopoy8jsXySEEKONljEJ1/K7sZ9+lVT7FSLBJLkSYgi0nCIy1TYO1YysZrjhSBQ37UTsA2sQmZ/r5kA0i3Dt0JMrIxIksu9j2rPn0OwzOLW0e5WqksmptGTNI2ooBHfI7JUQQgxFePdGwrs2AKBlTJbESggTSHIlxBBoWbFldDZPNV5/2ORojmr2hkhR/eBKH9B5BZ8XtbD4GjAC3iHFENn/GYQDbA4XY7OqnDS155Lw5y+eRXl4Ev6KdRh6dEj3FEKI8Spy4DMCbz9CuPI9DGPkFVkSYryQ5EqIIdA+L2pRYIk1Ex4pWlp9JCt+NHfagM7LSLFTp+YCEK3fO6QYIrs3orjSeK3ayZypWdiPWxLYYcaUdPa7T8IWbiN0QMpoCyHEQEVqd+J//UHUzHycF9wmjdmFMJH86xNiCBR7EiRnf55cjZx9V97mBlQF7Cn9ayDcQVEUyCzAAKL1gy9qYQS8RKq30JYzj9b2CAtm9N64UlEUZp2xBK9up+6jNwZ9TyGEGI+ijdX41/wWxZ2Bc+kKFJvT7JCEGNckuRJiiCzZRRTamkZUxcBAc6yBsCujfw2EjzVxQhZHomlEhtBMOLz3I9CjfBQsxGZRmVPcd5J3UkkulZZS3I3lRH1tg76vEEKMN5HqrShWJ65L7kR1DmyfrRAi/ixmByDEaKflFJG6ZxNHao6gGwbqCNhAHGqLNRB2pA5s5gogP8fN/vJMcur2YBjGoDZER3ZvREmdwFv7NE6amobd1vOSwA6KopA5/1wsH2+hasPrTD9/+YDvKYQQ45F93sXYZpyF4nCbHUpCbdxey3Nrq2j0BMlMsbN8yVQWzppgdlhCyMyVEEOlfl7UIluv43Cjz+RoYoz2JgDUpIEVtAAoyI0VtVCDbRjexgGfr3sbiR7eiSd7Hq3t4R6rBPZk1slzqSULdc8GdMMY8H2FEGK8MILt+F7+NdGGfQDjMrFa9WoFjZ4gAI2eIKterWDj9lqTIxNCkishhkzLmoKBQoHWMGL6XSmBVnQUlEEsEZmY6eKQng0Mbt9VpOoDwGBToBCLpjKnlyqBx1MVBaNoIbk0UL55y4DvK4QQ44ERCeJb81uih3diBNrNDscUz62tIhTpWhExFNF5bm18ejQKMRSSXAkxRIrNiZo2YUTtu7KEPPhV96AqRlk0FdInEUEjWjfw5Cq8+33U7GLe3RPhpOIMnPb+rz4uPuOLRAyVxk/exJDZKyGE6MKIRvC//iB6XRWOc7+FZfIss0MyRceMVX/HhUgkSa6EiAMtu4gp1qYRUTHQMAycUS8ha/KgrzE5N40aPYNo3cDKsUebD6E3HqAlay4t3hCn9lElsCcWVwrezJlMj1SyrapuQOcKIcRYZhg6gXceIVq9Bfvir2MtXmB2SKbJTLEPaFyIRJLkSog40LKLSDLaaWuoxx+MmBqLPxglGR+6I3XQ18jPdbM3lEm0fu+AGvtGdr8PisImXwEWTWFeycCrFeacej5uNcj2dWtl9koIITpEIxhhP7bTrsBWdrbZ0Zhq+ZKp2CxdP8LaLCrLl0w1KSIhjpJqgULEgZZVCEC+pYG9hz3MLMwwLZYWb5BU1YffNbAGwscqyHHzRiQLJVqB3lKDlpF/wnMMw4gtCcybyfoqH7OLMge0JLCDreAkvNZkCrxbqDzwRWZMGXhRDiGEGEuMaBjFYsN5wfdhBFSkPZHhruTXcS2pFihGIkmuhIgDNasAFJUCSyNVNeYmV62tbUxWQ0Tcg48hPydWMRAgWrenX8mVXleF0VZPa/EFNG0Nsvys4kHdW1E1nDMWM3PLGp5ct50ZUxYP6jpCCDEWhLa9TrhyHa5L7hwVVQE7Kvl1FJzoqOQHxD3BkmRKjESyLFCIOFAsdtT0SZQ4W0yvGOhtagDAnjbwHlcdXA4LJGcTVBzo/SxqEd69ETQL77dNRFMHtySwg6PsLDTFILV+M7tHSAVGIYRItPCuDQQ3/AU1ORNsTrPD6Rep5CfGO5m5EiJOtOxCJrV8RFVN66Cb78ZDsDWWXCWlDz65ASjITeFgQzbT+lGO3dAjRKo2oRXM4/2dHmYVZeByWAd9bzVtIkr2VBYeqeKl9Xv5wVfnDfpaQggxGkUOfEbgnUfR8spwnPttFLXvZuwjxXBV8uvPUkNpLCxGApm5EiJO1OwiHLofW7CF+ha/aXGE25oBsKUOfuYKYvuudvvT0ZsOYYT7/qUYPVSOEWijOXMejZ5AvxsH98U240xy1RZa9u9kf23bkK8nhBCjRbR2F/7XH0TNzMd5wW0oFpvZIfXbcFTy60/TYGksLEYKSa6EiJPOohZaI1WHzOt3ZfhaAFCHUNACYhUD90eywNCJNu7v89jw7vfB5mJjcxaaqjB/+tBmzQCsU08HzcYZrj28tHHfkK8nhBCjheLOwDJ5Fs6lK1BGyXLADoOt5Ldxey13PrSem371Fnc+tL5LUtSfpYayHFGMFLIsUIg4UTPzQdUosjVRVdPKwtnmLEXQAq2EsYDNNaTrFOQkcyASm/3S6/bAhOk9HmdEgkT2fYKl+DQ2lTdRVphO0hCWBHZQbE4sRadwyp5P+FvlYQ41FDMpK2nI1xVCiJFK97WgOFNQ3Zk4L/y+2eEMymAq+fVUBOOxl8pZ/cZOvP7e25scu9RQGguLkUKSKyHiRNGsqBn5TGtu4Zka82aubOE2Alb3kPd8ZaTY0e3JtGspWPooahHZ/ymEAzRlzqWhtZVLFxUO6b7HspaeSWT3RuY7D/HKxn3ccumsuF1bCCFGEt3Xgu+FX2DJPwnH4hvMDmdIBlrJr6dZp6hBn4kVdF1qmJli7zGRksbCItEkuRIijrTsQnKbNnKoro1gOIrdmtgNyLph4NTbCNtShnwtRVHIz3FzyJeDu35vr8eFd21EcaWxsT4FVfEwf3r2kO/dQcubgZKcxQXOA/y8fAqXLS4iJ31oM3JCCDHSGMF2/C/fh+H3YJ0+etpPDKWAxLHnDsbxSw2XL5naZfarp2OESATZcyVEHKnZRVj1IOl4TCnC4PWFSVH8GI7UuFyvIDeZSl8aRls9ur/7bJwR8BKt3opl6ul8tLOBssJ03M6hLwnsoCgq1mlnkOnfR6bm45X3+977JYQQo40RDuJb81v01lqcF9yGljO4HoGJNpQCEsefO1CZKXa+vnRGl0Ru4awJfH3pjM6Zqp6OESIRZOZKiDjSsosAyLc0UlXTyvT8oRWVGKiWtgCpqg9vUnpcrpef4+adzZngAL1+L2rB3C6vh/d8CEaUxoy51DXXsvT0grjc91jW6YsJffICVxTUsXJrEpcuKiIz1RH3+wghhBkCb/8Jva4Kx3nfwTJ59Cx97quAxIkSmp7O7Y/MFDv3fueMXl+XxsJiJJCZKyHiSE3PA81KaVILe0yoGNja3IJNiWJNyYjL9Qpyk6mOZGCgEO1h31Wk6n3UtIl8UGtDVZS4LgnsoKZko+WVMSOyAzBY88GBuN9DCCHMYj3pAhxn3YS1eIHZoQzIUApInOiYJIeGReu6b1iW+InRQpIrIeJIUS2omQUU25rZ/Xkz4UTyNccaCDvThl4KHWBipgtds+G1ZRM9bt+V7m0kergSy9Qv8GFlA6UFaaS4hqcXi3X6YhRvPZdNi7D2sxpavVL9SQgxehmGQfTIbgAsE0uxlp5pckQDN5R+Vn2d+9iPz+V/f7CEb1xcJkv8xKgkyZUQcaZlF5IZrcPjDdDcltgkIOhpAsCVHp/kyqKp5GUlccjIRq/b0yVZDO/+AIDG9DkcafJx6oyhNw7uNY6iU8Hq4IykPUR1nTWbZPZKCDF6hT5+Ht8LPydSs8PsUAatp35WAMFw9IT7rno7d87UzM4/L5w1gXu/cwaP/fhc7v3OGZJYiVFD9lwJEWdadjHa9jfJUT1U1XjISEnc/qCoN5ZcWZPjsywQYv2uKvalMcO6HaOtHiUllkRFqjai5hSz6ZCBosDJw7AksINitWMtPg2qPmDxjC/w1ieHOOfkyeSkja7mmkIIEdr2OqFPXsAy/Uy0iTPMDmfAjq3y53ZaAINQ5OiDN68/wmMvlQNHe14de46qgG7QYzXddz+t4cOKOrz+yICrDwoxUsjMlRBxpmYXAlBkb6LqUGtib+6L3U9xxa+QRn6um52+WIGMjn1X0aZD6I3VWEsW8mFFHaX5aaQmDc+SwA6W0jMhEuTLhU2oqsJTr+9M+LJLIYQYivCuDQQ3/AVL4ck4zrpxyP0IE23j9loee6m8c8+U1x/pklh1iBrw1OuVneccWxlQ//zwYDja43kdva0GUn1QiJFEkish4kxNnQgWOzOTPVTVJDa50kKtBBU7iiV+iU5BjpvD0TR01dq57yqyeyMoCg1pszncOLxLAjtouSUoqROw7X+fyxcXsaWqkc27Gob9vkIIEQ+6p47AO4+iTZyB49xvo6iJ7YMYD0+9Xkm0n8+02gOx5GmwlQHhaPVBIUYTSa6EiDNFVdGyplCgNbK/1kt4kL9UBsMebiNgSY7rNfNzktFRaXNM6Nx3Fa56H23SLD7cF0ABThnGJYEdFEXBOn0x0dqdnDPNyuTsJJ56YyfBUPenn0IIMdKoKTk4zrkF54Xfj+sDsETqSJj6a+P22kH3suow1POFSDRJroQYBmp2EanhI+jRCNV13oTcMxLVcRntRGzxaSDcweWwkJXqoIYcog37idbuxGhrwFryBT6qrGNafhqp7hNXh4oH6/QzQFEwdm/gugtKafIEeXHD3hOfKIQQJok2VndWBrSWfAHFNn72iq78R/mQr9Gf6oNCjCSSXAkxDLTsIlQ9wgStJWFLA1u9IdJUHzjjm1xBrN9VpS8doiFCm54FzUpDygwO1bdzaunwz1p1UJPS0fLnENqyhsLmjSyenctrm6o51NCesBiEEKK/dE8d/lfuJfDOIxj66J9ljxWwSBzpbSVGI0muhBgG2udFLWYktSasqEVLm59kxY+alB73axfkuNnaEltuGD2yC8uUeXxU1QbAKaXDv9/qWI6zvoE2aRbB9//KFaG/kW9v48l/VkpxCyHEiKL7WvC9fC/oOo4LbhuVe6w2bq/lzofWc9Ov3uJ7968l3EPxiuEiva3EaCWl2IUYBkpKDticlDla+WuNJyH3bGtuIkcxsKfGrwx7h/xcNw26G92ahBpux1KykA/frKdkcirpyYldsqG60nBe+H0iVe8TWP8k33e9yMv1J/H+tlwWnjQpobEIIURPjGA7/lfuw/B7cC37EVp6ntkhDVhHlb+OYhQD3W81FI/9+NyE3UuIeJOZKyGGgaKoaFmF5CkNNLQGaPUO/4ZcX3M9AM60+DQQPlZBTjKg4HFNBpuLxqSpHKz3siDBs1YdFEXBWrKQpCv/G2vhPC51bSZ9/W9pPyz7r4QQ5gttfQ29pRbnBbeh5RSbHc6gDKXK31DIHisx2klyJcQw0bKLSAocQSPKngTMXoU9zQA40+OfXGWk2ElyWNjkOhvXxXfw0a7YvU5J4H6rnqiuVFxfvJW2U28iBS+Rf/yM4McvYOgRU+MSQoxvtpO/hOtL/45l8iyzQxk0s6r0yR4rMdpJciXEMFGzi1CMKJOtLVQlILnS22MJj+aO/7JARVHIz3Gzo9mKllPMR5V1TM1LISPFEfd7DUbeyWexvvDbbA4WEPr47/j+fg/Rhv1mhyWEGEcMQyf44d/Q25tRVG3Uzlh1MGMGqWxKWp97rI7dA3bnQ+ulwbAYkSS5EmKYdBS1mJfuTUxRC38LOgqKM2VYLl+Qm8zBOi+1TT4OHPEmpHHwQFxy9mye53xetCzF8Hnw/f2nBD/8G0Y0bHZoQogxzjAMghv+QmjzP4js+8TscOJiztTMhN1LVeCc+XncefXJvR7TsQesY0at0RNk1asVkmCJEUcKWggxTBR3FordzVRrC68e9hDVdTR1+J5nWEMeAqqL1GGqSJWf4yYU0Xl54z7A/CWBx3M5LHzt3BL+9I8Q+efdygLvO50fdBxLbh71T5GFECNX6JMXCG9/E+uci7DOHFnFGDZur+W5tVU0eoJkptiZMzWTLVWNnf+/fMnULrNFG7fX8tTrlQkpYJHk0PjfHyzp17E97QELRXSeW1slFQXFiCIzV0IME0VRUHOKyNWPEArrHKof3l5MjkgbIUvysF2/IDd27Q3baimamEJW6shrhHn6zFxmFKTxzLrDhE/7Os6LbscI+fC98DOCHzyDEQmZHaIQYowJbXuD0MfPY5l+JvbTr0JRFLND6tTTbM/bm2t6nf25d/UnrPxHecIqAw7kPr3tATNrb5gQvZHkSohhpGUVYvfXYSUyrPuugqEobnxE7cOzJBBgYqYLi6ZgGHDqjJE1a9VBURSuu6CUYDjK/729G0vBXJKu/AXW0jMJffYKvr/dRbR2l9lhCiHGCCMaIVyxFkvhyTjOunFEJVbQv4p/oYjOyn+Uc+kPX2DH/pYERRYzkH1dai9f2t7GhTCLJFdCDKNYUQudaUke9gzjvqsWb5BU1Qeu+DcQ7mDRVPKykgA41aQS7P2Rl5XERacXsH5bLZUHmlFsLhxn3YTz4jswomF8L/43gY2rMSKyF0sIMTSKZsF16Y9xnPvtEdkkeKTP6uSk938FhN5L/+LexoUwiyRXQgwjLbsIgLnpXnYP48xVi6cdtxrE4h6+5Arg5OnZnDw9m+y0kbck8FjLFhWSmeLgydd2EonGntpaJs8m6YqfY515DuGt/6Tx7SdNjlIIMVpFanfif+uPGJEQij0JxWIzO6RuRkOhhx37W/jzPyv6dWxvs1zSF0uMNJJcCTGM1KR0FFcaRbYmjjT58PqHZ7bE2xhrIGxPHd7qTl86o4hbl580rPeIB7tV45ovTuNQQztvfHSwc1yxOXEsvgHrzHPxbHqZaF2ViVEKIUajaGM1/jW/JVq/FyMycmeGVr+x0+wQ+mXtpzX9Om75kqnYLF0/ttosqvTFEiOOJFdCDDM1q5CM8BEA9tQMz9LAQGsTAK70kbkXygzzp2UzrySLF9btpckT6PKa/bQr0ZIzCKx9DCMqDYeFEP2je+rwv3IvitWJ6+I7UB3DV0RoqLz+0fGzrb/L+hbOmsDXl87onKnKTLHz9aUzpFKgGHEkuRJimGnZRVjb63AqEXYdHJ7kKtwWS67sqfFvIDyaXX3+NAzDYPWbXYtYKDYn2Uu/hd58iNCnL5kUnRBiNNF9LfhevhdDj+K8+IeoyVlmhzQmDKQgxcJZE7j3O2fw2I/P5d7vnCGJlRiRpM+VEMMs1kzY4JQc37AlV3p7MwDqMO+5Gm2y05wsW1TIc+/uYUtVY5emmK5pp2CZ+gVCm/+BpWgBWsYkEyMVQox0Rnsz6FFcS3+Ilj6wnxf3rv6kSyW+silpfTbMPd7xvao6lsIdP9aRbPR3H9NIsGRentkhCBFXklwJMczUz4tazEpp4/09HsIRHaslvpPGaqCVKCqK3R3X644FF55WwIZttfzl9UpmFJyOzXq0opd90TVED24j8O6juL70nyjD2ORZjG3vvPMOq1evprKykqamJpxOJ1OmTOHqq6/msssuQz3u79bhw4d5+OGH2bBhA3V1daSnpzN37ly+9a1vMWvWrG7Xf+6553j88cfZu3cvqampfPGLX+T2228nJWX42i+MJz0lL186O7bkz9B1FFVFyy4i6Wv/g6JZB3Tt4xMriBVyuHf1J90SrN6SqFWvVnSWVG/0BFn5j/Iu53X0qwLYfbCFtzf3bx9TInU0MF77aQ26EZuxWjIvj+svnGF2aELElSRXQgwz1ZmC4s5kstpIJDqJ/bVtlExOjes9bGEPfksyaSOsx8pIYLWoXHfBdO57+lNeeX8/l59Z3Pma6kzBvugaAm//ifD2N7CddIGJkYrRbOfOnVitVq666iqysrLw+/2sXbuWH//4x3z66afcc889ncc2Njbyla98hUgkwte+9jXy8/M5fPgwTz/9NG+//TZPP/10lwTr8ccf55e//CVnnnkm1113HQcOHGDVqlVs3bqVp556Cptt5FWqG006Gu0em7yserWClGQHMye58b/2e7TcEuwnf2nAiRXQa++o48d7i8NmVU/YqwqO9qsaiToKTyycNUGSKTHmSXIlRAJoWYUkNxwA5rDzYEtckyvDMHBEvYScI3djtdlmFmZw+sxcXnn/AAtnTSA3w9X5mqVkIdru9wl++CyWwvmoyVIURAzcN7/5zW5jN9xwA9/61rd45plnWLFiBampsX/3//jHP2hsbOShhx7ivPPO6zz+jDPO4JprruH555/vTK6ampq4//77Wbx4MStXruxsUltSUsKPfvQjnn32Wa655poEvMOxq6dGu6GIzhOvlnN3yRai1VuwFPZ/Cd9AHTtbdbxQRO9XYmUGi6YQifZejSIzxd7jkkUhxjpZAyNEAqjZRSjeegozNHZV9/wUc7DaAxFSFB+6I76zYWPNVeeWYNEU/vL6Tgzj6AcCRVFwnPl1UFQC7z7e5TUhhmrixInouo7X6+0ca2trAyA7u2sin5MTa87tdB7tI/fmm2/i9/u54YYbOhMrgEsvvZTMzExeekkKsgxVz412Dc4KriVS9T62067AVnb2sNy7Y7ZqpDf77ck3Li7D7ez5GX1mil0KT4hxS5IrIRKgo5nwKVnt7D7Uih7HD/At3iCpqg/VlRa3a45FaW47Xz6rmG17m/i4sr7La6o7E/tpVxI9tJ3IrvUmRSjGAq/XS1NTE/v37+epp57iueeeo6SkhLy8o5v2Fy1aBMDPfvYzPvroI44cOcLmzZv593//d7Kysrjqqqs6j926dSsA8+fP73IfTdOYM2cO5eXl8kBgiHpqQnuRYwtnOSqxzrkI29xLBnS9jdtrufOh9dz0q7e486H15GX23nR95T/KR+zMVF8yU+wsnDWBq8+fLr2nhDiOLAsUIgG0nGJQFErt9fwt4KCmoZ3J2fEpPtHa3EqBEiGUPLwNhMeCc0+exPoth1n95i6WLCjo8pp15jlEqj4gsHE12uSTUF0yEygG7vvf/z7r1q0DYrOiixYt4p577uky63TKKadw9913c//993Pttdd2js+YMYNnnnmGSZOOVqKrq6vD6XT2WLhiwoQJ+P1+WltbSUvr/8OVzMzB/+zJzh57y49vXDaLB/7vM4LhaOdYm+LGO/kLnLTsXzq/d+98XM0Tr+6godlPVrqTG5aWcfYp+V3G3S4r/mCkc7lcoyeIRVNQFBgrObDdqnHjsllkZyfzpbOTSUl29Ph1OV5vXz8zjcW/z32R95sYklwJkQCKzYmaVUhW4ACQz67qlrglV+3NsVkYR5okVyeiqSrXX1jKf//5Y37/zKfcdFFp5wcnRVFxnPUN2v/2XwQ3PInz/O+aHK0wg8fjYdWqVf061uVycfPNN3cZu+OOO7j55pupq6vjrbfeorm5GZ/P1+3crKwspk2bxqJFiygtLaW6uppHHnmEm2++mSeeeKJziaDf7++1YIXdHptxCQQCPb7em8ZGL3p/O7ceIzs7mfr6tgGfN9LNKkjrTKxcShCfYWe9fyrrtwB3vNjjOfXNfn7z1Cf85qlPuoy3+cLdju1rX9Jo07F/alZBGvX1bd2qG16+uKjztWMdX6yjvtnP/z7zKZ62gGlLBsfq3+feyPvtH1VVhvQACiS5EiJhLHll6Fv/SZZ7MbsOtnLOyZPjct1ga6yBcFKGNLTsj6mTUrnynBKeeXs32cl2vrS4qPM1NW0itpMvI/Th3wjv+xhr4SkmRirM4PF4eOCBB/p1bFZWVrfkqqysrPPPl19+OXfddRfXX389a9asISMj1uT7tdde43vf+x6PPvooixcv7jz+jDPO4PLLL+d3v/sdv/jFL4DY/qtQKNTj/YPB2D4dh8PR/zcournpV28BMNN6kK+73+OPbeeyJ5JrclQjzy2XzuySCPVW3RDoljD1VjTkubVVsh9LjDmSXAmRIFpeGXz2Cguz21l/MH5FLSJtsQbCNlkW2G8XnpZPQ1uQ59ftJS8riVNn5HS+Zpu7lMieTQTX/RnLxBko9iQTIxWJNnnyZCorK+N2vWXLlvHXv/6VN954g69+9asAPPHEEyQlJXVJrACmTZtGcXExH374YedYTk4Ofr8fj8fTbWlgbW0tTqezswqhGLxiyxG+4V7LkWgqhyLSjL0nQ0mYeivYMRoLeQhxIlLQQogE0SZMA0VjpuMITZ4gja0DW8rTK38suVJkj1C/KYrCd6+Yy9RJKTzyUjn7a48uHVBUC46zbsbwtxL84K8mRinGgo4le62trZ1j9fX1GIbRYyGKSCRCJBLp/P+TTjoJgM2bN3c5Ttd1tm7dSllZWZf9XGLg8rQmbnG/RbOexB/aziPI6OobpqnD//3vqejHQBKmns7va1yI0UySKyESRLE6UHOKyAlVA7AzTrNXWtBDCBuKrfeKVKI7m1Xj1i+fhNtl5fd/20Kr9+gHAi27EOtJFxGueJfIoZHZlFOMLA0NDd3GDMNg9erVAMydO7dzvKSkBJ/Pxz//+c8ux3/22Wfs27eP2bNnd46dd955OBwOnnjiiS7HvvjiizQ0NLBs2bJ4vo1xR29v5l+T3yCElYfbzsdrjK6fo26nhZX/dg5DTa8yU+z88JqTsWjdr6Qp9Fj9byAJ0/IlU6WqoBg3tLvvvvtus43dnAgAACAASURBVIMYyfz+0Jip8GOmpCQ7Pl/P+wbGE6OtHmPfR2yMzsLmcDC3ZGD7pHr6Oh7e9E+SrAYp8y+MZ6hjXlKSnWgkSml+Om9tPkjlgRYWzspFU2MfALQJ0whXfUhk/2asZUtQVFlFfTxFUXC5RtdT/uFy7rnn8tlnn7F//3727dvHu+++y89//nM+/PBDli1bxk033dR5bH5+Pi+88AKvvfYajY2N1NbWsmbNGn7+85+jaRr/8z//Q1ZW7GeD0+nEZrOxevVqtmzZQjAYZM2aNfz2t79lxowZ/OQnP0HTtAHFOtjfa2Py57jFzpsbKnjedyoNeveKjCOZzaJy3YWl5Oe48bQH2Vc7uGIFNovK1edP58JFRbisKjurmzuX+yU5NG5YWtbjvqhkl41texqJHlMcpeNa+TldCwLk57jJTHWwv9aDPxglM8XO1edPN3W/1Zj8+9wHeb/9E4/fa4ohDTL6NNiqSqKr8ValpjeRQ+X4X/41a1yXsSWQx8/+5fQBnX/811HXDbb94d9ISUmi8Lq74xzt2Hbs1/Kjijoeen4bC2dN4F+WHV1mFanZgf+l/8E65yIcX/iameGOSPGoqjRWPPDAA6xbt459+/bR1taGy+WitLSUyy+/nOXLl6OqXZ/a7969m4ceeojPPvuMI0eO4Ha7WbBgAd/97neZMWNGt+s/++yzrFq1in379pGSksL555/PihUrBrXfSqoFghFsxwgHUN2xvaodRS2Gi82iEIr0/jVXAJtVIRg2sFkUwlEDwwBVgdKCNOqa/TR6gqgK6MbRqn3HJid//mcFaz+tQTe6Xq/jnCSHhqIoeP2RHq8zmO/v8dUCj49pJBtLf5/7Q95v/8Tj95okVycgyVV8jLd/1L0xIiG8j3+H6vQF3LdrGr///pm4ndZ+n3/817HFG6TlzyswcqZR9JUfDEfIY9bxX8sX1+/l+ff2cuXZU1n6hSmd44F3HydcuRbXZf8V61cmOklyNTqN9+TKiATxv3wfur+VpCv/G0XreVZ6rLzf/pL3O7bJ++2fePxekz1XQiSQYrGh5U4lNxzbd7X7YOsJzuhbs8dPqupHc0l1q6G6dFEhC2bk8Ow7VXy6++j+GfsXvoriTCXw7mMY0UgfVxBCjHSGHsH/+oNE63ZjP+3KXhMrIYQYLEmuhEgwLa8Mm+cQbi3EriEWtfA0N2NRdKwpGXGKbvxSFIWbLimjIDeZP764nYP13ti4zYVj8dfRmw4S+uwVk6OMH8PQiRzYIgmjGDcMQyfwziNEq7dgX/x1rMULzA5JCDEGSXIlRIJpeWWAwcJsz5ArBvpbGgFwpksD4XiwWzW+95WTcFg1fv/sFto+3wxrKZyPpfg0Qp+8SLS5xuQoh84wDIIbV+Nf8/8Rrd9jdjhCJER4yxoiu9/HdtoV2MrONjscIcQYJcmVEAmm5RSDZmW2q4F9h9sIhaODvlaoNZZcudKlgXC8ZKQ4uPUrJ9HiDfHQ37cRicaqZtkXXQtWe2x5oKGf4Cp903WDndUthCP/f3t3Ht5Ulf4B/HuTNElXCnRnKRRIWygUhIKCIFUWhYKIsgwCgoA4bAqyuM2AP0ccAWfYcRiUpWyCw9aKHawsyi4FgUKBQkVKoRstTdukSZPc3x+dRmJbaKHNxvfzPDyac89N3pOe3Js399xzHv5v/yj0Z/agNPl7uET0gdS/lU1iILI2l/CeUDz9GuSR/W0dChE5MSZXRFYmSF0gDWiFwNJ0GE0ifr2tfujnMhaXLSAs8+CwwNrUIqgexvYLw+X0u9i47wpEUYTErR6UT42AKesqSi883MxiBqMJP569hff/fRx/33QaC7acgbrYulPj6pMToT+1EzJVNyieGs4FaMnpGX47A9GggyB3g7x1NPs8EdUpJldENiANDIOi+DbchRJceZRJLbRlwwoFN+9aiozKPdUmAP2fCsaPZ2/hh6SbAABZq66QNo6A7udvYCqsuGhsVXSlRnx/Kh1zvjiGdd9dgqtChkHdmyM9qwgfrz9lvr+rrpVePQbd0Y2QBXeAssfrEASeAsi5laYehfa/S6D/xXnulyQi+8YzK5ENyBq1BgB0bnAXqekPf9+Vi04NreDKGa/qyEs9QtC+pQ+2/nAVF37NgyAIUHYfA4gitPuWoDT1KMRSXZX7a0oM+PbYdcxedRRbElPhW0+JGUMj8dfXOmFgt+aY8+oTMJhMmB+bhHPXqp+sPQzDjV9QcmANpIFhUD73ZwiSmi08S+RoDDfOouTgl5AGhkHenkMBicg6mFwR2YDEtxkgU6Cdey6uZhQ89FpqCkMhdDLP2g2OzCSCgAkDWiPQxw2rdiUjM08DiacPlD3HQ9RpUHJgNYpip0K7/19lM++Zyu6hKtTosePHa5i16ij+cygNwQGeePfVJ/DuyI6ICGloHpbUPNALfxndCX7erljyzTl8fyoddbH0oOH2ZWi/XwFJwyZw7fsWBJkcuQVa7PjxGtQPsYI9kb0zZKb+r883Nvd5IiJr4M/dRDYgSGSQBqoQlHsTJfoIpGcXITigZklSqcEEd7EYRgUns6hLrgoZpr3cDh+vP4Ul35zDh6M7wj0kCrLmHWHMTIXh6jGUpv0Mw9VjEBWeuK4MQ9wtP6TpGuCJUD/EPNXsvn/bBl5KvDvyCfw77iK2JKbi9h0NRvRqBZm0dn77Mub+Bm3CYkg8GsL1hRkwShVIPHEDuw6XzRL4ZOsAeLnxiyc5D9FkRMmhNRA8GsD1hXcgyF1tHRIRPUaYXBHZiDQwHMr08/AUtEi9ebfGydXdIh3qSTQodQ2powipnK+3Kya/FIFFW3/BF7sv4O0h7SCVSCALDIUsMBTq8Jfwy4+H4HY7Ca1LkjDN3QSTvx+UTbvCxdUPwP3/tkq5DJMHt8V/Dl7DdyduIDtfg0mDIuCmdHmkuE0FmdB+9zkEuStc+89CWp6IDQk/42ZOMdq39MGrvVVoWE/5SK9BZG8EiRRufd8GpC6QuHrZOhwieswwuSKyEVlQGPQA2nvdwZWbBejVqUmN9r+r1sBPKIGaMwVaRWjT+hjVNxTrvruEr/dfxYheKmTkFOHb47/hxMUsSCVe6N5uBEI7+MIzLxmlV49Bn7Qb+qRdkPi1gEvLpyBr0bnKL3sSQcCQ6JYIaOiGDQmX8UlsEqa90g7+9d0eKl5TUR403y4ERBHoNR2xh3Nw6JdbaOClwNTBbdFB5fsobweR3TFp7sJw9Thc2vaFxDvQ1uEQ0WOKyRWRjUh8ggG5K9pLc7Hh5l2IolijKYKL8u4gQADk9ZhcWUuPyCBk5BTj+1PpuJ5ZiKs3C6BwkaJvVFP06dwE3h6Ksop+PeAS1gOmojwYrh1HaWrZLH26Y5shbdK2LNFq1gGCTFHhNbq3C4KftyuW7ziPv60/hSmD2yK0af0axSmWFEH73SKIumJcDR+PdV9fR7HWgD5RTTCoe3Mo5Tz0k3MRdcXQ7l0EkzoHsmZPQPDys3VIRPSY4hmWyEYEiRTSgFA0zkpHQZEeOXe18KvBVQrt3bLZ5dzq+9RViFSJoc+2QFa+BldvFmBgt2bo1akJPFwrH74n8WgAeWQ/yCP7wZiXDkPqMZRePY6SG2cBFyVcWnWFPKIPJN4BFvuFNq2PD1/rhCXbz2HR1l8w+vlQdG8XVK34RL0Wmu/+AWNBNvYoXsT+g2qEBHnhnWGhaOrPyU/I+YgGHbQJi2G6mwnX56dDwsSKiGyIyRWRDcmCwqG88QvqCcVIvVlQo+SqtDAPAODqzeTKmqQSCaa93A4mUazRpBPSBk0g7dIE8s6vwHj7Ckqv/ITSSz+i9OIBSJu2g7xtX0iDws1XL/3ru+GD0R2xalcy1u69hMw8DV5+pgUk97m6KRpLofnvUhhyfsW64p64UlgPo/q2wDORQZBIuHAqOR/RZID2+xUwZl+F8rlJkDVuY+uQiOgxx+SKyIakQWEAgAi3HFxJv4tubat/n4CpOB8AIHGv2ZAxenQSiQAJHi5ZEQQJZEFhkAWFwdR5CEovHkDpxf3QfrsAkoZNII/oA1nLJyFIXeCudMHbQyKx+fsr+O74DWTe0eCNAW2gkFdco0o0GZEdtxRu2SnYXNQNri07Yf6zLVHPo+LQQyJnYcy6BuPNC1A8/RpcQqJsHQ4REZMrIluSNGwCKNzRXn4H/7lZULN9tQUwQgLBlUO9HJXEzRuKTi9B3r4/Sq8eQ+n5fSg59CWEk9vh0vo5uLSOhszVC6P6hiKwoTu27k/Fp5uSMO3ldmjg9fssfwVFOvy6czlaaM9jn9gVPV56GW2a8148cn6ywFC4D/uUQwGJyG4wuSKyIUGQQBYYhqYZ15CZp4G6WA8v9+qtOSQrVaNE4g5vgWuBOzpBJoc87Bm4hPaAMeMC9Of3QZ+0E/pf4uDSsitc2vZB76gm8G/gii92X8DHG05h2svtEBzgiUNnMlB8dCuekSfjV5/uGDhwDFxkFa9sETkT3ek9kHgHwiUkiokVEdkVJldENiYNCofyehIaSAqRerMAHUOrN0W20lAEvTuvWjkTQRAgaxwBWeMIGO/eQun571F65QhKL/8IaaM2aN22L95/tQOW/CcZn206jcCG7mhVcBQD3JKhD+mBts+NrdGMk0SOSJ+cCP2pHXAJ7c6hgERkd/iTN5GNSYPCAQBh8myk3rxbrX20OgM8UQyjsl5dhkY2JPUOgrL7a/B49R+QR70CU34GtAn/QP1Dn+IvXYsR4qdEiOYXDHA7A1mLJ9HguTFMrMjplV4tW9ZAFtwBiu5jbB0OEVEFvHJFZGOS+kEQXL3QXpqL76qZXN0t0qGeRAOtm3cdR0e2Jig9oOgQA3m752FIOwn9+X0QTm7CnxXugIsG0ibtoIweD4HDQ8nJGW6cQ8mBNZAGhkL53J8hSDj8lYjsD5MrIhsTBAHSwDAE30jBb5mFKNEbHrjIa0G+Go0kpSj14KQFjwtBKoNLq66QtXwKxswrKD2/D4AI5bMTIUh4KCfnZ8y8AkmDxnDt+zYEWfXuTSUisjaekYnsgDQoDMq0k2goqJF2S43Wze6fNBXlly0grKzX0BrhkR0RBAGywFDIAkNtHQqRVYiiCEEQII96GfIOAyC4cHkBIrJfHEdCZAdk/7vvqpVLJq6kP3hooK6gbAFhtwZcQJiInJdJnQ3Njnkw5qVDEAQmVkRk93jlisgOCPUCILh5I1LIxcFqrHdVWliWXCl45YqInJRJcxeabxcCei0g8P4qInIMvHJFZAcEQYA0KBzNJLdx7dZdGIym+9YXNWVXtyTu9a0RHhGRVYm6Ymj3LoKoVcP1hRmQ1g+ydUhERNXC5IrITkiDwqA0FqO+MR/p2UX3r1tyF6WQAS6uVoqOiMg6RIMOmoR/wnT3Nlz7TIPUL8TWIRERVRuTKyI7UZP7rlxKC1Ei9eC6RkTkfEQRgosSymffhKxxG1tHQ0RUI0yuiOyE4OkLwaMhItxykHqf+65Mogg3UxFK5V5WjI6IqG6JogmiQQfBRQnXF96BS0iUrUMiIqoxJldEdqLsvqswhEgzcfVmPkRRrLRekbYUXoIGJmU9K0dIRFQ3RFGE7uhmaOI/K0uweFWeiBwUkysiOyILag2FSQsPXQ4y8zSV1rmrLoGXRMvJLIjIaehP70HphURIA1SAlAsEE5HjYnJFZEekQWEAgFay21UODVTfvQu5YISL5/0XGiYicgT65ETok3ZCpnoaii7DeNWKiBwakysiOyLxaAjByw/hymykVjGphSY/FwCg9OYaV0Tk2ErTTkJ3dCNkwR2g7DGWiRUROTwmV0R2RhYUhhayLKTezKt0u66grNy9ga81wyIiqnVSn2aQqbpB+dyfIUi4UDAROT4mV0R2RhoUDrmog7zwNvILdRW2G4vKkisOCyQiR2VSZ0MURUi8/ODacwIEGe+zIiLnwOSKyM5I71nvKvVmJUMDNWVlgpu3NcMiIqoVxjvpKN4xD/qknbYOhYio1jltcrVx40YMHjwYERERmDZtmq3DIao2iZs3hHqBCJVnVTqphVSvRomg5C+9RORwTOpsaPcuguCigEtoD1uHQ0RU65w2ufLz88OkSZMwdOhQW4dCVGOyRuFoIcvCtRt3KmxTlBZCJ/OwQVRERA/PpLkLzbcLIZoMcO03ExJPH1uHRERU65w2uerTpw969eqF+vW5FhA5HmlQGOQoBfJvQFNiMJcbjCa4iUUwyL1sGB0RUc2Iogna/y6BqFXD7YUZkNZvZOuQiIjqhMzWARQXF2Pt2rVITk5GcnIycnJy0LdvXyxdurTS+iaTCevWrcPXX3+NjIwM+Pr6YuDAgZg0aRIUCoWVoyeqG9LAsvWuWsoyce1WAdqGlE27frdQh3oSDQyuwbYMj4ioRgRBAkXUywAAqV8LG0dDRFR3bH7lKj8/H8uWLUNycjIiIiIeWH/+/Pn47LPPEBERgblz56Jnz55YvXo1pk+fboVoiaxD4uoFeDeC6g+TWtzJL4aXUAKpO6/IEpH9E00GGDIuAgBkjSMga/zg8zwRkSOz+ZUrPz8//Pjjj/D39wcAhIaGVlk3NTUVGzduxNChQ/Hxxx+by318fLB06VIcOnQIzzzzTJ3HTGQNLo1bI+TuAey/kQeg7Jfeu7k58BVEyL04DTsR2TdRNKHk4BoYrp6A25C/cSggET0WbH7lSi6XmxOrB4mPj4coihgzZoxF+ahRoyCTyRAfH18HERLZhjQoDC4wwJiThlKDCQBQlJsNAHCtzwWEich+iaII3dFNMFw9Dnnnl5lYEdFjw+bJVU0kJyfD09MTLVpYjtf28vJCSEgIkpOTzWUGgwE6nQ4GgwEmkwk6nQ56vd7aIRM9NFlgGEQIaC7cxm+ZhQAA7d1cAIBbg4a2DI2I6L70p/eg9MIPcGn3POSR/W0dDhGR1dh8WGBNZGdnV3mVKyAgAElJSebHq1atwvLly82P27Vrh86dOyM2NrZGr9mwIae8ri2+vp62DsHBeELrG4xWtzJxK1+Lpzo0Rqk6DwDg16QJZF58Px8V+yRR7TNmXYU+aSdkqqeh6DIMgiDYOiQiIqtxqORKq9XC07PyL0MKhQIlJSXmx1OnTsXUqVMf+TXv3CmCySQ+8vM87nx9PZGTU2jrMByOJCAUITmJ2JSSgR5tAyAW34UJAvK0Ugg6vp+Pgn3y0UkkAn+Aogqk/i2h7DMVsqbtmVgR0WPHoYYFurq6Vjm0T6fTQalUWjkiorolCwqHDEaUZl6FSRQhLSlAicQNgkRq69CIiCwY0s/DmHsdAODSrCOPU0T0WHKo5MrPzw9ZWVmVbsvMzKz2xBhEjkIaGAoREjQWM3ArtxgKgxo6GYeyEZF9MWamQrtvGXTHtkAUOdqDiB5fDpVcRUREoLCwENeuXbMoV6vVSEtLQ5s2bWwUGVHdEOSuEBs0RStZJi78mgcPsRhGRT1bh0VEZGa8kw5Nwj8heNSHstdkDgUkoseaQyVX/fr1gyAIWL9+vUV5bGwsDAYDBgwYYKPIiOqOoklrBMtycSYlA14SLQQ3b1uHREQEADCps6HduwiCiwJu/WaVLYBORPQYs4sJLTZu3Ai1Wm1+nJaWhpUrVwIAoqKiEBUVBaBsgeERI0Zg06ZN0Gg06NKlC1JSUrBlyxZER0dzAWFySrKgcMjO7oVL7hV4eOpg8OACwkRkH/Rnv4NoMsAt5n1IPH1sHQ4Rkc3ZRXL11VdfISMjw/w4NTUVS5YsAQBMmTLFnFwBwAcffICgoCBs27YNCQkJ8PHxwYQJEzB58mSrx01kDdKAVjBBgk7yXwEAinpMrojIPii6vQp52z6QeAfaOhQiIrtgF8nV/v37q11XKpVi/PjxGD9+fB1GRGQ/BBcljPWDEZF3AwDgXp+/DhM5gu3bt+PDDz8EAJw+fRru7u4W23U6HVauXIk9e/YgNzcXQUFBGDZsGMaMGQOJxHLUvslkwrp16/D1118jIyMDvr6+GDhwICZNmgSFQmG1NgGAaNBBd2wr5J1egsTVCwITKyIiM7tIrojo/tyCI1Ca/78rV94NbRwNET1IXl4eFi1aBDc3N2g0mkrrvP322zh48CCGDx+O1q1b4/jx4/jss89w+/ZtfPDBBxZ158+fj9jYWMTExGD8+PG4ePEiVq9ejdTUVPMwemsQTQZov18BY/p5yJpGQhLc3mqvTUTkCJhcETkAWaNwlP4SBwCQunNYIJG9+/TTT+Hv74/Q0FDs2bOnwvZDhw5h//79mD59Ot58800AwJAhQ6BQKBAbG4uhQ4eiVatWAMqGym/cuBFDhw7Fxx9/bH4OHx8fLF26FIcOHbLKPceiaELJwTUwpp+DovsYyJhYERFV4FCzBRI9rqT+LSFKZIBUBijcH7wDEdnMsWPHEBcXh7lz50IqrXwh3bi4OMhkMowcOdKifMyYMRBFEd9++625LD4+HqIoYsyYMRZ1R40aBZlMhvj4+Fpvwx+Joog7+76C4epxyKNegTy8Z52/JhGRI2JyReQABJkcMv+WkHk25BoyRHZMr9dj3rx5eOmll9CxY8cq650/fx4tW7aEh4eHRblKpYKHhweSk5PNZcnJyfD09ESLFi0s6np5eSEkJMSibp3Ra6BNOwuXtn0hb9+/7l+PiMhBcVggkYNQdBuJeq4mFNk6ECKq0hdffIH8/HzMmjXrvvWys7MtZsK9l7+/P7Kysizq+vv7V1o3ICAASUlJNYqxYUOPB1eqwBPGsX+HROEKQXh8fpf19fW0dQhWxfY6N7bXOphcETkIaYPGcPX1RFFOoa1DIXJqarW6wmL1VXFzc8O4ceMAlK3RuHr1arz//vto0OD+90aWlJRALpdXuk2hUKCo6PefUbRaLTw9K/+SoFAoUFJSUq1Yy925UwSTSazRPkDZF5Wcx+j4w/Y6N7bXuT1seyUS4SF/gPodkysiIqJ7qNVqLF++vFp1fXx8zMnVvHnzoFKpMHz48Afup1QqodfrK92m0+mgVCrNj11dXatdl4iIbIvJFRER0T0aN26My5cv12ifffv24cSJE/j888+Rnp5uLi8uLgYApKenw9PTE40aNQIA+Pn5WQz9u1dWVhY6dOhgfuzn54ezZ89WWjczM7PKIYNERGR9TK6IiIge0a1btwAA77zzTqXbX3zxRTRq1Aj79+8HALRt2xYJCQkoKiqymNQiNTUVRUVFaNOmjbksIiIChw8fxrVr1ywmtVCr1UhLS8MLL7xQF00iIqKHwOSKiIjoEUVHRyMgIKBC+aZNm3Dy5EksWLAA3t7e5vKYmBjExcVh06ZNmDhxorl87dq1EAQB/fv/PiNfv3798K9//Qvr16/H//3f/5nLY2NjYTAYMGDAgDpqFRER1RSTKyIiokcUHByM4ODgCuUHDx4EAPTq1Qvu7r+vUdezZ09ER0dj8eLFyMrKQnh4OI4fP474+HiMHDkSKpXKXDc0NBQjRozApk2boNFo0KVLF6SkpGDLli2Ijo62ygLCRERUPUyuiIiIbGDx4sVYsWIF4uLisG3bNgQFBWHWrFkYO3ZshboffPABgoKCsG3bNiQkJMDHxwcTJkzA5MmTbRA5ERFVRRBFsebzsT5GHnbKWrL0uE0BWlf4PtYevpePrjamrCXr41Ts1cP2Oje217nZcir2x2clQCIiIiIiojrE5IqIiIiIiKgWMLkiIiIiIiKqBUyuiIiIiIiIagGTKyIiIiIiolrA5IqIiIiIiKgWcJ2rB5BIBFuH4DT4XtYOvo+1h+/lo+H755ge5e/2uP3N2V7nxvY6t4dpb228R1znioiIiIiIqBZwWCAREREREVEtYHJFRERERERUC5hcERERERER1QImV0RERERERLWAyRUREREREVEtYHJFRERERERUC5hcERERERER1QImV0RERERERLWAyRUREREREVEtYHJFRERERERUC2S2DsDemEwmrFu3Dl9//TUyMjLg6+uLgQMHYtKkSVAoFLYOzyHcvHkTzz33XKXbnn76aXz55ZdWjsgxFBcXY+3atUhOTkZycjJycnLQt29fLF26tNL6O3bswLp16/Drr7+iXr166N27N6ZPnw4vLy8rR25favI+hoaGVvoczZs3R0JCQl2HatcuXLiAuLg4HD9+HDdv3oRUKkWzZs0wYsQIDBw4EIIgWNRnf3Q+27dvx4cffggAOH36NNzd3S2263Q6rFy5Env27EFubi6CgoIwbNgwjBkzBhKJ5W+39nhuPXjwILZs2YLLly8jLy8Prq6uCA4Oxp/+9Ce8+OKLFdpw+/ZtrFq1CkePHkV2djbq16+PyMhITJw4EW3atKnw/Pb2mahpewEgPT0dK1euxOHDh5Gfn4/69eujXbt2+Oijj+Dj42NR1xnaW+7o0aMYO3YsACAuLg4qlcpiu6P35xMnTmDv3r04deoUbt26BTc3N4SEhGDcuHHo2bNnhed29PYC1j1eMbn6g/nz5yM2NhYxMTEYP348Ll68iNWrVyM1NRUrV660dXgOpXfv3ujdu7dFmZ+fn42isX/5+flYtmwZfH19ERERgQMHDlRZd926dfj000/RvXt3jBw5Ejdu3MD69etx/vx5bN68GXK53IqR25eavI8A0KlTJwwdOtSizNPTsy5DdAhr1qzBsWPH0KdPHwwfPhw6nQ7fffcdZs+ejRMnTmD+/PnmuuyPzicvLw+LFi2Cm5sbNBpNpXXefvttHDx4EMOHD0fr1q1x/PhxfPbZZ7h9+zY++OADi7r2eG69cuUKXFxcMGzYMPj4+ECr1eLQoUN499138csvv+Cjjz4y171z5w5efvllGAwGDB8+HE2aNMHt27exdetWHDhwAFu3brVIsOzxM1GT9gLA2bNn8frrryMwMBCvvvoqfHx8kJeXhzNnzqCoqMgiuXKG9pbT6XSYN2/effu+o/fnzz//HNnZ2ejduzdUKhXUajV27NiBiRMnlBbpMAAAFkpJREFUYtq0aZg8ebLFczt6ewErH69EMrty5YoYGhoqfvjhhxbly5cvF1UqlXjw4EEbReZY0tPTRZVKJS5dutTWoTgUnU4nZmZmmh+rVCpx6tSpFerduXNHjIyMFF9//XXRZDKZy3fu3CmqVCpx06ZNVonXXlX3fSzfNmfOHGuF5lCSkpJEnU5nUWY0GsWRI0eKKpVKvHz5siiK7I/OaubMmeKAAQPEmTNniiqVSiwqKrLYfvDgQVGlUomrVq2yKH/vvffE0NBQ8cqVK+YyRzu3vvHGG2JYWJh49+5dc9natWtFlUolJiYmWtQ9deqUqFKpxL/97W/mMkf7TFTWXq1WK0ZHR4vjxo0T9Xr9ffd3hvbe65///KfYrVs3cf78+RbHunLO0J9PnjwpGgwGi3parVbs27ev2KZNG4u6ztBeax+veM/VPeLj4yGKIsaMGWNRPmrUKMhkMsTHx9smMAdWUlICrVZr6zAcglwuh7+//wPr/fDDD9BqtRg9erTF0KwBAwagYcOGj30/re77eC+9Xl/lL5SPqyeeeKLCr80SiQR9+vQBAKSmpgJgf3RGx44dQ1xcHObOnQupVFppnbi4OMhkMowcOdKifMyYMRBFEd9++625zNHOrYGBgTCZTCgqKjKXFRYWAgB8fX0t6paPxnB1dTWXOdpnorL2fvvtt8jIyMCsWbPg4uICrVaL0tLSSvd3hvaWu3btGtasWYM5c+bAw8Oj0v2doT9HRUVV+GwrlUr07NkTpaWl+PXXX83lztBeax+vmFzdIzk5GZ6enmjRooVFuZeXF0JCQpCcnGyjyBzTV199hcjISLRv3x7PPvssVq9eDaPRaOuwHN758+cBAB06dLAol0qlaNeuHS5evAhRFG0RmkNKSEhAZGQkOnTogKeffhoLFy5ESUmJrcOyW5mZmQCABg0aAGB/dDZ6vR7z5s3DSy+9hI4dO1ZZ7/z582jZsmWFL6AqlQoeHh4W50t7P7cWFRUhLy8Pv/32GzZv3owdO3agZcuWCAoKMtfp2rUrAODjjz/GqVOnkJWVhTNnzuD999+Hj48Phg0bZq5r75+J6rT3p59+goeHB9RqNV588UW0b98e7dq1w4gRI3Du3DmL53OG9pabN28eOnTogAEDBlT5fM7Qn6vyx+M74Bzttfbxivdc3SM7O7vKX7wDAgKQlJRk5Ygck0QiwZNPPonevXsjKCgIubm52L17Nz7//HNcvnwZn3/+ua1DdGjZ2dlwdXWt9CbhgIAAaLVaFBQUwNvb2wbROZbIyEg8//zzaNq0KdRqNb7//nusWbMGZ8+exbp16yCT8RB5r+zsbGzbtg2NGjUyf/Fmf3QuX3zxBfLz8zFr1qz71svOzkZUVFSl2/z9/ZGVlWVR157PrW+99RYOHz4MABAEAV27dsVHH31kcRWmY8eOmDdvHhYvXoxXX33VXB4WFmb+TJSz989Eddp7/fp1GI1GTJgwAc8//zwmTZqEjIwMrFq1CqNHj8b27dvRqlUrAM7RXgD4z3/+gzNnzmDXrl33fT5n6M+VuXTpEhITE9G+fXs0bdrUXO4M7bX28YrfHO6h1WqrvJFdoVDw1+xqCgoKwvr16y3KhgwZgilTpiA+Ph7Dhw+vspPTg2m12ipvDi6fxYZ9tXq2bdtm8Xjw4MH45JNPsGHDBsTHx2PQoEE2isz+6PV6vPXWWygqKsLSpUvNfZD90f6o1eoKx+CquLm5Ydy4cQCAtLQ0rF69Gu+//77FL9eVKSkpue/f/d4hOXV9bn3Y9pabOXMmxo0bh+zsbOzfvx/5+fmVDhP28fFBq1at0LVrV4SGhiI9PR1r1qzBuHHjsGHDBvMQwbr+TFijvcXFxdBqtRgwYAD+/ve/m8vbtGmD0aNHY8WKFVi8eDEA52hvXl4eFixYgNdeew0tW7a872s4S3++V0FBAd566y3IZDL87W9/s9jmDO219vGKydU9XF1dodfrK92m0+mgVCqtHJHzEAQBEydORGJiIn766ScmV4/gQf0UAPvqI3jzzTexYcMG/PTTT0yu/sdgMOCtt97CmTNn8PHHH+Opp54yb2N/tD9qtRrLly+vVl0fHx/zl5V58+ZBpVJh+PDhD9xPqVRW+3xZ1+fWh21vufDwcPP/Dxo0CH/9618xatQoJCQkmJPMffv2YerUqfjyyy/x9NNPm+t369YNgwYNwpIlS/DJJ58AqPvPhDXaWx7f4MGDLfbt0qULgoKCcPLkSXOZM7R34cKFUCqVFWbJq4wz9Od7FRUVYcKECcjIyMCKFSvMVyTLOUN7rX28YnJ1Dz8/P5w9e7bSbZmZmTW+SZ4slQ+byM/Pt3Ekjs3Pzw9arRZqtbrCMIzMzEy4urqiXr16NorO8TVs2BBKpZL99H+MRiPeeecd7N+/Hx9++CGGDBlisZ390f40btwYly9frtE++/btw4kTJ/D5558jPT3dXF5cXAygbL0jT09P83Hcz8/PYijNvbKysizuv6nrc+vDtPd+YmJi8PXXXyMxMdG8TMOGDRvg7u5ukVgBQKtWrRASEoKff/7ZXFbXnwlrtNfPzw9XrlypsJYVUDapx8WLF82PHb29ycnJ2LFjB2bNmoWcnBxzvYKCAgBl65spFAoEBwcDcI7+XE6j0WDixIm4cOECFi9ejGeeeabC/s7QXmsfrzihxT0iIiJQWFiIa9euWZSr1WqkpaVVukggVd9vv/0GoOzLKz28tm3bAgDOnDljUW4ymXD+/HmEh4c/cGw1VS0rKwslJSXspyjrU7Nnz0ZCQgLmzJmDUaNGVajD/ugcbt26BQB455130KdPH/O/ffv2AQBefPFFi79/27Ztce3atQozrqWmpqKoqMjifOlo59byYT/lX64BICcnB6IoVjoxg8FggMFgMD92tM9EZe1t164dgN8nOLhXZmamxRUQR29veRsXLlxo0fdjY2MBAG+88YZ5llTAOfpzefnEiRNx5swZLFy4sMK6pOWcob3WPl4xubpHv379IAhChbGfsbGxMBgM9509hn5X2S/+paWl5su+0dHR1g7JqTz33HNQKpXYsGGDRXn5quMxMTE2isyxVNZPRVE030fwuPdTk8mE9957D/Hx8ZgxYwZef/31SuuxPzqH6OhoLFmypMK/zp07AwAWLFiAuXPnmuvHxMSgtLQUmzZtsnietWvXQhAE9O/f31xmr+fW3NzcCmWiKGLLli0Ayia8KdeyZUtoNBr897//tah/9uxZXL9+HREREeYye/1M1KS9MTExkEgk2Lp1q0X9/fv3IysrCz169DCXOXp727ZtW2nff/755wEA7733HpYsWWJ+DmfozzqdDn/+859x6tQp/P3vf0e/fv2qfF5naK+1j1ccFniP0NBQjBgxAps2bYJGo0GXLl2QkpKCLVu2IDo6utLLpVTRX/7yF2g0GrRv3x4BAQHIzc3F3r17kZqaihEjRlh0eLK0ceNGqNVq8+O0tDTzauBRUVGIiopCgwYNMG3aNCxYsAATJkxAnz59cOPGDaxbtw5t2rSpMGzrcVSd93HVqlU4d+6c+R6CgoIC7N+/H2fOnEF0dLT5xPq4WrBgAXbt2oW2bdsiICAAu3fvttj+xBNPoEmTJuyPTiI4ONg87OleBw8eBAD06tUL7u7u5vKePXsiOjoaixcvRlZWFsLDw3H8+HHEx8dj5MiRUKlU5rr2em6NiYlBVFQUWrduDR8fH+Tm5iIhIQGXLl1CTEyMObEEgIkTJ+LHH3/EzJkz8fPPP0OlUiE9PR2bN2+GQqHApEmTzHXt9TNRk/a2aNECY8eOxZdffokJEyagZ8+euHXrFjZu3AgfHx9MmTLFXNfR2+vv71/p8b58Lb+uXbs6XX+eOXMmjh49imeeeQaiKFY4vnfr1s08JNQZ2mvt45UgcgESC0ajEWvXrsW2bdtw69Yt+Pj4YODAgZg8ebJ51hu6v+3bt2P37t1IS0uDWq2GQqFAaGgohg4dygkCHuDZZ59FRkZGpdumTJmCqVOnmh9/8803WL9+Pa5fvw4vLy/06tULM2bM4P0tqN77mJiYiM2bNyM1NRX5+flwcXFB8+bNMWjQILz66qtVLp76uBg1apTFTet/9Omnn1rc7M7+6Jzeffdd7Ny5E6dPn7ZIroCy4TcrVqxAXFwccnNzERQUhKFDh2Ls2LEVPj/2eG5dvnw5Dh8+jOvXr6OwsBBubm4IDQ3FoEGDMHjwYEgkloN7rl69ipUrV+Ls2bPIysqCh4cHoqKiMHnyZISFhVV4fnv7TNS0vaIoYvPmzdiyZQuuX79uvudsxowZFlPPl3P09v7RsmXLsHz5csTFxVl8+QYcvz/f7xwJlN1j2KVLF/NjR28vYN3jFZMrIiIiIiKiWsB7roiIiIiIiGoBkysiIiIiIqJawOSKiIiIiIioFjC5IiIiIiIiqgVMroiIiIiIiGoBkysiIiIiIqJawOSKiIiIiIioFjC5InpMLVu2DKGhoThx4oStQyEiIiJyCjJbB0DkqEJDQx9Y54+rnBMRERGR82JyRfSIpkyZUuW2Ro0aWTESIiKiuvX666/jyJEj960zbdo0TJ482UoREdkXJldEj2jq1Km2DoGIiMgqLly4AJlMhjfffLPKOn379rViRET2hckVkZUsW7YMy5cvx4YNG3Dr1i2sX78eaWlpcHd3R8+ePTFjxgz4+vpW2O/69etYuXIljh07hvz8fHh7e6Nr166YNGkSmjVrVqG+0WjEtm3bsHv3bqSmpqK0tBT+/v7o3LkzJkyYUOk+CQkJWLNmDVJTU6FQKNCtWze8++678Pf3r4N3goiIHFF6ejru3r2L1q1b84dFoiowuSKysnXr1uHIkSPo168funfvjqSkJOzYsQMnT57E9u3b0aBBA3Pdc+fOYezYsSguLsazzz6Lli1bIi0tDXv27MEPP/yAtWvXol27dub6er0eb775Jo4cOYLAwEDExMTAw8MDGRkZSExMRMeOHSskV5s3b8b+/fvx7LPPIioqCufOncPevXtx6dIl7N69G3K53FpvDRER2bHz588DANq2bWvjSIjsF5Mroke0bNmySssVCgXeeOONCuU//fQTtm3bhtatW5vL5s+fj/Xr12PRokWYP38+AEAURcyZMwdFRUVYuHAhBg4caK6/d+9eTJ8+HbNnz8bevXshkZRN/Ll8+XIcOXIE0dHRWLp0qUVipNfrUVRUVGk833zzjcUEHe+88w7i4+ORmJiIfv361fAdISIiZ5ScnAyAyRXR/TC5InpEy5cvr7Tc09Oz0uRq4MCBFokVUHbf1o4dOxAfH4958+ZBLpfj9OnTSEtLQ4cOHSwSKwDo168fNm7ciKSkJCQlJSEqKgpGoxGbN2+GUqnERx99VOGKk1wut7gqVm7UqFEVZj4cMmQI4uPjcf78eSZXREQE4Pfk6vTp08jMzKy0zrhx4+Dm5mbNsIjsCpMrokd0+fLlGtXv3LlzhTJPT0+Eh4fj5MmTuHbtGsLDw3Hx4kUAqHIq9yeffBJJSUm4ePEioqKikJaWhsLCQkRGRtboXqnKfoEMDAwEABQUFFT7eYiIyHmJomg+L+3YsaPSOt7e3hb3Yi1atAgXLlzA2rVrq/06//znP3H69GnExsY+dKwP87pEtYXJFZGVNWzYsNJyHx8fAEBhYaHFf/38/CqtXz75RXk9tVoNADWehMLT07NCmVQqBQCYTKYaPRcRETmn69evo7CwEE888QS2bNlSrX1SUlIQFhZWo9dJSUlBeHj4w4T4SK9LVFsktg6A6HFz586dSstzc3MB/J7slP83Jyen0vrl5R4eHgAALy8vAEBWVlbtBUtERITfhwTWJPG5dOlShWHwD5KSklLjfWrjdYlqC5MrIis7efJkhbLCwkKkpKRAoVCgRYsWAH4/gVVWHwBOnDgBAGjTpg0AICQkBF5eXrh8+TITLCIiqlXlyVV1rwjl5OQgNzfXIhlbuXIlBgwYgA4dOuDJJ5/Eu+++i5KSEvP2O3fuIDs7GxKJBK+99hoiIyPx4osv4ty5c+Y6WVlZmD17Nrp06YJOnTph6tSp5h8nq3pdImtickVkZXv27DGPWy+3bNkyFBYWon///uaJKDp27IjmzZsjKSkJCQkJFvUTEhJw6tQpNGvWDB07dgRQNpRvxIgRKCkpwdy5c6HX6y320ev1yMvLq8OWERGRs6ppcpWSkgKlUonmzZuby4xGI+bNm4f4+Hj84x//wJEjR7B+/XqLfQBg7dq1mDx5Mnbu3ImAgAC89dZbMBgMSE9Px0svvQR/f39s3rwZsbGxyM/Px9y5c+/7ukTWxHuuiB5RVVOxA0CvXr0q/HrWvXt3/OlPf8ILL7wAX19f84x/jRo1wsyZM831BEHAZ599hrFjx2L69OmIj49HSEgIfv31VyQmJsLd3R0LFiwwT8MOAJMnT8bZs2dx4MAB9O3bFz179oS7uztu376NI0eOYPbs2Rg8eHDtvwlEROS0TCYTLl68CKlUCpVKVa19Ll26BJVKZb6HF4DFZBeNGjVCz549kZaWZi5LSUmBi4sLli1bhsaNGwMAZs2ahf79++PGjRv45JNP8Morr2DGjBnmfSZNmoQpU6bc93WJrInJFdEjqmoqdqDs5PHH5GrMmDHo3bs31q9fj71798LNzQ2DBw/G9OnTK0x2ERkZiW+++QarVq3CsWPHcODAAdSvXx/9+/fHpEmTEBISYlFfLpdjzZo12Lp1K3bt2oVdu3ZBFEX4+fmhd+/e5qtcRERE1ZWWlgaNRgMPDw/8+9//rrLe0KFDzZMq/XFiitu3b+PLL7/EiRMnkJWVhdLSUuj1eowfP95cJyUlBb179zYnVsDv9xPn5ubi8OHDOHXqlMVMgkajEa6urhbPwSGBZEtMrogeUk2nYL/X4MGDq30FKSQkBAsXLqz2c8tkMowcORIjR468b72pU6da/Ip4r8aNGz9S+4iIyHmUDwksKiqq8gdFiUSCsWPHmh+npKRg9OjRAID8/Hy88sor6NSpE2bPno2AgABIJBK88sorFsMMU1JSKpwbz549Czc3NxQWFsLDw6PSaeBdXFwqfV0iW2ByRURERERVGjRoEAYNGlTt+lqtFr/99pt5xr5Dhw5Bp9Nh8eLFEAQBALBz505oNBrzVSatVovr169bLAEiiiK++uorDBw4EDKZDFqtFj4+PnB3d6/W6xLZApMrIiIiIqo15SMfQkNDAZQtLqzRaJCYmAiVSoUff/wR//rXv+Du7o7g4GDzPhKJBLt27UKXLl1Qv359LFu2DLdv38aKFSsglUrh5eWF2bNnY/LkyfDw8EB6ejoSExPxl7/8BRKJpMLrEtkCkysiIiIiqjUpKSkIDg423wv1zDPPYNiwYZgzZw4UCgX69++PAQMG4JdffjFfyUpJSUGTJk0wffp0vP3228jPz0ePHj2wbds2NGjQAADw73//GwsXLsTo0aNhNBrRpEkT9O/f3zyx0x9fl8gWBFEURVsHQURERERE5Oi4zhUREREREVEtYHJFRERERERUC5hcERERERER1QImV0RERERERLWAyRUREREREVEtYHJFRERERERUC5hcERERERER1QImV0RERERERLXg/wEqKUL0b2dnQgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 864x432 with 2 Axes>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MV-qtMJsD6s7"
},
"source": [
"While we see that the network has begun to learn the energies, we also see that it has a long way to go before the predictions get good enough to use in a simulation. As such we're going to take inspiration from cooking shows, and take a ready-made GNN out of the fridge where it has been training overnight for 12,000 epochs on a V100 GPU."
]
},
{
"cell_type": "code",
"metadata": {
"id": "e89weTfODSDz"
},
"source": [
"with open('si_gnn.pickle', 'rb') as f:\n",
" params = pickle.load(f)"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KiYDpSSVGR0l"
},
"source": [
"Using our trained model we plot the predicted energies and forces against the labels. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "4j0Zz0RDEUVF",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 441
},
"outputId": "e78aa0c4-8460-453a-adb0-fa86a10d4d42"
},
"source": [
"plt.subplot(1, 2, 1)\n",
"\n",
"predicted_energies = vectorized_energy_fn(params, example_positions)\n",
"plt.plot(example_energies, predicted_energies, 'o')\n",
"\n",
"format_plot('$E_{label}$', '$E_{predicted}$')\n",
"plt.subplot(1, 2, 2)\n",
"\n",
"predicted_forces = vectorized_force_fn(params, test_positions[:300])\n",
"plt.plot(test_forces[:300].reshape((-1,)),\n",
" predicted_forces.reshape((-1,)), \n",
" 'o')\n",
"plt.plot(np.linspace(-6, 6, 20), np.linspace(-6, 6, 20), '--')\n",
"plt.xlim([-5, 5])\n",
"plt.ylim([-5, 5])\n",
"\n",
"format_plot('$F_{label}$', '$F_{predicted}$')\n",
"finalize_plot((2, 1))"
],
"execution_count": 21,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1gAAAGoCAYAAABbkkSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeXxU5dk//s+ZJZPJxjIJWQARgoQkrCIqiwbEoPQJwjcta2WRtUWswlMUvyqK2tal/r7WBVQEQZQApagVIZVFosSo1YJANjAIsg2QRZIhM5nt/P5IZ8xklswkZ5ZMPu/X63k98cyZc+45DXPnuu/rvm5BFEURRERERERE1GayYDeAiIiIiIgoXDDAIiIiIiIikggDLCIiIiIiIokwwCIiIiIiIpIIAywiIiIiIiKJMMAiIiIiIiKSiCLYDeioamquwWr1b4V8jSYGVVU6v94j3PEZth2foTT4HF2TyQR06RId7Ga0K4Hof1rC32dHfB6O+Dwc8Xk4C+YzMZYcgOn73VANzEZi1q9dnsMAK0isVjEgHVywO9FwwGfYdnyG0uBzJCkEqv/xph30Cz4PR3wejvg8nAXtmXTuAaH7IMgzs92ewgCLiIiIiIjIA0vNBci7pECRdAMUSTdAEAS353INFhERERERkRumk1+i/u+PwXTq316dzwCLiIiIiIjIBdOP38Fw8G3IU/pDcd1gr97DAIuIiIiIiKgZ89mjMOxfA1m3PlDf9SAERYRX72OARURERERE1IT1Wg30n74GWZfuiLp7GQRlpNfvZZELIiIiIiKiJmTRXRCZdR/k3TMhqHzbDoQBFhEREREREQBL1VmIJj0USf2g7DuiVddgiiAREREREXV4lp8vQP/JC2go2ADRamn1dRhgERERERFRh2atvQL9Jy8CggD1XQ9BkMlbfS0GWERERERE1GFZr9Wg/pMXIJqNUP/PCsg6J7XpegywiIiIiIiowzId3wvRUIeoCf8Ledeebb4ei1wQEVHQFBVrsbOgAlW1DdDEqZCblYoRmW0bOSQiIvJFxPDfQNFvNORdUiS5HmewiIgoKIqKtdi0pwxVtQ0AgKraBmzaU4aiYm2QW0ZEROFONOqhP/AmrLpqCDKZZMEVwACLiIiCZGdBBYxmq8Mxo9mKnQUVQWoRERF1BKLZCP2//gZzxdewVp+V/PoMsIiIKChsM1feHiciImor0WKGfu9rsFwsR+TYhVBcN1jyezDAIiKioNDEqXw6TkRE1Bai1QLDgTdgOXsUqtvntnoj4ZYwwCIioqDIzUpFhMKxG4pQyJCblRqkFhERUVgz6mGtvQTViBmI6J/lt9uwiiAREQWFrVogqwgSEZE/iaIIiFYIkTGImrwKglzp1/uFXYB18OBB5OXloby8HNXV1VCr1ejVqxdmzJiBSZMmQSZzP2n35Zdf4r777gMAfPzxx+jXr5/9tXPnzmHcuHEu3zd69GisX79e2g9CRNQBjMhMYkBFRER+I4oiGr7eBuvPF6HOfsDvwRUQhgHWiRMnoFQqMW3aNMTHx0Ov16OgoAArV67EkSNHsHr1apfva2howFNPPYWoqCjU19e7vX52djays7MdjnXr1k3Sz0BERERERG1n/M9HMB3NhzJjHCCTB+SeYRdgLVq0yOnY7NmzsXjxYmzfvh3Lly9Hp06dnM5Zu3Yt6uvrMXXqVGzcuNHt9dPS0jBp0iQpm0xERERERBIzfr8Hxu8+hKLfaKhG/RaCIATkvh2myEVycjKsVit0Op3TaxUVFXj77bfxyCOPICYmpsVrGQwG6PV6fzSTiIiIiIjayFh6EA1fb4Oiz82IvH0eBCFwYU/YzWDZ6HQ6GI1G1NXVobCwEDt37kTfvn2RkuK8S/NTTz2FoUOHYuLEiXj11Vc9XnfDhg147bXXAADdu3fH9OnTMX/+fMjlgZlyJCIKNUXFWhaqICKikCKP7wXFDaMQmXUfBA81GPwhbAOsBx98EIcOHQIACIKAkSNHYvXq1U5Tg//4xz9w+PBhfPjhhx6vJ5PJcOuttyI7OxspKSmorKzERx99hJdeegnl5eV46aWXfGqfRtPyTJkUEhJiA3KfcMZn2HZ8htIIxed48LuzeDe/HA0mC4DGTYLfzS9HXGwkxgzrGeTWERFRR2O9egmyTomQJ/SGeuzCoLQhZAOs2tpabNq0yatzo6KiMH/+fIdjf/zjHzF//nxcvnwZBw4cQE1NjVPxiurqarzwwguYM2cO+vbt6/EeKSkpTu2ZMmUKli5dil27dmH69OkYPny4V+0FgKoqHaxW0evzWyMhIRZXrtT59R7hjs+w7fgMpREKz9HVTNXOggp7cGXTYLJg465iZF7X2e9tksmEgA1YERFRaDP/dBT6T/+GyNFzoOx/e9DaEdIBli0VryXx8fFOAVZ6err958mTJ2PVqlWYNWsW8vPz0bVrVwDAiy++iMjISNx///2taqMgCFi8eDH27duHL774wqcAi4ioPSkq1mLTnjIYzVYAjTNVTf+7uarahkA2j4iIOjjzhVLo974KWdceUPQeFtS2hGyA1aNHD5SXl0t2vZycHGzbtg379u3D1KlTcfz4cezcuRMrVqzAlStX7OddvXoVAHDx4kWoVCr06tXL43W7d+8OAKipqZGsrUREoWZnQYVTMGU0WyETAFeT8Zo4VYBaRkREHZ3l0g/Q578MWVwC1L/6IwRVdFDbE7IBltQMBgOAXwIorVYLoHEW68UXX3Q631buvaUg78yZMwAAjUYjWVuJiEKNuxkpqwhEKGQOwVeEQobcrNRANY2IiDowseEa6vP/H4SoTlD/z8OQRQZ/vXLYBViVlZWIj493OCaKIvLy8gAAgwcPBgAMHDgQf/vb35zev2fPHuTn5+PRRx9FUtIvVbBqamrQpUsXh3NNJpM9jXHs2LGSfg4iolCiiVO5DLKarsViFcHAqKiowKRJk2AymfDGG2+w/yGiDk1QRSNy9GzIu6VCFuX/tb/eCLsAKycnB8OHD0dGRgbi4+NRWVmJ/Px8lJWVIScnBzfffDMAIDExEXfffbfT+0+ePAkAGDlyJPr162c//sQTT6C+vh5DhgxBUlISKisrsXv3bpw8eRIzZ860B25EROFoUKoGnx2+4HDMNlM1IjOJAVWAiKKIVatWQalUwmQyBbs5RERBY629DGtdJRTdM6BMvSXYzXEQdgHWvffei0OHDmHTpk2oq6tDVFQU0tLS8Kc//Qm5ubmtvm5WVhY++ugjbN26FbW1tVCpVEhLS8Pzzz+PyZMnS/gJiIhCQ9Oqga6MGsjAKtB27NiB4uJiLFiwoMV9G4mIwpW5tgr1n7wAWMyInv4CBEVEsJvkIOwCrKVLl2Lp0qWtfv8DDzyABx54wOn4lClTMGXKlLY0jYio3WheNdCVoxVVAWwRVVdX469//St+97vfoVu3bsFuDhFRUFjrr+LijuchGnSIynkk5IIrAAjstsZERNQuuKoa2BxLsQfW888/j86dO2PevHnBbgoRUVCIBh30u/8Kc20l1BOWQ57QO9hNcinsZrCIiKjtvAmeWIo9cL766it8+OGH2LBhAyIi2jZaGyobMyckBL/SVyjh83DE5+GoIzyPg9+dxbt7SlFZo0d8FzVmT0jHmGE9Hc6pKdwL69WLSJr6fxHVJ3TrHzDAIiIiJ+6qBtqwFHvgGI1GPPnkk5gwYQJGjRrV5utVVelgdbV5WQAlJMTiypW6oLYhlPB5OOLzcNQRnkfztPQrNXq8uv0IausMDmt9xRvGIUqThqg+mUF/JjKZ4HbAiimCRETkJDcrFREK112EJk6FORP6s8BFgLz11lu4fPkyHn300WA3hYjIL9xtZr+zoAKixQTD5xthrb0CQZBBrrkuSK30HmewiIjIiS144v5WwXX58mW8+eabmD59OgwGg31z+6qqxgIjV65cwZkzZ9C9e3coFOzSiah9cpcxUVOrh2H/GzCf/g7y5H6QxSUEuGWtw29jIiJyiftbBV9VVRWMRiPeffddvPvuu06vP/HEEwCA/fv3o0ePHoFuHhGRJFylpQsQcV/nr2A+/QNUI38L5Q0jg9Q63zHAIiIiClE9evTA3/72N6fj33zzDd5//30sWrQImZmZ0Gg0QWgdEZE0crNSm20NImJ6zNcYJPsBEcN/g4gB2UFtn68YYBEREYWo2NhY3H333U7H6+vrAQA33ngjxo4dG+hmERFJqnlaelKcHIM61yOibw5UQ3OC3DrfMcAiIiIiIqKgsqWli1YLBJkconkkIFcGu1mtwgCLiKgDKSrWsnBFGMjNzUVubm6wm0FEJKmGI7thOV8M9V0PQlC0bc+/YGKZdiKiDsK2z4htIXFVbQM27SlDUbE2yC0jIqKOzli8H8ZvtkOIjAFk7XsOiAEWEVEH4WmfESIiomAxnTiEhsLNkF83BJFjF0KQte8QpX23noiIvOZunxF3x4mIiPzN9OO3MBSsh7x7JtR3LoHQzmevAAZYREQdhiZO5dNxIiIif5PFJULR60aox/+hXa+7aooBFhFRB5GblYoIhePXfoRChtys1CC1iIiIOipr7RWIogi5pifU4x+AoAyfwT4GWEREHcSIzCTMmdDfPmOliVNhzoT+rCJIREQBZdGexLUdj8N0/NNgN8Uv2n+SIxEReV1+3bbPCBERUTBYKk+jPv//gxDdGYrUW4LdHL9ggEVE1M7Zyq/bKgTayq8DYDBFREQhw1J9HvpP/gohIgpR//MwZFGdg90kv2CKIBFRO8fy60REFOpEsxH6PS8BMnljcBWjCXaT/IYzWERE7RzLrxMRUagTFBFQjZgBWecUyDolBrs5fsUZLCKids5dmXWZ0Jg+SEREFCzW+p9hPnccAKDsMxzyrt2D3CL/Y4BFRNTOuSq/DgBWEdi0p4xBFhERBYVo0EH/yV+h37cGorE+2M0JGAZYRETtnK38ukxwfo1rsYiIKBhEox71e16CtVYLdfZSCBFRwW5SwHANFhFRO+OuJPu6j0tcns+1WEREFEiiqQH6/P8Ha+VPUI9/AIruGcFuUkAxwCIiakc8lWTXxKlcBlPu1mgRERH5g+nEIVgunUTkHb+HoteQYDcn4JgiSETUjngqye5qLVaEQobcrNRANpGIiDo4ZcYdiJr0BJSpNwe7KUHBAIuIqB3xVJLdthbLNmOliVNhzoT+3GyYiIj8TrRaYfhyC6w/ayEIAuTd+gS7SUHDFEEionakpTTAEZlJDKiIiCigRNGKhi82wlT+OWRxCYjo3LH7Ic5gERG1I0wDJCKiUCKKIhqK8mAq/xwRQyciYkB2sJsUdGE3g3Xw4EHk5eWhvLwc1dXVUKvV6NWrF2bMmIFJkyZBJvvlD5Ovv/4as2fPdnmdadOm4emnn3Y4ZrVasXHjRmzbtg3nz59HQkIC7rnnHixZsgQqFReRE5F/FBVr8eGhIlyp0UMTp8KogUk4WlHlVEWQiIgo0Izf7oTp+F4oB4xHxE25wW5OSAi7AOvEiRNQKpWYNm0a4uPjodfrUVBQgJUrV+LIkSNYvXq103umTZuGYcOGORzr3bu303l//vOfsXnzZuTk5GDBggUoKSnBW2+9hZMnT2LNmjV++0xE1HG5qhpYeEzLtVVERBR0osUEy8VyKPvfDtWIGRAEFxsydkBhF2AtWrTI6djs2bOxePFibN++HcuXL0enTp0cXh8yZAgmTZrk8bonT57Ee++9h6lTp+KZZ56xH4+Pj8crr7yCgoICZGVlSfMhiKjDKyrWYsveclwzWJxes1UNZIBFRERScbfHojuiaIUgV0L9q/8FZEoGV010mDVYycnJsFqt0Ol0Ll+vr6+H0Wh0+/5du3ZBFEXMnTvX4fisWbOgUCiwa9cuKZtLRB1YUbEWG3aVuAyubLh5MBERScWWLWHrW2x7LBYVa12ebyr7HPpdL0A06iEoVBBkHSak8ErYzWDZ6HQ6GI1G1NXVobCwEDt37kTfvn2RkpLidO6zzz6LRx99FACQmpqK++67D1OmTHE45/jx44iNjUVqquNC8ri4OPTp0wfHjx/334chorDlasRwZ0EFLKLn93HzYCIi8pW7WSpPeyw2n8Uy/fAVDJ+/A3mPTEAetqFEm4TtU3nwwQdx6NAhAIAgCBg5ciRWr17tMH2pUChwxx13ICsrC926dcPFixexdetWPP744zh37hyWLVtmP/fy5ctITEx0ea+kpCR89913PrVPo4lpxafyXUJCbEDuE874DNuOz9C1g9+dxbv55WgwNc5UVdU2YMMnpbBYW4iuAMzNyeRzJSIir7la07tpT5n9Z1eaHzefPgzDZ+sgT7oB6vEPQJAr/dvodipkA6za2lps2rTJq3OjoqIwf/58h2N//OMfMX/+fFy+fBkHDhxATU0N6uvrHc4ZNmyYU3GLqVOnYvr06Vi3bh1+85vfoGfPngAAvV6P2FjXf8yoVCoYDAZvPxoAoKpKB6sXf0S1RUJCLK5cqfPrPcIdn2Hb8Rm6t3FXsT24svEmuIqOlCPzus58rgBkMiFgA1ZERO2Zp1kqd3ssyoTGwGxEZhLM50ug3/c6ZPHXQX33MggKZlK4E9IB1muvvebVufHx8U4BVnp6uv3nyZMnY9WqVZg1axby8/PRtWtXt9dSKpWYP38+li1bhi+//BLTpk0DAKjVardrtBoaGhAZGelVW4mImqZo+EouADOz0/zQKiIiagtfi0QEmrezVE1ZRdhnuW7u3hny7hlQj10EIULtlzaGi5ANsHr06IHy8nLJrpeTk4Nt27Zh3759mDp1aov3BoCamhr7sW7duuH77793eb5Wq3WbPkhE1FTzFA1fzcvJCKkOm4iIPKffhcp3trtZqpaoLXXYWfADRiwZjagJy/3QsvDTYUp+2FL4rl692uK5Z86cAQBoNBr7sQEDBqCurg4VFRUO59bW1uLUqVPIzMyUsLVEFK5cpWh4K0atwM6CCsx77gBWrCl0W92JiIgCy1P6XajIzUpFhMK3P/2T5TV4pNPHGGb81k+tCk9hF2BVVlY6HRNFEXl5eQCAwYMH2483naGyqa+vx5tvvgmlUonRo0fbj//qV7+CIAhO68I2b94Ms9mMiRMnSvURiCiMtTR6KHOzj4hMJkBvMHtdQpeIiAKnNel3gTYiMwlzJvT3ugptgqwWv4/dBxPkOGzqzf7GByGbIthaOTk5GD58ODIyMhAfH4/Kykrk5+ejrKwMOTk5uPnmm+3nLly4EImJicjIyLBXEfzggw9w4cIFrFixAsnJyfZz09LSMHPmTLz//vuor6/HLbfcgtLSUuTl5WHs2LHcZJiIWuRN5zQ/Jx15+05Apzfbj8WoFRAEAXX1JodzueEwEVFocJd+F2pbaozITLL3GSvWFLoNALvIdLg/9lPIYcUrtXfhijU25FIeQ1nYBVj33nsvDh06hE2bNqGurg5RUVFIS0vDn/70J+Tm5jqcO378eOzfvx+bN29GXV0doqOjMWDAADz55JMYM2aM07Ufe+wxpKSkYPv27cjPz0d8fDwWLlyI+++/P0CfjojaK1t+vieaOJVD59fU/OcOuHxPKI2OEhF1VLlZqU7rayMUMuRmpXp4V3DlZqVi3cclTsdlsOJ3sfuhEkx4rW48Llk7A+Cgni/CLsBaunQpli5d6tW5ixYtwqJFi7y+tlwux4IFC7BgwYLWNo+IOhhfKgYOStW4fS2+ixpXavROx0NtdJSIqCOyBR2BqiLoz4qFVsjwcf2N0ImROG9x7Jc4qOedsAuwiIhCha8VAwuPadG3R2eXneTsCel4dfuRdjU6SkTUkbjLQJCaVBULmxfgiBSMuF5xBWWm7jhu6unyPRzU807YFbkgIgoVvlYM9FRxasywng6LkzVxKsyZ0J+pGkREHYxUFQubzkZFwITFMfsxP+YgYgXnbAmAg3q+4AwWEZGE2rKJMOA5/SJQo6NERBS6WluxsGn/FKP+JQRQwIIFsZ/hekUlNupuR53YuIlwdKQckRGKkN04OZQxwCIikkhbNxEG4NDpERERNedtxcLmAZXeYIZFbHzNVqlWBivuiylAmlKL93Sj8L2pF4DG2aqZ2WkMqFqJKYJERBJpyybCNnqDmXuNEBGRW642DG6evmcb8LMFYjr9L8FVU0MjTmNAxDlsv3YL/m1sfL9MAFPQ24hDpUREEpGiupJFBMvgEhGRW95ULPR2wO87Y29U1cbgtLmb/dj8nAwAv+yTxfRA3zHAIiKSQFGxFjIBsLoYIfQVy+ASEZEnLa3J9dyPiJig/h5HjL1w0dLFIbiKUAgAIEmVwo6MARYRkRc87TliS8WQIrgCWAaXiIjaxt06LQDIUR9Gtvo4AOCivov9uFwA5kxI91ilkAGWd7gGi4ioBc1z2W2jeba1UlKsvbJhGVwiImqtomKtPbXPlezIY8hWH8chQz/s0Q+2H5cJwLycDIzITGp1lUL6BWewiIha0NKeI1J1OlxYTERErbX5X2X47PAFt69nqUqRE3UY3zT0wY76WwAI9tes4i/pf95WKST3OINFRNQCT6N5trx0X8kFx/+OUMgw/7+jh0RERL4oKtZ6DK4EWJEZcQ7fG69D3rWREOHYCTUNnrypUkiecQaLiKgF7kbzZAJanRo4LyfDYwUoIiIib9kyKlwRIEKEDG/V3QERgNXF/MqgVI39Z2+qFJJnDLCIiFqQm5XqtIFwhELWpnVXLVWAIiIi8pa7TIuByp8wNrIE63RjoRfdp/gVHtOib4/O9n6JfVTbMEWQiKgFIzKTMGdCf3sKhSZOhVED2fEQEVHw2bYJaS5NcQFzYz6HXLDCKnr+k7/pumJqO85gERF5wTaat/lfZTh4+ILHXHciIqKm3G314WkLEG+v62qbkD6KS1gQ+xm0lk54o24cGqBs8VqsEigdBlhERP9VVKxF3r4T0OnNAIDoSDlmZqfZO7uWKjQ1pYlTwWA045rB4vZeTL8gIgp/tiCo+ca9P5z7GYXHtK3a0LdpYNZcT3klFsceQLU1Bmvr7vSYGtgUqwRKhymCRERo7Kze2V1qD64A4JrBgnUfl+CBlwtQVKxFwZGWgytNnAobVt6BF5eMwszsNLfnMRWDiKhjcLfVR8GRCx63AHGn+d6MzdWLKvxk1mBNbTZ0otqrNrJKoLQYYBERobEDNFtEl69dM1hcpmC40rSD8jQCyVQMIqKOwd33vbs+xVP/UFSsxfpdJS6LLMUKegAiqqyxeL1uPK6KUS6vEaNWYOzQFId1xdyDUVpMESQiQssBj7cVA5t3UNywkYioY/O01YerIMtd/+BuvRUAdJXp8IfYfPzH2Bv/1A9z+X6ZAO63GCCcwSIiAlxWYPKVq06RGzYSEXVs7vqBtOs6O53rqX9wlWoIAHFCPe6P/RQqwYxvjb1dvpeb2QcWZ7CIiOA+VcMXrjpFbthIRNSxueoHBqVqUHhM63TuqIGu958qKta6nAWLFgy4P24vYmQGrKnNxgVLV6dz2O8EHgMsIiK4T+HwhbvOixs2EhF1bM37gRVrCl3ORn1TeglHK6ocBuTiYn/GO7tLnc4VIGJh7AF0lenwRt04nLEkOJ0To1bgxSWjpP0w1CIGWEREaJx9alpGtyl3efJNRUfK/dQyIiIKB55Kq9tcM1js23vYyrarIuQuizCJEJCvHwwAqDC7HsSbcWc/CVpOvmKARUQE5xSOGLUCRpMFRrPoVfqgIEiwiIuIiMJS872wvGU0W53eo4QZqcpLKDN1R5mpu8f3M3siOBhgEVGH0XT00FVOui2F45eO0PuFWU33zyKSSnFxMT7++GN89dVXOHfuHORyOa6//nrMnDkT99xzDwN7onbCXYEKX8lhwX0xBeivvIA/XZ2MKmus23NZrTZ4GGARUYfQuHdIKaxiY9BUVduA9bsac9qbj/C1piNkR0b+8Pbbb6OoqAjjx4/H9OnT0dDQgD179uDhhx/G119/jT//+c/BbiIReaGlNb6aOBV0eiMaTO4H9mSwYlbMIWRGnMfWa7d6DK4A14WXKDAYYBFR2Gop390qing3v9weYBUVa7Flb7k9/91bLLtO/jJr1iw8//zziIiIcDg2Z84c/OMf/8DcuXPRrx/XWBCFOk97Ir64ZBSKirUuC1nYCBAxPboIQyPO4INrN6GowfO/+xi1gumBQcR9sIgoLNnS/FoaNWwwWeznb9hV4lVwpVIK9hkrTZwKcyb0Z0dGfnHjjTc6BFcAIJPJMH78eADAyZMng9EsIvKRu72wBqVqsGJNIdZ9XOKykIVNhvI8blFV4JP6ITjYkOHxXhEKGYtbBBlnsIgoLPma5rezoAIe+jY7hVzA7LvTGVBRUGm1jfvndO3qvOcNEYUeT3thedNXFZt64NXa8fjBnOjxPO55FRrCLsA6ePAg8vLyUF5ejurqaqjVavTq1QszZszApEmTIJP9Mnrw9ddfY/bs2S6vM23aNDz99NP2/z537hzGjRvn8tzRo0dj/fr10n4QImoTX/a0mvfcAa/Oi46UY2Z2GjsuCqrLly9j+/bt6N69O4YNGxbs5hCRl5oWUtpZUIHPDl9o8T1jI4tx0pSEcxYNfnBTih0AFk7MYN8UQsIuwDpx4gSUSiWmTZuG+Ph46PV6FBQUYOXKlThy5AhWr17t9J5p06Y5dVK9e/d2ef3s7GxkZ2c7HOvWrZt0H4CIJCHFxsHNRUYwp52Cy2g04sEHH4ROp8Mrr7zilD7oDY0mxg8t811CgucF+h0Nn4ejcH0eB787i3fzy+3p6Z6MiSzB5Kjv8LkhDefqNW7PEwTgnjE3SNnMdiGUf0fCLsBatGiR07HZs2dj8eLF2L59O5YvX45OnTo5vD5kyBBMmjTJq+unpaV5fS4RBU7zEuyuUi9kgmCvItgaUgdsRL4wm8148MEHcfjwYTzzzDMYMWJEq65TVaWD1ZvN3fwoISEWV67UBbUNoYTPw1E4P4+Nu4q9Cq5GqE7g/0R9i8PGXvigfrjHc0URYfu83AmF3xGZTHA7YNVhilwkJyfDarVCp9O5fL2+vh5Go9GraxkMBuj1eimbR0Rt0LygRVVtAwqPaTFqYJJDMYqsIcltKqfOUuwULBaLBf/7v/+LAwcO4N9hCKMAACAASURBVLHHHsOUKVOC3SQiagVvBupuijiFqVFfodjYHZt1o2Ft4c91mdCY6r5iTSGKirVSNZXaIOxmsGx0Oh2MRiPq6upQWFiInTt3om/fvkhJSXE699lnn8Wjjz4KAEhNTcV9993ntvPasGEDXnvtNQBA9+7dMX36dMyfPx9yudx/H4aIPNqyt9xpkbDRbEXBkQuwivB5MbErLMVOwWK1WvHwww8jPz8fjzzyCGbNmhXsJhGRl3zf/kPEkIjT+MGchHd0WbCg5b8vbRPSVbUN2LSnDIDz/o4UWGEbYD344IM4dOgQAEAQBIwcORKrV6922PVeoVDgjjvuQFZWFrp164aLFy9i69atePzxx3Hu3DksW7bMfq5MJsOtt96K7OxspKSkoLKyEh999BFeeukllJeX46WXXgr4ZySixs7LXcfVtNPxZjGxO6zKRMFitVrx6KOPYteuXVi+fDnmzZsX7CYRkZds2394U6G2kQhAwDu6LMhhhakVf6YbzVbsLKhgfxVkgii2YUGCH9XW1mLTpk1enRsVFYX58+c7HCstLUVNTQ0uX76MAwcOoKamBo8//jjS0tI8XstkMmH69OkoLS3Fv/71L/Ts2dPtuaIoYunSpdi3bx/ee+89DB/uOUeWiKQ379lPcaXGPym7CV3U2PD4eL9cm8gbzz33HN555x0MHDjQ5czVjTfe6LGfcoVrsEIPn4ej9v48Wtrk3pVUhRY56sN4WzcW18TINrdhw8o72nyNUBYKvyOe1mCF7AxWbW2tPRWvJfHx8U4BVnp6uv3nyZMnY9WqVZg1axby8/M97huiVCoxf/58LFu2DF9++SWmTZvm9lxBELB48WLs27cPX3zxhU8BViA6uFD45Wvv+Azbzt/P0F/BVYRChsmje4fM//78XXTNUwcXDoqLiwEAx44dw8MPP+z0+l/+8hefAyyi9qx5QaNgZxd4U2CpJb3kV7Ao9gB+tkZL0iauFw6+kA2wevTogfLycsmul5OTg23btmHfvn2YOnVqi/cGgJqamhav2717d6/PJSLpyYRfUgGlEgqdNhEAbN68OdhNIAoZtoJGtuAl2GuOXLXH13T0FHk1fhe7HzprJF6vzW7z7BXXC4eGkA2wpGYwGAAAV69ebfHcM2fOAAA0Gvd7DrTmXCKSnpTBlSAAC3K4WSMRUSjaWVDhsqDRuo9L8MO5nzHrrv5eX0uKmTBX7fFFN9lVLIndhwZRgdfrxqNWjGrVdWwDjRwcDB1hF2BVVlYiPj7e4ZgoisjLywMADB482H68pqYGXbp0cTi3vr4eb775JpRKJUaPHu3xXJPJZE9jHDt2rKSfg4g8s3WOUhJFsAITEVGI8rSmyTZz5E2QJdVMWFv3RjRBjsuWOORdG4lqa+tSnSMUMsyZ0J99VogJuwArJycHw4cPR0ZGBuLj41FZWYn8/HyUlZUhJycHN998s/3chQsXIjExERkZGfYqgh988AEuXLiAFStWIDk52X7uE088gfr6egwZMgRJSUmorKzE7t27cfLkScycOdMhcCMi/3ox7z8oPfOzX67NCkxERKFJE6fyGNQUHLlgD7A8zVC5mwlr+t3vzQxXS+1xJ0owQC+qUGONwSt1dwEQWnxPU5yxCn1hF2Dde++9OHToEDZt2oS6ujpERUUhLS0Nf/rTn5Cbm+tw7vjx47F//35s3rwZdXV1iI6OxoABA/Dkk09izJgxDudmZWXho48+wtatW1FbWwuVSoW0tDQ8//zzmDx5cgA/IVHHtvlfZR6Dq5bWZEUoZC2mdLR1VJKIiKSXm5XqMPPUnO27v6UZKnff8bbj3sxwFRVrYTCaff4MMYIeD8R9ih9Mifh7/a3wNbgCGj9nuFcJbO/CLsBaunQpli5d6tW5ixYtwqJFi7w6d8qUKW43HyaiwCk44nkB8fycDLcdsG20r6XyuazAREQUemzBzbqPS1y+LvtvrNLSDFWMWgGd3nVwZJu58vT+5gGYt9RCA34fuw9dZTp8Z7zVp/c2xT4q9MmC3QAiIm8VFWs9zk5p4lQYkZmEORP62zsgTZwKCydmYOHEDACNHbNOb3R7DVZgIiIKXSMykzB2aIrL17KGNB73NENVVKyF3uB+5mnL3vIWZ7haU9xCBRN+F7sfSfKreLtuLE6ZE316f1MNJguKirWtfj/5X9jNYBFReLKNGHpiC4xGZCY55KQ3H21sMLmO0mQCuFiYiCjE2dZZFRy5AKvY+N2dNSTFftzd2ihNnAo7Cypg8TBQd81gcfuaSilgxZrCVqWRz435HD3lVdigy0K52XWA6C2d3syCTCGOARYRhaTmC4wbTJYWRwzz9p0A4NzheDvaaBXZWRERtQez7urvtmKgq7VatuwEd+mF3mgwiWgwtW6N7j5DJr4x9sFx03Wtvn9TLMgU2pgiSEQhxzbjZBslrKptcJsv35ROb8aGXSVOqRPejjYyr52IqP1zlSpuy04I5Pe8DFakKRrXDVeYk3DY2FvS67MgU+jiDBYRhZy2bN5oERtz6JuO6nlTSpdrr4iI2gdPJdSbv7ZwouPm8S1VIpSKABEzo7/EcNUpvHA1B+ctXb16X9MS7INSNThaUeW2/+KgYOhigEVEIaeto3LNc+hddagKuQCVUoZrBgv3EiEiaic8lVAH4PE1W+AVo1b4OcASMSXqawxXncKu+iFeB1dNNw1uGijGRilRrzc5rB3joGBoY4BFRCGhaWcitRGZSfjh3M8OC6JvG5TsNn+fiIhCQ/MZKYPR7LaEuu3n5q+9m18OURTtr3mTct56Iiapv8OoyBPYqx+AvYaBXr0rOlKOmdlpLsvA19WboJALiOagYLvBAIuIgq61e4q4E6N2/GorKtai8NgvJd6tIlB4TIu+PTqzgyIiClGuZqvc8fRag8l9ZUCp9VFcxh3qEnxuSMMu/VB4u5FwZITC3h+5SpM3W0R0ilbg1YeypG4y+QEDLCIKurx9JyRN1xjev5u9lG5LI54MsIiIQpMv63Ft65GCXfjhlDkRa2rvxAlzMrwNroDGdrdUAj7Yn428xyqCRBRURcVaSdM1VEoBhce0DhUI3e1rws6KiCh0+fIdPShVE9Q1SbdGnMT1iisAgHJzCkQfgiublj4vi1q0HwywiCiobHnzUohQyKCQy3we8SQiotDjy3d04bHG7TmiI+X+ao5bwyMqMCOmCLerSv12Dxa1aF+YIkhEQdXWWSRbCXbbol9vN5FkZ0VEFFpsBS2qaxvQ9b9lyguPab0aNLOlfc/MTmvTZsK+Gqw8g5nRX6LclIQt10b55R4JXdSYPLo3U9rbEQZYRBQ0zTcE9pUmToUXlzh2aO4qEcaoFVAp5S73TSEiouByVdCi8JgWowYm4ZvSS25TvZuqqm3Alr3lUMgFmJvWNPeTDOU5zI75AqfN8Xi7bizMaN3smW3vq+ZsfVxCQiyuXKlrY2spkBhgEVHQtDU90GA0o6hY2+ImkhEKGWbc2Y8BFRFRiHJV0MJotuJoRRUEwfv1TN4EYlIZFvEjLlo6403dOBihbNU1IhQyjBqY5DRTxyyL9o0BFhEFja/pgYIAiE1G+a4ZLPZNJG3BU9Myt5ytIiJqH9z1B6FZjEgEIOD9a6OgEswwiBGtuoogAKMGJmHWXf3Rt0dn9lthhAEWEQXFi3n/8el8dyV4XZVbH5GZxI6JiKgdsa2nbc5d+lywdJdX49dR32Cj7nbUilHQtzK4AhoHDJvuych+K3wwwCKigLAtXq6qbYBcBlh82PbKlirhbuFyaI5wEhGRtwalavDZ4QsOxyIU3leFDYRE2c9YErsXJlEOuWBtnMhqI+7JGJ4YYBGR3xUVa7FhVwlsa459Ca4EAZgzob/Hc1hunYio/Soq1trLrDsSEaNWSLpXYmtpZHW4P24vrBDwet141FhjJLs2BwnDDwMsIpJc09mqhC5qXK0zoLUFnWxrrjwVxBiUqmndxYmIKOhcFbgAAKNZhMlsDlhVQHc6y67h/thPoYAFr9bdhSvWOEmvz0HC8MMAi4gk1bzU7pUafZuv6a70us3Riqo234OIiILD0/e7CAQ1uAIAqyjgqjUKG+pvxkVLF8mvz2qB4YcBFhFJasvecslz5m1VldpXlSkiIvKGp+/3YIoUjDCKCtSKUfhb3d0AvC8X763oSLlP66+aZoiw2mDoanWAdddddyE9PR0ZGRno378/0tPTkZCQIGXbiKidKSrW+mUPElsn4q7IBdMrqK3YpxEFj6fv92BRwYjfx+5DpSUWm6/dBn8EVxEKGWZmp3l9vqvNmJtvVUKhodUB1ty5c1FWVob9+/dj7dq1MBgM0Gg09o7J9n+9e/eWsr1EFGKajqbJpO9/7BUER2Qm4YdzP7usMsX0Cmor9mlEwTMiMymkAiwlzFgU+xl6yqvwqX6gZNeNjpQjMkLR6tknd5sxswph6Gl1gDVjxgz7z0ePHsVDDz2EUaNGQaFQ4JtvvsG6desgCAIiIyNx+PBhSRpLRKGl+Whaa/YqUSnlaDC5nvWyVRC0dRzcjJH8hX0aUfBs/ldZsJtgJ4cF82MOoo/iEt69dhuKTT0lua5CLmBmdlqr0gGraxvQlWny7Yoka7Aef/xxPPnkk8jKyrIf+/bbb/Hwww8jNzdXilsQUQhyV/nJF+6CKwCIUjnnpnMzRvI39mlEgVVw5ELLJwXI9OgipEdcwBbdCBw2SjdjLfo4AukqHdAdpsmHHkkCrDNnzqBPnz4Ox2666SasWrUKb775JpYuXSrFbYgoyJovrvX3qJk/1nMRtYR9GpF/FRVrsWVveUh+xx8ypOG0OQFfG2+Q9LoWET6l8nk7gMk0+dAkk+IigwcPxo4dO5yOp6amorS0VIpbEFGQ2UbTbEFVIFISOCpHwcA+jch/bBvPh1ZwJeIGxUUAwBlLAgobvC884Qtf+k1vZqw0cSqHNHoKHZLMYD366KOYNWsWLly4gLlz5yItLQ1GoxHr1q1D165dpbgFEQWZFOmAvpAL3BuEgoN9GpH/7CyoaPXG8/4hYnLUtxgbWYpXasejwuy/YMXdoKGr0uvuskQ0cSq8uGSU39pI0pAkwEpPT8eOHTvwzDPP4Ne//jUUCgUsFguUSiWee+45KW5BREEm9YyVTHBfFCM6Uu7zYmAiqbBPI2o7V0EDEHoFGSaov8fYyFIUGPqjwpzot/u4S+VzV3p91MAkFB7TOgxsMh2w/ZBso+Hrr78e69evx8WLF1FaWgpBEDBgwADuI0IUJqRYcxWhkNnTGTb/q8xlyXWmO1AoYJ9G1HqugoZ3dpf6XOjB38ZFHsfd6qMoMvTFB/XD4Y+9rgB4rHjrrvT60YoqzJnQ36GKIKvmth+SBVg2ycnJSE5OlvqyXjt48CDy8vJQXl6O6upqqNVq9OrVCzNmzMCkSZMgkzkvOzt79izWrFmDQ4cOoaamBl26dMGgQYOwevVqxMfHO5y7c+dObNy4ET/++CM6deqE7OxsLFu2DHFxcYH6iERBkZuV6tBhAo1pfN6meqiUcijkwLqPS7BlbznqG5zz70cNZIVACi3B7tOI2iNXQYM5tPICkSKvwT1R/8F3DddjW/2tEP0UXAHwmNLnqfS6rWpuQkIsrlyp81fzyA8kCbCqqqrw1FNP4csvv0RERAQ++OADJCUF54+kEydOQKlUYtq0aYiPj4der0dBQQFWrlyJI0eOYPXq1Q7nf//995g3bx6Sk5Px29/+FvHx8aiursbhw4eh0+kcAqyNGzfiL3/5C2677Tbce++9+Omnn7Bp0yYcO3YMW7ZsQURERKA/LlFARShlDp2mt/3l2KEpKDymxTVD43vdLW7+pvQSZt3Vv83tJGqLUOrTiNqjUEsDdOWCpQvW1o3DCVMyRGlqvrnUUrEmT2utqP2SJMBavXo1fv75Z7z66qv4/e9/D5PJBAB4+umnkZSUhEWLFklxG6+4utfs2bOxePFibN++HcuXL0enTp0AAAaDAcuWLcPQoUOxdu1aKJVKt9etrq7Gyy+/jNGjR9s3nASAvn374pFHHsGOHTswc+ZM/3wooiBrnu7hi9goJY5WVHn13tCqKkUdVSj1aUTtUSC28Witwcoz0IkqVJiTUGbq7vf7tbRmylV2CNdatX+ShOxFRUV47LHHMHLkSIcUvHHjxmH37t1S3KLNkpOTYbVaodPp7Mc++eQTnD9/HitWrIBSqYRer7d3pM3t378fer0es2fPtgdXADBx4kRoNBrs2rXL75+BKNCKirVYsaYQ6z4uaXUFwbp6k08d7Yo1hSgq1rbqXkRSaA99GlEoy81KRYTC8U9MmeC/FDxvZSjPYU7M5xivPgYgMCmLLaW9j8hMwpwJ/Vl6PcxIMoMlk8mgUjlPZV533XU4e/asFLfwmU6ng9FoRF1dHQoLC7Fz50707dsXKSkp9nO++OILxMTEoLa2FpMmTUJZWRlkMhmGDh2KlStXYtCgQfZzjx07BgAYOnSow33kcjkGDRqEr776CqIoOgRfRO1ZW2at2sJWQQlouWMi8odQ7NOI2hPbd7etimCMWgG9wRzUNt2guIh5MQdx3tIV7+iy4K+CFk15m+ZnW2tF4UOSAGvMmDH48MMPsWzZMofjOp0Ocrlcilv47MEHH8ShQ4cAAIIgYOTIkVi9erVDAHT69GlYLBYsXLgQd999N5YsWYLz589j7dq1mD17Nv7+97/jhhsad/K+fPky1Gq1y2IWSUlJ0Ov1uHr1Kjp37uxV+zSaGAk+ZcsSEmIDcp9w1hGe4cHvzuLdPaWorNEjvosasyekY+v+k5IGVyqlHA0m71IAjWYr3t5Vgrc/LrG3Z8ywnpK1pb3qCL+LoSAU+zSi9kynD25wdb3iMhbGfoYrljisrRsHg+j/NfNM8+vYJAmwli9fjtzcXIdjer0er7/+OjIyMlp1zdraWmzatMmrc6OiojB//nyHY3/84x8xf/58XL58GQcOHEBNTQ3q6+sdzrl27Rr0ej0mTpzosLdJZmYmZs+ejddffx0vv/yy/fO4K2JhG+k0GAxef76qKh2sfi5XyqozbdcRnmHzmaorNXq8tOU/kt9n9t1pDnuiDErV4GhFldv0QfG//zyu1Ojx6vYjqK0zdOgRvo7wu9gaMpkg+YCVP/o0oo4kWBkQ7gyPOIVaqxpr6rJRL0b6/X6eyrJTxyBJgJWYmIitW7fiqaeegl6vR25uLvR6PeLi4rBu3bpWXbO2thavvfaaV+fGx8c7BVjp6en2nydPnoxVq1Zh1qxZyM/PR9euXQEAkZGN/8iad6S33HILUlJS8M0339iPqdVqGI1Gl/dvaGhwuB5Re+KqnK4/uEuBWLGmsMU1WkazFTsLKthZUUD4o08j6kgC1a94a0f9zYgSjLjm5+CKezmSjSQBVk5ODvLy8rB+/XpcuHABZWVlUCgUGDJkSKv3h+rRowfKy8ulaJ69jdu2bcO+ffswdepUAEC3bt1w4sQJp72uACAhIQElJSX2/+7WrRv0ej1qa2udPpNWq4VarbZXJyRqTwJR6Wns0BS3r+VmpWLdxyVuX7cJ1YpUFH780acRdQRFxVp7pkKwxctqMT26CO9fG4Uaa4zfgyvOWlFTklQR/OGHH+yzOykpKbjjjjtw++23QxAEp32ngsWWvnf16lX7MVsRC63WuWKZVqu1z3QBwMCBAwEAhw8fdjjParXi2LFjSE9PZ4ELapf8vdfGr0b08ri31YjMJERHtryuhXuCUKC0hz6NKFTYqs3Oe+4A1n1cEhLBVReZDvfH7kWy/GdEwL/rv2LUCmxYeQdeXDKKwRXZtSnAWrBgAV555RUIgoCLFy86va7X67Ft27a23MJnlZWVTsdEUUReXh4AYPDgwfbjOTk5kMlk2Lp1q8P5Bw4cwKVLl3D77bfbj40bNw6RkZF49913Hc795z//icrKSuTk5Ej5MYgCxlU5XSlEKGRYODEDv//NkBbPnZmd5rENXCxMgRCKfRpRKLOttQqFoMomVtBjSexeRApGrK27E5es3hUfay2d3ox5zx1wu8VI0wCU25B0HG1KEezXrx/+/e9/QxRFTJkyBdHR0ejfvz/S09ORlpaGU6dOISEhQaq2eiUnJwfDhw9HRkYG4uPjUVlZifz8fJSVlSEnJwc333yz/dzU1FTcd999WL9+PRYuXIgxY8bgwoULeO+99xAfH4+lS5faz+3atSv+8Ic/4IUXXsDChQsxfvx4/PTTT9i4cSMyMzMxZcqUgH5OorZomsahiVMhvpMKF6r0kl1fJsCnPPTmJX2jI+UQBAE6vZlpFxQwodinEYWyUFtrFSUYsCR2LzrJ9FhTdyfOWTQBu7erLUaaF/vgNiQdhyCKYptL2Q0YMADbtm3D5cuXUVpairKyMpSVlcFsNmP58uUBnd157bXXcOjQIZw+fRp1dXWIiopCWloaJk+ejNzcXIdNI4HG2a0tW7YgLy8Pp0+fRnR0NEaPHo3ly5eje3fnHb537NiBTZs24fTp04iLi8Odd96J5cuX+7z+ilUE24dwfIb+ru7UfJFvOD7DYOBzdM0fVQRDqU/zh0D0Py3h77Oj9vo85j13INhNcBAtGLAw5jPs1g/BCXNyUNqgiVPhxSWjALgv4tT0HG+0198PfwqFZ+Kp/5EkwDKZTFAqlW29TIfCAKt9CMdn6E3VvrZYODHDYWQuHJ9hMPA5uuaPACvc+zQGWKEnFJ9H80wHV5kE/u5PvKWEGVYIsEAOQEQgNhH2ZMPKOwB4DkBt53gjFH8/gi0Unomn/keShRf//ve/cfToUSkuRUR+YssD92dnGKNWMO2B2j32adTRNV9bZUtta7p+qKhYC53e9fY1gaSABQtiP8N9MQUIZHAVo3a9yqZpQSZ357g7TuFDkgDrL3/5C06dOuV0vLy8HFVVVVLcgojaIFALkSWYECcKOvZp1NG5Wltl248Q+KVPaTAF9ztfBivmxHyO/sqLOGq8DoEIruRCY6bGjDv7ORVnal6QyV2fyL4y/EkSYJ05cwbDhg1zOn706FE8/PDDUtyCiNogUAuRrxksfr8Hkb+xT6OOzt1gXFVtA4qKtVi/qyToxS0EWHFv9CEMijiLv1+7Gd8Y+wbkvurIxtmnEZlJmDOhv33GShOnciru5K5PZF8Z/iSZo4yLi0N1dTV69uzpcHzYsGF48cUXpbgFEbVBoHLkuVcVhQP2adTRaeJULvuN6Eg5Nu0pQ5CX8AEAJkd9h2Gq0/hn/Y041OB+r0Wp6fRmbNhVAqAxyPKUFu/uObKvDH+SzGDddtttWLdundNxq9UKi4VROlEwFRVrIQtASjr3qqJwwT6NOjpX+yNGKGQQBCHoM1c2/27og3/W34j9hgEBv7dFBLbsLW/xPHfPkX1l+JMkwHrooYdw7Ngx/O53v0NJSWNUf+3aNbzxxhtIS0uT4hZE1Aq2PHl/jzb6uu8VUShjn0Ydnbv0N53eHOSWAamKSwCAcxZNUIIrG2/S/LxJI6TwJEmKYGJiIrZv344nnngCubm5UCgUsFgsiIuLwxtvvCHFLYioFQK19soqctNECh/s04hcp7/ZyrYHy52RxzAx6jDW1Y3FcVPPlt8QAlpKI6TwJFmdyMTERLz11ls4f/48ysrKoFAoMHjwYHTu3FmqWxCRj1rbEcoE+DTrxXxyCjfs06gj8Gavq+bnBcttqlJMjDqMbxt6o9jUPWjtsGGpdfKk1b8dDzzwAF544QWo1Wr8+OOP6N27NwCge/fu6N49+L/4RB1VUbEWW/aWt7pK0cKJGQCATXvKHGa/IhQyjBqYhMJjWqfjzCen9o59GnU0thRy2/e5ba8roHHWJRSCKptbIk7iN9H/xvfG6/D+tVEQpVnh0moKuYAZd/YLahsotLU6wEpISIDZ3JiLO2HCBKjVaqSlpaF///5IT09Heno60tLSoFJxZJsoUIqKtdiwqwSWNqy52llQgdysVMyZ0N/lyGbfHp29GvEkak9CuU+zWq3YuHEjtm3bhvPnzyMhIQH33HMPlixZwj6WWq2lva6aD7IFS1dZHaZFf4VSYwo26W6DNQjBVYxaAVEUcc1gYb9HXhFECXY7u3TpEsrKylBaWmr//2fPnoUgCOjVqxd2794tRVvDSlWVDlY/Vx5ISIjFlSt1fr1HuGtvz3DFmkJJRhsjFDLJFuK2t2cYqvgcXZPJBGg0MZJeM9T6tGeffRabN29GTk4Obr31VpSUlGDr1q0YO3Ys1qxZ4/P1AtH/tIS/z46C8TzmPXfA7WvuyosHS7ryPH4wJcIk3coWn0nZL/qK/16chcIz8dT/SFbkIjExEVlZWfZjer0eZWVlKC9vuYwlEUlDqg7RNorJETrqiEKpTzt58iTee+89TJ06Fc8884z9eHx8PF555RUUFBQ4tJPIW+6CqBi1IiSCq36KCxAAlJtTUBoCa67YL5IvWh1g3XXXXUhPT0dGRoY9hSIhIcH+ulqtxtChQzF06FBJGkpELYuOlEu2Q3wodLBEgRKqfdquXbsgiiLmzp3rcHzWrFlYs2YNdu3axQCLWiU3KxXv7C6FuVlOud5gRoxaEdSS7H0Ul7Ag9iC0lk44UZsMEQHYzNELVbUNKCrWMsiiFrU6wJo7dy7Kysqwf/9+rF27FgaDARqNxiFfPT093b5QmIj8TxCk64RYGZA6klDt044fP47Y2FikpjoWkomLi0OfPn1w/PjxgLaHwseIzCRs2lPqdNwiAqIoQi6gTet5W6unvBKLYg7gZ2sU1tXdEdDgSiEXnALO5poWAiFyp9UB1owZM+w/Hz16FA899BBGjRoFhUKBb775BuvWrYMgCIiMjMThw4claSwRuWar9iTViCMrA1JHE6p92uXLl5GYmOjytaSkJHz33Xc+X1PqNWutlZAQG+wmhJRAP4+D352F0ew6mLhmsEARhAgrWV6D38fuQ72owuu12agT1QG7tyAA+5IIKgAAIABJREFUmb274tipao9rFI1mKz489CPuGXNDwNoG8N+LK6H8TCRZg/X444/jySefdEhT+Pbbb/Hwww8jNzdXilsQkRvNS+22li0fnxWSqKMLpT5Nr9cjNtb1HxEqlQoGg8Hna7LIRegJxvPYuKvY7WsyAS3O5PjDTRGnYIIcr9dl46oYHdB7iyLw/Q9VXp17pUYf0P+9+O/FWSg8E09FLiSpdXnmzBn06dPH4dhNN92EVatWobCwUIpbEJEbrkrt+koTp8KLS0Zhw8o78OKSUQyuqEMLpT5NrVbDaDS6fK2hoQGRkZEBbQ+FD0/rbIMVf+/S34iXrv4PqqyhOzMBMIWeWiZJgDV48GDs2LHD6XhqaipKS53ze4lIOm0tRsF0QCJHodSndevWDZcuXXL5mlardZs+SORKUbEWK9YUYt5zByBzs7RJpZS7fc0f4oR6LIndC42sDiIE1IpRgbt5K7DPJG9IEmA9+uijeP/997FixQoUFxfDbDajvr4e69atQ9euXaW4BRG50ZqRNHmT3lOpCI3qTEShIpT6tAEDBqCurg4VFRUOx2tra3Hq1ClkZmYGtD3UftnSyW2Dcq5mqRRyAWazJWAzWNGCAUti9+J6xRXECL6nuwZCjFph72c1caqg7YVF7Yska7DS09OxY8cOPPPMM/j1r38NhUIBi8UCpVKJ5557TopbEFETtqIWVbUNiFErfKr2lKJRo/JqAyz/7UGvGSysikTURCj1ab/61a/w5ptvYtOmTXj66aftxzdv3gyz2YyJEycGtD3UfrlLJ5cJjcGWJk6FunojjAEKriIFI34fuw8auQ5v1o3DGUtCy28KsAiFDDPu7Me+kXwm2ZbY119/PdavXw+tVouSkhIIgoABAwY47CNCRG1TVKzFlr3lDntd6fRmKOQCopUyXDNY7EUqbAFYc9pqvdPoJDdQJHIUKn1aWloaZs6ciffffx/19fW45ZZbUFpairy8PIwdO5Z7YJHX3KWTW0Vgw8o7sPlfZfjs8IWAtCUCJiyO2Y9k+c9YrxuDH8yh0fcIAhAd2bgHGAs+UVtIEmBVVVXhqaeeQlFREZRKJT744AMkJfEXkkhKnqoFmi0iIiMEeyVAd8EV4H7xMjcWJmoUan3aY489hpSUFGzfvh35+fmIj4/HwoULcf/99wetTdT+2PoHV8eLirUBC64AQC409mPv6m5DialHwO7bXIRCQIRSbg+oBqVqcLSiKqibLFN4kGQN1urVq/Hzzz/jlVdeQX19PUwmEwDg6aefxltvvSXFLYg6vJaqBer0Znvn6SlYcrd4mVWRiBqFWp8ml8uxYMECfPrppzh+/DgOHjyI5cuXQ6Xiv1nyXm5WKiIUzn/2NZgs2LK3PCBtkMEKBSzQiyq8Unc3vjf1Csh93TGaRRhNViycmIHcrFQUHtM69KOb9pShqFgb1DZS+yRJgFVUVITHHnsMI0eOhEz2yyXHjRuH3bt3S3ELog5PihmmCIUMWUNSnDpZVkUi+gX7NApHIzKTMGdCf0RHyh2O6/Rmh7RzfxFgxazoL7Ao9gAEWCEiNAos2VLkXQ1i2l4j8pUkAZZMJnM5knbdddfh7NmzUtyCqMNr7QxT8+pHs+7qjzkT+rMqEpEb7NMoXAXre16AiOnRRbhRdQalphSI0vz5KZmq2ga3g5hMn6fWkGQN1pgxY/Dhhx9i2bJlDsd1Oh3kcrmbdxGRL3KzUt2uwXLHtoFwcyMykxhQEbnBPo3CWSBmqxyJyI36BreqKrBHPwifGUJvawHbgKO7NWpEvpIkwFq+fDlyc3Mdjun1erz++uvIyMiQ4hZEHZ4tIFr3cYlX5zPtj6h12KdROGm6rUcwgoXxkcdwe2Q5DugzkK8fHPD7t6RpX9l8EJP9KLWWJAFWYmIitm7diqeeegp6vR65ubnQ6/WIi4vDunXrpLgFEaExyPJUIdCG5WWJWo99GrUnzQOopt/9zavPBiPd7ZipJ5R6Cz7RDwFCZN2Vjau+0t2zJPJFmwMsi8WCjz76COPGjcP69etx4cIFlJWVQaFQYMiQIYiLi5OinUT0X7lZqR5nsWwjbuwUiHzHPo3aE1cBVNON41uqPutPveRXcMYSj4uWLvhE3yUobfAkOlLulELP9HmSSpsDLLlcjtWrV2P48OHo1KkTUlJSkJKSIkXbWuXgwYPIy8tDeXk5qquroVar0atXL8yYMQOTJk1yqAhlc/bsWaxZswaHDh1CTU0NunTpgkGDBmH16tWIj4+3n5eWlubynr1790Z+fr7fPhNRUyMyk5w2G26KmwYTtV6o9WlEnniqfDciMyloBRpGqE5gevRX2KwbhW+NoZliJwihNZtG4UWSFMEhQ4bgxx9/RM+ePaW4XJucOHECSqUS06ZNQ3x8PPR6PQoKCrBy5UocOXIEq1evdjj/+++/x7x585CcnIzf/va3iI+PR3V1NQ4fPgydTucQYAHATTfdhKlTpzoci42N9fvnImpqZnYaNuwqgYWbBhNJLpT6NCJPWqp8525zYX8aFnEKU6O+QomxOw4brw/ovX3BzYTJnyQJsKZNm4aXX34ZvXv3DnqHtGjRIqdjs2fPxuLFi7F9+3YsX74cnTp1AgAYDAYsW7YMQ4cOxdq1a6FUKlu8fs+ePTFp0iTJ203kq8Y9RFxHWKx6RNR6odSnEXniLoCy9QEtpZRLbYDyJ/w2uhAV5kRs0GXBgtCtusl+kvxJsiqCADBx4kTcfvvtuPnmm5GRkYH09HSo1WopbtFmycnJsFqt0Ol09gDrk08+wfnz5+3BlV6vh0KhaDHQMhqNMJvNiIqKCkTTqQPxtFi5qbx9J2AVXQdXrHpE1DbtoU8jAlxv39G0DxiRmYS8fScCMlsTI+gxO+YQzlo0WFd3B0zS/InpF+wnyd8k+e0vKChAWVkZSktLUVZWhs2bN+Ps2bMQBAG9evUKys73Op0ORqMRdXV1KCwsxM6dO9G3b1+HXPovvvgCMTExqK2txaRJk1BWVgaZTIahQ4di5cqVGDRokNN18/Pz8dFHH8FqtSIhIQGTJk3CAw88gMjIyEB+PApDLS1WbspTZ8lNg4naJhT7NCJXbN/1zQfmAGDFmkJU1TYgOlIOmSC4HZSTik5U4x1dFk6b49GAljOCAkUTp8KgVA2OVlSxOiAFjGRl2hMTE/H/s3fvcVFX+f/AXzPDAAOCCoPhLTOUa+qioaIWUl4Lk2jNS4oli7a25U/Tsrav6ZZbWdvXdk0rLwGWpLZgSUWmJCWR5Z1AsLxSishFuQ2Xmfn8/uA7I+NcGGCGGWZez8ejx+pnznw+53N2nDPvzznnfaKiorTHFAoFCgsLUVRUZIlLtNmSJUtw6NAhAM0LGceMGYM1a9boLGq8cOECVCoVEhMTMWXKFCxevFg7ohUfH4/du3dj8ODB2vLDhg3DlClTcPvtt6OqqgrffPMNtmzZgpMnTyIpKQkuLvb7tIbsX2uLlTVy80tMnoedBlHH2GOfRmTMrZnvbn1YZ+2NhW+XlMFLrEB+U3+cbupr1Wu1hyZTYMsZImnZZwGwvyTrsVhE0NjYiHPnzgFozqonk8kQHh6O8PDwdp2vqqoKycnJZpX18PBAQkKCzrHly5cjISEBpaWlyMrKQmVlJerq6nTK1NbWQqFQYNq0aXj99de1x8PCwhAfH493330X69ev1x7ftWuXzvvj4uKwdu1apKSkICMjA7GxsWbfn69vN7PLdoSfHxNwdFRntWGFkYXIFVUNOnXYcyjX6Dm8PKR2+f+5PdapK2I7dh5L92lEnaUzU7P3kVTiSa/9qBHccfpGX6ihn6nZljTrrNoyQ4TIEiwSYOXk5GD58uW4fv06BEGAu7s7HnzwQSxbtgy+vr7tOmdVVRU2bNhgVlm5XK4XYIWEhGj/HBsbi1WrVmHevHnIzMyEj48PAGin9cXFxem8d9SoUejTpw9++umnVq/95JNPIiUlBd9//32bAqzy8hqo1dYdrvfz88K1a9VWvYaj68w29DGyWNnH201bh9z8ElyrVBg9x6z7B9vd/+f8HFoG29EwsVhk8QdW1ujTiDpLZ2UN9BNXYbHXN2iCC96rvt/ugisA2umS5s4QIbIUiwRY//jHPxAREYEVK1bA09MTx44dQ1JSEh5++GF8/PHH7crC1K9fP4tOxYiJicHOnTuxf/9+bZr1Xr164cyZM3qp2AHAz88PBQWtZ97x9fWFu7s7KisrLVZXck6GFisDzZ3lio05GBrgi5w849MDPd0l7CiILMAafRqRpRlLitQZqdl9xDV4ymsfAODdqomoUNvf6Hp0eB9tn9haOnsiS7PI44YrV67g2WefRf/+/eHj44MJEybgo48+wqhRo7B27VpLXKLD6uvrAQA3btzQHtMksSgp0f/RWlJSoh3pMuXq1auor6/nU03qsMgwf8yfGgxPd/20tuVVDfj2+GWj0z5cXcSYM9HwRthE1DZdoU8j56aZ8qYJEMqrGrB5bwEWvJ6FhiYVJFbeQ/du13NwFSmxsXoiStXdrXuxdvB0l2De5GDt342lZGeqdrIWiwRYgYGBuHr1qt7xv/71rzh8+LAlLmG2srIyvWOCICA1NRVAc6IKjZiYGIjFYnzyySc65bOysnD16lXce++92mOGRqgEQdCu0YqOjrZI/YmalG2fOsrMgUSWY099GtGtcvNLsDWjwOgDtxqF0ugm9Jayr34I3qyKwWVVT+teqJ1uTewRFxUAVxfdn7xM1U7WZJEpgjExMXjllVfw7rvv4vbbb9cer6iogLe3tyUu0aa6REREIDQ0FHK5HGVlZcjMzERhYSFiYmIwcuRIbdmAgAA88cQT2Lp1KxITEzF+/HhcvnwZH330EeRyOf72t79py27atAmnTp3Srs+6ceMGsrKycPz4cURHR2PKlCmdep/kmNqzONnX243BFZEF2VOfRtSSZuTKyku4DZKJGjHX8xA+qxuBUnV3VKo7J1lXe9w6MmUsnT37TrIWiwRYmgx8DzzwAKKjoxESEgK1Wo0vvvgCzz//vCUuYba5c+fi0KFDSE5ORnV1NTw8PBAUFIS1a9fqJbMAgBUrVqBv375ITU3Fa6+9Bk9PT0yYMAHLli2Dv//Nf3gjR47Eb7/9hrS0NFRWVkIqlWLgwIH4+9//jsceewxisf0t7iT7ZGoz4bbOB+cTOCLLs6c+jailzswQ2JIrmrDIaz/6SyrwnTjYLqcFahjrF29NZ09kTSJB6PjOc5WVldpNGTUbM54/fx6CIOCOO+5AYGAggoKCEBQUxKl0/4dZBLsGS7fh9q8L8e3xyzrHXF3E2il+mo0hTRGLALWALvMEjp9Dy2A7GmaNLIKO3qd1Rv/TGn6edZnbHgtez+qE2uiSQomFXlkIcLmKpJoonGq6vfU3dRI3qQhj7urt8JsI89+LPntoE1P9T6sjWJs2bcKkSZMQEGD8KXnPnj0RGRmJyMhI7bHGxkacOXNG2zl999132LJlC44cOdKOWyDq+gwFV4BuqlhjmQRbSogJdbjOg6izvPDCC5gxYwaGDx9utAz7NLJXnZEhsCUJVHiiWzYGuZTgo9pxdhNcaR40dpO5YlC/HjoJLYjsQasB1jvvvAOVSqWzHqm+vl67h5Qxrq6uuOuuu3DXXXd1vJZEXVxufonB4EqjvKpBO3XQVHDFVOxEHZOeno6+ffuaDLAMYZ9G9sDQQzhXFzEC+nqj8NJ1dHxOki4XqOEqUmJX3WgcbbzTsifvAM0ALDcMJnvVroVDW7Zs0Xmy11JpaSlqa2s7VCkiR5OWfdbk657uEp2Uu4YwFTuRdWzduhXx8fG2rgZRqzTbeWiSOHi6SyASCTh90bLBlQgCpFCiAVK8Wz0RuQ2Blju5hWlmgRDZk3Ynubh+/brB47t27cKmTZuQn5/f7koROZrWpnSIRCI0KlV6x7vaeiuirkihUODnn3+2dTWI9BhLihQZ5q/NKNjYjq09TBPwZ4/D6C25jo3VE6GE/t6M9oYbBpO9sUgWwVup1Z2f4YbIHmk6R1NCBvTA6YuGH1ioBWDbyvusUTUiIrJjNwOo5t9Ut06Hs05GQQEPyY5hnPsZ7FeEQWmZ7VKtjhsGk73pGv9yiLogTedo6slayIAe+PX3G0ZfZ6dBROScDAVQLafDWWPUZrL7Kdwvy8f39UHYqxgOQGTxa1gatyshe8QAi8hKTD1d9PV2Q+K0UBSX1kCpMjy9g50GkXWIRPb/o5HIWAClOW7pB3Dj3ArxgMdJHG4IwH/rRqIrBFcAtNucENkTs6YIsjMiajtTTxffXDwWALB5b4HRMuw0iKxj06ZNOHDgAIYMGYIhQ4bg8mXjGT6JbMVUSvbc/BLERQXgwy9PG31I11ZFTX3wXX0Q0usiIHSR4MrX2439JNklswIsdkZEbWesczT3qSM7DSLLGzNmDPLz81FQUICCggLs2rVL+9rcuXMRGhqK0NBQhIWFISAgAGIxJ3qQbcRFBRh9CJeWfRZxUQEQLLBhdD9JOX5X+eCa2hv/rRvV4fNZQzeZCxqb1Hrp6TnLg+xVqwGWqc7oscceQ0hICIKCghAcHIygIKaQJtIwtF+JRAQ0NKmw4PUs+Hq7wdVFZDADlKe7/WdtIuqKtm3bBgAoLi7GL7/8ov2voKAAR44cwZEjR7SzNtzc3BAYGIiwsDC8/PLLtqw2OaHIMH+jAVZ5VQPSss+io4NXQ6UX8Xi375BeF4HvG+xzs15XFzFmT2hOE28ooyKRPWo1wDLVGR09ehRHjx7VdkYSiQQymcy6NSbqIjRf/JoOwdNdgoYmNWoUSgDNHaSLRAQRgJZ9pEQE7ndFZGX9+/dH//79MXXqVO2xCxcu6PVzp06dQl5eHgMssglTMyE6muQiWPoH5nf7HpeUchxusM+RoFsDKQZU1FWYnabd3M6ourqaa7aI/o9mvxIAWLExB7X1uh2iUiWgm8wFblIJn8oR2dgdd9yBO+64AzExMQAAQRBw7tw55OXl2bhm5KwMzYRwdRFjaIAvvj3e/uUaAS4lSOh2EFdUPfB+zf1ohNQS1bUosejmemWirqZD+2CxMyIyreUmkcbUKJT495J7O7FWRGQOkUiEgIAABATY59N9cny3zoQQi5pTtWefaH9w5YZGJHQ7iHJVN2yqngCF4Gqp6raJRAQsiAk1Og3SAsvLiGzGohsNszMiuunWTSKN4V5XRETO7eDRYiRl5BucyaD535b9SUeCjwa4IrnmXlxR9UCt4N7hurfXgphQRIb5Y2tGgcH7EXMyFHVhTI9EZCWm9sHSYBYkIiLnlptfgg27T2pnOpRXNSD5q0Lk5pdoy5jTn7Sml/gGhkovAgCKlH1QJXh06HwdpQkco/7Ux+Drxo4TdQUWHcEioptaW4DM9VZERJSWfRYNTSqdY41KNdKyz2r7h44mtPAVV+Mp730AgMLrfexizdXT67NRW6+Cr7cbQgb0QNGl61ALzSNXU0YPwJ/58JG6MAZYRBamWXdljK+3GxfuEhERAOPBU8vjHcka2F1Ui8Ve30AKFf5TPdkugisAqK1vDirLqxpQXdeEhP+bMggAfn5euHat2pbVI+oQThEksiDNuitjHaFEBE4JJCIiLWPrcFvuhxgXFQBXl7b/ZOsmUuAp72/QTdyA96on4IqqZ7vraU2aETsiR8EAi8iCWpsnL+KqXSIiaiEuKgASA31DQ5Nauw4rMswf86e2fSPg4a4X0FNci/er78MllbzDdbWmjk6DJLInDLCILKi1DkKpEviUjoiItCLD/OHhrr9iQ6kSsHlvAVZszMH2rwvb1Xd81xCM1288hHPK2yxR1TZxkYh0RuHMsWJjjk5yD6KuigEWkQWZk3KdT+mIiKilmromo6+VVzXg2+OXze47pFBivmc2eksqAYhQrvayUC1NS5wWqu0Dfb3d8MQDIfjP/4tq01YkmgyKB48WW6uaRJ2CSS6ILCguKsDopoka3PeKiIhakveU4VqlosPnkUCFBd0OIlh6Gaeabu+0NVe+3m6IDPM3mBW3rQ8VG5VqpHx1Gm8sirRU9Yg6HUewiEw4eLQYKzbmYMHrWWZNXWgt5Tr3vSIiolvFTw1pVxKLlsRQY3637xHqehk7ayNxvHGghWpnmql+zVSfaephY5kFgk0iW2KARWSEOZs/GmKs0xCLgPlTg7nvFRER6Rg/oj/mTw1u9wwHEQTM8fwBw1wvIa02Aj82DrZwDY0L6OuNtOyzBh9Emlo3FhcVYPR+5T1lFq8nUWfiFEEiI4xt/rjjmyKkZZ9FeVWDwc2C46ICkPxVoU42QVcXMYMrIiIySjPFLje/BNsyCqASzH+vC1TwEiuQUReO7IYQ61XSgNMXr2v/rHkQCTTfj6npgZr+0FB/GT+1c++ByNIYYBEZYaxjqK1X6WyQ2LIzafm/poIwIiJyDJrN5S35fW98s49bCZBChSa44P3q+6G2g4lJmj2tIsP8jW6QrBm5MtZfjh/RnxsNU5fGAIvICGMdw61adiYaxhb7EhGR49BsLq8ZgTH00K2t59uaUQDBzNGrKbKTCJX+gXerJqEB0jZfz1o0faexGR0t12yxvyRHZPtHHUR2Ki4qAG5S8/bwYOp1IiLnY2hzec1Dt7bSTA1Umxlc3ef+C6bKTuGKqgcarfC8vJvMBd1k7TtvyxGqlmvLfL3dOF2enAJHsIiMiAzzh7eXO5Iy8lsNoJh6nYjI+RjrG9rz0G3HN0Vmr7sa51aI6R7HcKxhAD6pjYQAUZuvZ8y2lfdp/3zrCJ25hgb4YsXGHE6TJ6fFESwiE8aP6I83F4+F2ETfxdTrRETOydjDtfY8dNOs7W1NuOt5zPD8CXmN/bC99h4IZv6U83SXIHFaqMl08LfW29AIVOK0UGxbeZ/Re3STSpCTV9LmDLxEjsThRrAOHjyI1NRUFBUVoaKiAjKZDAMGDMDs2bMxffp0iMU3v1hWrlyJ9PR0o+eaMWMGXn31VZ1jaWlpSEpKwvnz59G9e3dMnDgRS5cuhbe3t9XuiTqfZtFyRVUDfLzd4O8jw+Vy/X053KQixE/hdAciImdkzhojS7uklOPHhgDsrh3dpqQWIpFI21ft+KZIL6AzVm9ja6SM3buLBKitNzxtkn0lOQuHC7DOnDkDqVSKmTNnQi6XQ6FQIDs7GytXrsSJEyewZs0abdmZM2ciMlJ/p/D09HTk5uZi/PjxOseTkpLw2muv4Z577sHcuXNx6dIlJCcnIy8vDzt27ICrq6u1b486gaFFy4aEDOiBFbOHd2bViIjIjlgqa+xLmw61Wqa3pBIlqh4oV3shtXZsm+tao1Bq66xJB9+Rehu79817CwyW51plciYOF2AtXLhQ71h8fDwWLVqEXbt2YdmyZejevTsAIDw8HOHh4Tpl1Wo13n77bfj4+CAqKkp7vKKiAuvXr8e4ceOwefNmiETNc8YGDRqE559/Hp9++inmzJljxTujzmJo0bIhpdxpnojI6XU0C15ufglO/lZusswglxIs8jqAbxRDsK9+aLuuY2j6X0dHlAydI3X/GW0w11J7E2YQdUVO82nv3bs31Go1ampqtAGWIT/88ANKSkowf/58SKU3U54eOHAACoUC8fHx2uAKAKZNm4Z169YhIyODAVYXZOgJnrlP2fg0joiIzGVsxCh1/xmT77vD5RoWemWhXOWFnIbAdl27M9cKC0ZyzBs7TuSIHDbAqqmpQWNjI6qrq5GTk4O0tDQMGjQIffr0Mfm+tLQ0AEBcXJzO8by8PADQG/GSSCQYOnQofvzxRwiCoBN8kX0zNBVw894CuElFaGhqvSNg5kAiIjKHof5mW0aB0dEejb6Scizqth9Vahk2Vk9AreBu8joSMeDuKkFtvQpiEaAW0OlZ/Iwl6zA3iQeRI3DYAGvJkiU4dKh5TrNIJMKYMWOwZs0akwFQdXU19u/fj7CwMAQHB+u8VlpaCplMZjCZhb+/PxQKBW7cuIEePXqYVT9f325tuJv28/Pz6pTrdEV7DuUanArY0CTARSKC0kS+XDepBI/HhLF9zcR2sgy2I1HXZGjquUqAyeDKBSokdvsW9YIr3q2eiCrBw2A5T3cJ5kwM0o6Oubu6YM7EIJsllPD1djM4w4MPJcmZ2G2AVVVVheTkZLPKenh4ICEhQefY8uXLkZCQgNLSUmRlZaGyshJ1dXUmz/PFF1+goaFBb/QKABQKhdEkFm5uzV8a9fX1ZtUXAMrLa6A2dzfBdvLz88K1a9VWvUZXds3EGiqxqLkz0GQRHBrgi1Nny3WmdoTd3oPtawZ+Di2D7WiYWCzqtAdWRO3VninlSkiwo3YsKtWeqFQb/4wrVcC2jALtHlqatOgAbBJk2SKzIpG9sesAa8OGDWaVlcvlegFWSEiI9s+xsbFYtWoV5s2bh8zMTPj4+Bg8T3p6OqRSKWJiYvRek8lkaGxsNPi+hobmL053d9ND92RfjD1lA4BGpYC4qAA8NH4wf9QSEVGHmOpvbtVdVIuB0ms40XgHzih7t1q+oUl/6p0t06JbKrMiUVdmtwFWv379UFRUZLHzxcTEYOfOndi/fz8effRRvdfPnTuHEydOYPLkyQan+fXq1QsKhQJVVVV60wRLSkogk8lMJs8g+2MqnSzQ3Dk8NH5wJ9aIiEjXgQMHsG/fPpw4cQIlJSXo3r07goOD8eSTT2L4cG4T0VUYGtUxxEukwFPe38BLpMCZJn/UtbLmyhRbJmKyRIZCoq7M/B3qujjN9L0bN24YfF2T3OKRRx4x+PqQIUMAAMePH9c5rlarkZeXh5CQECa46GIiw/wRHW486QmzBBKRra1atQonT57EpEmT8NJLL2H27Nk4c+YMZs+ejT179ti6emSmyDB/zJ8abHIdkoeoAX/12o8e4jpsrrmvQ8EVwDVPRLbkcAFWWVmZ3jFBEJCamgoAGDZsmN7rKpUKn332Gfz8/DAqhV/bAAAgAElEQVRu3DiD573//vvh7u6OlJQUneOff/45ysrKDE4rJPuWm1+Cn05fNfo6OycisrV//etfyMzMxLPPPosZM2bgr3/9K/773/+iZ8+eWLduHdTq1vfsI/sQGeaPNxcb3iDYDY1Y5HUAt0luYEv1eJxT3mb0POY8yuWaJyLbstspgu0VExODiIgIhIaGQi6Xo6ysDJmZmSgsLERMTAxGjhyp956cnByUlpYiMTEREonE4Hl9fHzwzDPPYN26dUhMTMSkSZNw6dIlJCUlISwsDDNmzLD2rZEF3Zoy1xB2TkRka6NHj9Y75uvri4iICOzbtw/l5eXw8/OzQc2ovTTp01sa6lqM/pJybK0ZjzNK09vJCGgOoFr2Xy4SEdykYtTWq7jmicgOOFyANXfuXBw6dAjJycmorq6Gh4cHgoKCsHbtWoPZAYHm5BYA8PDDD5s8d0JCArp3747k5GT84x//gLe3N+Li4rBs2TKjGQbJvrTc6LE17JyIyF6VlJRAKpXCy4up+23B2KbB5rzHUALhnxsDcFEpR6m69bXcmusxiQSR/RIJ3FrbJpimvfPl5pfgwy9Pm9zfSsPX2w1vLh7LNrQAtqFlsB0Nc8Y07dnZ2Vi4cCEefPBBvP3227aujtM5eLQYG3af1Mne5yaV4G8zhmH8iP5mv0cMNWZ65uJQfRCKVXKzri0Ri/D/ZoVj/Ij+OHi0GClfnUZZpQLynjLETw0xen0i6lwON4JFZEzq/jNmBVecu05EltbRvR01fv/9dzz//POQy+V44YUX2lWXznjA15qu/MAgKSNfLzV6Q5MKSRn5CLv9ZhbilqNct04LFEGNxzxzcLfbeVxSys0KsDQbCofd3gOfH/xVZ5r7tUoF/rPrBKqq6x1iJKsrfz6sge2hzx7axNQDPgZY5PA0nVyNQtlq2W4yF8yeEOgQHRQR2Y+O7u0IAFevXsUTTzwBpVKJrVu3cu2VjRibYt7y+K3rfHXjWQGPehzG3W7nsbcuHDkNQSavp5lR0VJa9lm9NcS23PuKiHQxwCKHZk4yCwCcw05EVtXRvR3Ly8vx+OOPo6ysDNu2bUNYWJgFa0dtYWzT4JaZZw0FQM0EPOxxBGPcf8U+xRDsrx/S6vU01zJnDTG3FyGyDwywyKEZ7+Ru8nSXGE2dS0RkaxUVFZg/fz6uXLmCzZs3Izw83NZVckqtBTjlVQ145p3vIAgCautVBsuIIcBXXI3s+mB8ofiTWdf19XZr08NCIrI9Bljk0Fp7micRAXMmmp6eQURkK9evX8fjjz+O4uJivPfee4iIiLB1lZySuQGOqanoEqigggTbasZDgAjm7GjlIhFpMwa2dm2uHyayHwywyGlxWiAR2bsFCxagqKgIsbGxKC0txWeffabz+sSJE+Hh4WGj2jkPYwGOoT2tDBnnVohIt1/xbvVE1AnuZl2z5ZrgzXsLTJZlf0ZkXxhgkcMxd68rTgskInuXn58PANizZw/27Nmj9/qBAwcYYHUCY/2JOcHVSNffMMPzJ+Q19kO9YN6embcmtjC17ot9GZH9YYBFDsXcaRxERF1BRxJjkOUYC3BEAEzFWH9yvYDZnrk43dQbH9ZEQQ2xWde79VpxUQF6fRunBBLZL/P+pRN1Adu/LsTmvQVmBVfdZHy2QERE5omLCoCri+5PJkkry6iCpX8g3vN7nFf6YWt1NFSQmH09sYHzukpvXt/TXYL5U4M5JZDITvFXJjmE7V8X4tvjl80q6yIRYfaEQCvXiIiIHIUmkNFMP/f1dkONohENTcbHr64oe+BE4wDsqhuNpjb+3Go59dDQzIwmpW03iiYi0xhgUZeXm1/SanClWYjMhcBERNQekWH+2r4jN7/EaOKJXuIbuKb2wg3BEym197brWq3tqcVNhYnsGwMs6tLMGblydRFzKgUREVlMWvZZg8f7ScrxN699+KEhEJ8rRpg8h6+3G4YG+CInr8Tk2ipT+24RkX1igEVdljkjVwAYXBERkUUZCm78JdfxV6/9qBNc8V1DcKvn0GT/G9SvB/YcOo9rlQqDsyxMZRAkIvvEAIu6rNT9Z1ot4yYVMbgiIiKLujXokYursNjrG6ggxsbqibiu9mz1/RqRYf54aPxgXLtWbbAsMwgSdT0MsKhLys0vQY1CabKMSATETwnppBoREZGzGBrgq51BIYIaiV7fQgI1/l01GWVqb5PvbWtwZCjBRstRrpZ7P3KdMZF9YIBFds9Q52Fs/ruGp7sEcyYGsZMhIiKLaNkXtUyjLkCMnbWj0Si44Kq6h977PN0lEIlEqFEo2x0AtUywcWudWo5ulVc1IPmrQu17iMg2GGCRXTPWeZja6yo6vA/mTW59/jsREZE5bu2L1ALgIapHkPQKjjcOxDnlbQbf5+vtpl1rZQ3MMEhknxhgkV0z1nlo0q7fytNdwuCKiIjazdisiZZ9kbuoEX/1OgB/yXWca+qFG4LhNVfWzvTHDINE9okBFtk1Y52EoeDK1UWMORODrFwjIiJyVObMmnBFExZ2y0JfSQW21EQbDa4AoJvMuj+zmGGQyD6JbV0BIlM83SVmlfP1dmM6diIi6hBTsyYAwAUqJHgdxECXa0ipuQcFTf1Mnk8QDDwNtKC4qAC4uuj+lGOGQSLb4wgW2TWRSNRqGU93iVXnuBMRkXNobdZEqPR3BLpcwY7asTjRdEer56utV1mwdvpayzBIRLbBAIvsWmup2AHrd2BEROQcjE250zjVNADrqqbhiqonAGjXAxtbF9wZU/WMZRgkItvhFEGya5xHTkRE1pabX4IVG3MMBlciCHjY42fc4VIKANrgCmgOqhKnhSIhJpRT9YhIiwEW2TVD88uJiIgsRZPYwvDIVXNwNd79NAJdSgy+X7Pv1PypwdqHgi3XBWuCtwWvZ2HFxhzk5hs+DxE5Dk4RJLt26/xyQ6ydpYmIiByXocQWGg/ITiDKvRDf1odgX/0Qg2U0+069uXis3lQ9bgRM5Jz4y5TshqG9RzRzyzVPAT/88jSUqpsT3V0kIsyeEGjDWhMRUVdm7OHdBPc8TJblIad+MPbU3Q2JSASVkaSA5VUNZu2fBXAjYCJnwLlXZBdunaKhecrXcipFZJg/nnggRGcKxhMPhLCTIiKidjO01lcEAf0kFTjSMBC760YBEGFBTKjRdcGe7hKDfRg3AiZyThzBIptp+bTPEENP+ZgtiYiILCkuKkBnGp8YaqghRkrtPQAAAWL4ertp+x69jYddxBCJRGhU6ma01eyfZavsgkRkOxzBIpswvaj4Jj7lIyIia4oM88f8qcHoJnNBuOt5POedAS+RAmqIoYYYLhKRNhugpuytySyMbSmiFsDsgkROyOFGsA4ePIjU1FQUFRWhoqICMpkMAwYMwOzZszF9+nSIxTe/6FauXIn09HSj55oxYwZeffVV7d+DgoIMlhs4cCAyMzMtdxNOIHX/GaOLilviUz4iIrK2yDB/RHheQd03Obig9EO9IAXQnERp9oTAVmdSGJuN0XItFjcCJnIeDhdgnTlzBlKpFDNnzoRcLodCoUB2djZWrlyJEydOYM2aNdqyM2fORGRkpN450tPTkZubi/Hjx+u9dvfdd+PRRx/VOebl5WXx+3BkufklZm0gzKd8RETUGZS/50Ox/11I5HdgyIMr8L6rrE3vv3WaIXCzD+PUdiLn43AB1sKFC/WOxcfHY9GiRdi1axeWLVuG7t27AwDCw8MRHh6uU1atVuPtt9+Gj48PoqKi9M7Vv39/TJ8+3TqVdxJp2WdbLSMWQbuHCBERkSW1XAM81Ps6nnD9Ci49esNj6jKI2hhcAfpbinCkisi5OVyAZUzv3r2hVqtRU1OjDbAM+eGHH1BSUoL58+dDKpUaLNPY2AilUgkPDw9rVdehtbauytVFzOCKiIis4ta9qc5Xu+GXbn0gDZiLke7d2n1ejlQRkYbDBlg1NTVobGxEdXU1cnJykJaWhkGDBqFPnz4m35eWlgYAiIuLM/h6ZmYmPvvsM6jVavj5+WH69Ol4+umn4e7ubvF7cESt7WDPp35ERGRNmr2pfMTVuK72RLUgw9bqKPjmlmJk+GBbV4+IHIDDBlhLlizBoUOHAAAikQhjxozBmjVrIBKJjL6nuroa+/fvR1hYGIKDg/VeHzZsGKZMmYLbb78dVVVV+Oabb7BlyxacPHkSSUlJcHFx2Oa0GFPTAxOnhTKwIiIiqyqvaoCfuArPeGcir/F27KobrT1ORGQJdhsRVFVVITk52ayyHh4eSEhI0Dm2fPlyJCQkoLS0FFlZWaisrERdXZ3J83zxxRdoaGgwOnq1a9cunb/HxcVh7dq1SElJQUZGBmJjY82qLwD4+rZ/GkJb+PnZVwKOChMd2EPj7fPJob21YVfENrQMtiNRx93p3Yh40T6IISC7/ubDVGatJSJLsesAa8OGDWaVlcvlegFWSEiI9s+xsbFYtWoV5s2bh8zMTPj4+Bg8T3p6OqRSKWJiYsyu55NPPomUlBR8//33bQqwystroDa0+6AF+fl54dq1aqteo618vN2MprK1t7oC9tmGXQ3b0DLYjoaJxaJOe2BFXZ+67jqe7PYNVAolNlRNwlV1DwDMWktElmW3AVa/fv1QVFRksfPFxMRg586d2L9/v16adQA4d+4cTpw4gcmTJ6NHjx5mn9fX1xfu7u6orKy0WF0dSctMTb7ebhga4IucvBKDqWyJiIisobkv+g3zsAd9JFX4ufcs1KvcAWb8IyIrsNsAy9Lq6+sBADdu3DD4uia5xSOPPNKm8169ehX19fXw9fXtWAUd0K2ZmsqrGpCTV4KxQ/xx6mw5U9kSEZHVteyL0iQj4SpS4lK1BPOnsu8hIutwuACrrKwMcrlc55ggCEhNTQXQnKjiViqVCp999hn8/Pwwbtw4g+etrKxEz5499c67fv16AEB0dLQlqu9QNJmaWmpUqnHqbDneXDzWRrUiIiJnsje7EGHisziOgbik0vw+UCMt+ywDLCKyCocLsGJiYhAREYHQ0FDI5XKUlZUhMzMThYWFiImJwciRI/Xek5OTg9LSUiQmJkIikRg876ZNm3Dq1CmMGjUKffr0wY0bN5CVlYXjx48jOjoaU6ZMsfatdTnGMjIxUxMREXUGQdmIR4RMBHhexR9KH5Sqb+6Dyb6IiKzF4QKsuXPn4tChQ0hOTkZ1dTU8PDwQFBSEtWvXGs0OmJ6eDgB4+OGHjZ535MiR+O2335CWlobKykpIpVIMHDgQf//73/HYY49BLBZb5X66mpZrrsQiwFAeD2ZqIiIiaxNUSij2v4tAaQm214zVCa4A9kVEZD0iQRCsm8qODHLELIK3rrkyxNVFjPlTg7vMtAxmbus4tqFlsB0NYxbBtuuM/qc11v48C2o16rPeg/LcT/jjzlisP9FDL7mSPfVF/Peti+2hi+2hzx7axFT/w2EXspjU/WcMBlfi/9vb2dfbza46NCIickyqy6ehPPcT3EbPRPCEWMyfGqwdsWJfRETW5nBTBMk23kw9hhqF0uBragHYtvK+Tq4RERE5K5d+YfCIWwOJfAAAIDLMnwEVEXUaBljUYW+mHsPpi9eNvt5Nxo8ZERFZj2afq4imn1Dq2h8j7r0HkWEDbF0tInJS/OVLHZKbX2IyuAKa09kTERFZg2b9b5TLCUz1OIWD9U1I/soPADhqRUQ2wTVY1G65+SXYmlHQarnaelUn1IaIiJxRWvZZREp+QYzHCfzUcCf21N3dvKlw9llbV42InBRHsKhdtn9diG+PXzarLFPhEhGRtQQ25CHO8whONN6O1NoxENCcWYn7XBGRrXAEi9osN7/E7ODK1UWMuKgAK9eIiIickSAICJZVoKCxL1Jq7oG6xc8aPtwjIlthgEVtkptfgs17W58WCDAVLhERWY+gVkMkEsFl3OPYXh8NFSTa1/hwj4hsiVMEyWzmTAsUi4CEmFAGVUREZDXK339BQ24qZFOXIfKuPoBIjLTssyivaoCvtxviogLYDxGRzTDAIrOYOy2QwRUREVmT8koRFF//G+Ie/hBJ3QFwnysisi+cIkitys0vwRYzpgW6uojYwRERkdWoSs9Bkfm/EHv5QvbAcojcPG1dJSIiPRzBIpPaki1w/tQQK9eGiIiclariD9R99S+I3L0ge/A5iGXetq4SEZFBDLDIqLZkC4wO78PRKyIishqRhzdc/APhFjkHYs+etq4OEZFRDLBIT25+iXaxsDmiw/tg3uRgK9eKiIickbq2EiJ3L4jdvSCbvMTW1SEiahUDLNKRm1+CLRkFEATzyidOY1ILIiKyDnVtJeo+/ycktwVAdt+Ttq4OEZFZmOSCdKRknjY7uPJ0lzC4IiIiq1ArqqD44k0I9dVwvWuSratDRGQ2Bliko6HJzOgKQG29yoo1ISIiZyU01ELx5VtQV5dBNmUpJL3utHWViIjMxgCL2s3X283WVSAiIgekyHof6so/IJv0NFx6B9m6OkREbcI1WITc/BLs+KaoTSNSri5ixEUFWLFWRETkrNwiHoEQEg2X/kNsXRUiojbjCJaTy80vwbaMArOCK093CYDmkav5U4O5/oqIiCxGUCnRdPYnAIBEPgAud4TbuEZERO3DESwnt+ObIqjMWHbFVOxERGQtglqF+qz3oDx/BGIvOddcEVGXxgDLiW3/urDVkStfbzfERQVwtIqIiKxCENSoz94G5fkjcIuczeCKiLo8ThF0Urn5Jfj2+OVWy725eCyDKyIiO7F7924EBQUhKCgItbW1tq5OhwmCgIacj6D8NQeud8fBdchkW1eJiKjDGGA5odz8EmzNKGi1nGbNFRER2V5FRQXeeusteHh42LoqFqO+dg5NBd/CddgDcA2fZuvqEBFZBKcIOpk3U4/h9MXrrZaTiIA5E5kal4jIXrz22mu47bbbEBQUhM8//9zW1bEISa8AeMT+D8R+AyESiWxdHSIii+AIlhPZ/nWhWcFVN5kLFsSEcmogEZGdyM3Nxd69e/Hyyy9DIun6swtu/JQBZfEpAICk150MrojIoXAEy0mYu+aK2QKJiOxLY2MjVq9ejYcffhgjRozA7t27bV2lDmksyEL1oRS4DIqES/+hFjlnbn4J0rLPoryqgcmZiMjmGGA5Ac1eV61JnMZRKyIie/Pee++hsrISK1assMj5fH27WeQ87VGddxDVh7bDY9AI3PbnJRBJpB0+58GjxUjJLEJDU3NW3PKqBqRkFsHbyx3jR/Tv8Pk7i5+fl62rYFfYHrrYHvrsuU0YYDmBtOyzZu11xeCKiMg6qqqqkJycbFZZDw8PJCQkAADOnTuHDz74AC+++CJ8fHwsUpfy8hqo1WZ0ChbWdO5n1B/YCEmfYPSKexZlFfUA6jt83qSMfG1wpdHQpEJSRj7Cbu/R4fN3Bj8/L1y7Vm3ratgNtocutoc+e2gTsVhk9IEVAywnUF7V0GqZ6PA+nVATIiLnVFVVhQ0bNphVVi6XawOs1atXIzAwELNmzbJm9TqF6nIhxL0CIJu8BGKpG4BGi5zXWB9nTt9HRGQNDhdgHTx4EKmpqSgqKkJFRQVkMhkGDBiA2bNnY/r06RCLdfN6XLlyBZs2bcIPP/yA0tJS9OzZE8OGDcOiRYsQFhamd/60tDQkJSXh/Pnz6N69OyZOnIilS5fC29u7s27RbJs+PYHMHy+2Wo7rroiIrKtfv34oKipq03v27duHw4cP41//+heKi4u1xzX7XxUXF8PLywt9+/a1aF0tTRDUEInEcBs7F1A2QiR1s+j5fb3dDAZTvt6WvQ4RkbkcLsA6c+YMpFIpZs6cCblcDoVCgezsbKxcuRInTpzAmjVrtGXLy8vxyCOPQKlUYtasWejfvz+uXLmCTz75BN9++y0++eQTnSArKSkJr732Gu655x7MnTsXly5dQnJyMvLy8rBjxw64urra4pYNMicdu4tEhCceCOHUQCIiO3T5cnNiomeffdbg69OnT0ffvn2RlZXVmdVqE9XV31D/XRJkk5+B2LsXYOHgCgDiogKQ/FUhGpVq7TFXFzHiogIsfi0iInM4XIC1cOFCvWPx8fFYtGgRdu3ahWXLlqF79+4AgL1796K8vBwbN27E/fffry0/duxYzJkzB3v27NEGWBUVFVi/fj3GjRuHzZs3a1PKDho0CM8//zw+/fRTzJkzpxPusHXmpGPvJnPB7AmBDK6IiOxUdHQ0/P31v6M//vhj/PTTT1i3bh169LDfNUaqsouo++ptiNy7ARZIZmGMph9jFkEishcOF2AZ07t3b6jVatTU1GgDrOrq5sVxfn5+OmV79eoFAJDJZNpjBw4cgEKhQHx8vM5+HdOmTcO6deuQkZFhFwHW9q8LW03Hvm3lfZ1UGyIiaq8BAwZgwIABescPHjwIAJgwYQI8PT07uVbmUVVehuLLtyCSusPjwecg9uxp1etFhvkzoCIiu+GwGw3X1NSgoqICFy9exI4dO5CWloZBgwahT5+byRzGjBkDAHjllVdw5MgRXL16FcePH8eLL74IuVyOmTNnasvm5eUBAMLDw3WuI5FIMHToUBQUFEAQOj8rU0svbc5tNbgScy9HIiKyInXVNSi+WAeIRM3BlZfc1lUiIupUDjuCtWTJEhw6dAgAIBKJMGbMGKxZs0Zn9GnEiBFYvXo11q9fj8cee0x7PDg4GLt27dJZOFxaWgqZTGYwmYW/vz8UCgVu3Lhh9nQNS+9DMuvFDNQ2qFotN2X0ALveN8Aesb06jm1oGWxHAoDXX38dr7/+uq2rYZTIzQNi+QC4jfwzxD04qkREzsduA6z27hmisXz5ciQkJKC0tBRZWVmorKxEXV2d3nvlcjkGDx6MMWPGICgoCMXFxdiyZQsSEhKQkpKinS6oUCiMJrFwc2tetFtfb/5+Hpbch2TZf743K7gKGdADf44KsPm+AV2JPeyz0NWxDS2D7WiYqX1IqHOp66shcnGDyM0THlOW2ro6REQ2Y9cBVnv2DNEICQnR/jk2NharVq3CvHnzkJmZqd2scd++fXj66aexdetWjBs3Tlt+7NixiI2NxTvvvIO1a9cCaF6P1dhoeM+Ohobm9LDu7u7m36CFvLQ5F9drm1otx1TsRERkLUJ9DRQZ6yDq5sPgioicnt0GWO3ZM8SUmJgY7Ny5E/v378ejjz4KAEhJSYGnp6dOcAUAgwcPxp133omff/5Ze6xXr15QKBSoqqrSmyZYUlICmUymTZ7RWbZ/XYjL5YpWyzG4IiIiaxEaFaj76m2or1+BbPTM1t9AROTgHDbJxa000/du3LihPXbt2jUIgmAwOYVSqYRSqdT+fciQIQCA48eP65RTq9XIy8tDSEiIzvoua3sz9VirCS0AQAQwuCIiIqsQlA1QfL0e6rILcJ+wGC797rJ1lYiIbM7hAqyysjK9Y4IgIDU1FQAwbNgw7fFBgwahrq4OX3/9tU75kydP4sKFC7jrrpsdxf333w93d3ekpKTolP38889RVlaGmJgYS96GSQtez2p1nyuNv0wLtXJtiIjIWdV/lwTVlTNwj14I6R3DbV0dIiK7YLdTBNsrJiYGERERCA0NhVwuR1lZGTIzM1FYWIiYmBiMHDlSW3bRokX47rvvsHz5cvz8888IDAxEcXExduzYATc3NyxevFhb1sfHB8888wzWrVuHxMRETJo0CZcuXUJSUhLCwsIwY8aMTrm/Ba9nmV122CBf7gtCRERW4/qnGLj0HwLpoNG2rgoRkd0QCbbevMnCNmzYgEOHDuHChQuorq6Gh4cHgoKCEBsbi7i4OIjFuoN2v/32GzZu3IiTJ0/i6tWr6NatGyIiIvDUU08hOFh/at2nn36K5ORkXLhwAd7e3pgwYQKWLVvW5vVX7cki2JbgKmRAD6x7JopZxzqImds6jm1oGWxHw5hFsO06msVWENRQnj8Kl4F3t3tqPD/Putgeutgeutge+uyhTUz1Pw4XYHUVbe3g2hJcyVzFeHfZeLv48HV1bMOOYxtaBtvRMAZYbdeRAEsQBDQcSkbT6YOQPbC83Wuu+HnWxfbQxfbQxfbQZw9tYqr/cbg1WM5OE1wRERFZkiAIaPjxEzSdPgjXPz0ISd8wW1eJiMguMcByIFIJGFwREZFVNB7dg6a8ryENmwDXiD93auZcIqKuhAGWA3l/xX22rgIRETkgVeVlNB7fC2nQPXAbM4fBFRGRCQ6XRdBZbVvJ4IqIiKxD0rMPPB56EWK/OyES8dksEZEp/JZ0AAyuiIjIGprO5KDp/FEAgOS2QRCJ+bOBiKg1/KbsIhKNbBhs7DgREVFHNJ37GfXZW9BUeBBMOExEZD5OEewiNBsGp2WfRXlVA3y93RAXFcCNhImIyOKUl06g/sB7kPQaBNmEv3HNFRFRGzDA6kIiw/wZUBERkVUp/yiA4psNEPv2g2zqUoikbrauEhFRl8IpgkRERKSl+v0XiL1vg+yB5RC5eti6OkREXQ5HsIiIiAiCIEAkEsF15Ay4hk+DyFVm6yoREXVJHMEiIiJycqrKy6hLWwVV5R8QiUQMroiIOoAjWERERE5MXVUKxRfrAEGASCyxdXWIiLo8jmARERE5KXVNOeoy3gBUSsgeXAFxdyZSIiLqKAZYRERETkitqELdF29CaKiD7IHlkPj0s3WViIgcAqcIEhEROSGRRApx99vgGpUAid8dtq4OEZHDYIBFRETkRISmeghqASJXGTymLLV1dYiIHA6nCBIRETmR+uytUGT+LwRBbeuqEBE5JAZYRERETkR97QKkofdBJOJPACIia+AUQRsRi0UOdR1HxjbsOLahZbAd9bFN2k42bh7E/Yfauhr8/+4WbA9dbA9dbA99tm4TU9cXCYIgdGJdiIiIiIiIHBbnBxAREREREVkIAywiIiIiIiILYYBFRERERERkIQywiIiIiIiILIQBFhERERERkYUwwCIiIiIiIrIQBlhERG8T3nwAABRISURBVEREREQWwgCLiIiIiIjIQhhgERERERERWQgDLCIiIiIiIgtxsXUFqP0OHjyI1NRUFBUVoaKiAjKZDAMGDMDs2bMxffp0iMW68fOVK1ewadMm/PDDDygtLUXPnj0xbNgwLFq0CGFhYXrnT0tLQ1JSEs6fP4/u3btj4sSJWLp0Kby9vTvrFq2uLW24cuVKpKenGz3XjBkz8Oqrr+ocYxvqfw4BoLi4GBs3bsShQ4dQWVmJnj17YujQoVizZg3kcrlOWWdoQ6Bt7Xj48GHEx8cbPM/MmTPxj3/8Q+eYWq1GUlISdu7ciT/++AN+fn546KGHsHjxYri5uVn1vojaavfu3XjppZcAAMeOHYOnp6eNa9S5Dhw4gH379uHEiRMoKSlB9+7dERwcjCeffBLDhw+3dfWsht9TN+Xn52Pv3r348ccf8fvvv0MikeCOO+7AnDlz8NBDD0EkEtm6ijZ39uxZTJ8+HU1NTXjvvfcQHR1t6yrpYIDVhZ05cwZSqRQzZ86EXC6HQqFAdnY2Vq5ciRMnTmDNmjXasuXl5XjkkUegVCoxa9Ys9O/fH1euXMEnn3yCb7/9Fp988olOkJWUlITXXnsN99xzD+bOnYtLly4hOTkZeXl52LFjB1xdXW1xyxbXljacOXMmIiMj9c6Rnp6O3NxcjB8/Xuc421C/DQHg5MmTWLBgAXr37o3HHnsMcrkcFRUVOH78OGpqanQCLGdpQ6Dt7Qg0fyZHjBihc2zgwIF65f75z39i+/btiImJwV/+8hcUFBTggw8+wK+//oqNGzda7Z6I2qqiogJvvfUWPDw8UFdXZ+vq2MSqVavg5eWFSZMm4fbbb0dZWRl27tyJ2bNn44033kBsbKytq2gV/J66acuWLcjNzcWkSZMwa9YsNDQ04KuvvsJzzz2Hw4cP45///Ketq2hTgiBg1apVkEqlaGpqsnV1DBPI4SxcuFAIDg4Wrl+/rj324YcfCoGBgcL+/ft1yh45ckQIDAwUXn31Ve2x8vJyYdiwYcKCBQsEtVqtPZ6eni4EBgYKH3/8sfVvwsYMtaEhKpVKuPfee4XRo0cLjY2N2uNsQ8NtqFAohOjoaCEhIUGnvQxhGzYz1I4//vijEBgYKPz3v/9t9f1nzpwRgoKChJdeeknn+IYNG4TAwEDh4MGDFq8zUXstX75cmDZtmrB8+XIhMDBQqKmpsXWVOl1ubq7esbKyMmHUqFFCZGSkoFKpbFAr6+L3lK6jR48KDQ0NOsdUKpUwd+5cITAwUCgqKrJRzezDrl27hGHDhgn/+c9/hMDAQCErK8vWVdLDNVgOqHfv3lCr1aipqdEeq66uBgD4+fnplO3VqxcAQCaTaY8dOHAACoUC8fHxOsPQ06ZNg6+vLzIyMqxZfbtgqA0N+eGHH1BSUoJp06ZBKpVqj7MNDbfhF198gT/++AMrVqyAVCqFQqEw+vSJbdistc9iXV0dGhsbjb4/IyMDgiDg8ccf1zk+b948uLi4OE07kv3Lzc3F3r178fLLL0Mikdi6OjYzevRovWO+vr6IiIhAeXk5ysvLbVAr6+L3lK7hw4frzdAQi8WYNGkSAODXX3+1RbXsgmaU+8knn0SfPn1sXR2jGGA5gJqaGlRUVODixYvYsWMH0tLSMGjQIJ0P3pgxYwAAr7zyCo4cOYKrV6/i+PHjePHFFyGXyzFz5kxt2by8PABAeHi4znUkEgmGDh2KgoICCILQCXfWecxpQ0PS0tIAAHFxcTrH2YaG2/D7779Ht27dUFVVhenTp+NPf/oThg4dijlz5uDUqVM653PGNgTa9ll89dVXER4ejiFDhuCBBx7A7t279cr88ssv8PLyQkBAgM5xb29v3Hnnnfjll1+sdi9E5mpsbMTq1avx8MMP6017pWYlJSWQSqXw8vKydVUsjt9T5ikpKQEA+Pj42LgmtvPGG2+gR48eWLBgga2rYhLXYDmAJUuW4NChQwAAkUiEMWPGYM2aNTpP/UeMGIHVq1dj/fr1eOyxx7THg4ODsWvXLvTt21d7rLS0FDKZzGASAX9/fygUCty4cQM9evSw4l11LnPa8FbV1dXYv38/wsLCEBwcrPMa29BwG164cAEqlQqJiYmYMmUKFi9ejD/++AObNm1CfHw8du/ejcGDBwNwzjYEzGtHFxcX3HfffYiKikKvXr206ylfeukl/P7771i6dKm2bGlpKW677TaD1/L398fRo0ete0NEZnjvvfdQWVmJFStW2Loqdik7OxunTp3Cgw8+CHd3d1tXx+L4PdW60tJS7e81Z30I8eOPP2LPnj3Ytm2b3a/BZoBlB6qqqpCcnGxWWQ8PDyQkJOgcW758ORISElBaWoqsrCxUVlYaXBwsl8sxePBgjBkzBkFBQSguLsaWLVuQkJCAlJQU7XRBhUJh9IOryeRTX1/fllu0us5qw5a++OILNDQ06I1eAWxDY21YW1sLhUKBadOm4fXXX9ceDwsLQ3x8PN59912sX78eQNdsQ6Bz2nHEiBF6Heyjjz6KWbNmYfPmzfjzn/+M/v37A2huR2NPvN3c3OyyDalrau9n/9y5c/jggw/w4osvOtST+Y5+F2j8/vvveP755yGXy/HCCy9Ysop2g99TpjU2NmLJkiWoqanBv//9b7sPLqyhsbERL7/8MqZOnYqxY8faujqtYoBlB6qqqrBhwwazysrlcr0v4ZCQEO2fY2NjsWrVKsybNw+ZmZnazmrfvn14+umnsXXrVowbN05bfuzYsYiNjcU777yDtWvXAmhej2VsTUdDQwMA2N0TtM5ow1ulp6dDKpUiJiZG7zW2oeE21NzzrUHpqFGj0KdPH/z000/aY12xDQHbfBYBQCqVIiEhAUuXLsUPP/ygnfbbWjvaYxtS19Tez/7q1asRGBiIWbNmWbN6na6j3wUAcPXqVTzxxBNQKpXYunWr3jpqR8HvKeOUSiWWLFmC48eP45VXXjGYzdgZfPDBBygtLUVKSoqtq2IWBlh2oF+/figqKrLY+WJiYrBz507s378fjz76KAAgJSUFnp6eOsEVAAwePBh33nknfv75Z+2xXr16QaFQoKqqSm96VklJCWQyGbp3726x+lpCZ7RhS+fOncOJEycwefJkg1PU2IaG27BXr144c+aM3l5XQHMCloKCAu3fu2IbAp3/Wbz12gBQWVmpPdarVy+cPHnSYPmSkhKj03KI2qo9n/19+/bh8OHD+Ne//oXi4mLt8draWgDNe+Z5eXnpTGPvKjr6XVBeXo7HH38cZWVl2LZtm8H9Kh0Fv6cMU6lUePbZZ5GVlYWXXnoJM2bMsHWVbKK0tBTvv/8+Zs2ahfr6ely8eBEAtAlfrl27hosXL6Jv375wcbGP0IZJLhyQZij9xo0b2mPXrl2DIAgGkwIolUoolUrt34cMGQIAOH78uE45tVqNvLw8hISEOPwmd4basCVNcotHHnnE4OtsQ8NtOHToUAA3F+q2VFJSojNCwzZs1tpnsSVNp+Pr66s9dtddd6G6uhpnz57VKVtVVYVz58459I82sn+XL18GADz77LOYNGmS9r99+/YBAKZPn4558+bZsoo2UVFRgfnz5+PKlSv44IMP9JL9OBp+T+lTq9V47rnnkJmZieeff94p/x1olJeXo7GxESkpKTrfE2+99RYA4H/+538wadIkg78tbIUBVhdWVlamd0wQBKSmpgIAhg0bpj0+aNAg1NXV4euvv9Ypf/LkSVy4cAF33XWX9tj9998Pd3d3vWHYzz//HGVlZQanxHVVbWlDDZVKhc8++wx+fn56I4IabEPDbRgTEwOxWIxPPvlEp3xWVhauXr2Ke++9V3vMmdoQaFs7thyh0qirq8P7778PqVSq87l84IEHIBKJ9NaCbN++HUqlEtOmTbPULRC1WXR0NN555x29/0aOHAkAWLduHV5++WUb17JzXb9+HY8//jiKi4uxadMmRERE2LpKVsfvKV1qtRovvPACMjIysGzZMrvPmGdt/fr1M/g9oUnatnDhQrzzzjs6DxdtzT7G0ahdYmJiEBERgdDQUMjlcpSVlSEzMxOFhYWIiYnRdlAAsGjRInz33XdYvnw5fv75ZwQGBqK4uBg7duyAm5sbFi9erC3r4+ODZ555BuvWrUNiYiImTZqES5cuISkpCWFhYQ41RN2WNtTIyclBaWkpEhMTje7VwjY03IYBAQF44oknsHXrViQmJmL8+PG4fPkyPvroI8jlcvztb3/TlnWmNgTa1o6JiYm47bbbEBoaqs0imJ6ejsuXL2PFihXo3bu3tmxQUBDmzJmDjz/+GHV1dRg1ahROnz6N1NRUREdHIyoqyha3SwQAGDBgAAYMGKB3/ODBgwCACRMmwNPTs5NrZVsLFixAUVERYmNjUVpais8++0zn9YkTJ8LDw8NGtbMOfk/pWrduHfbs2YMhQ4bA399f7zMwfPhwbSIjZ+Dl5YUpU6boHdckgBo+fDiio6M7u1omiQRH3EjGSWzYsAGHDh3ChQsXUF1dDQ8PDwQFBSE2NhZxcXEQi3UHKH/77Tds3LgRJ0+exNWrV9GtWzdERETgqaee0kszDgCffvopkpOTceHCBXh7e2PChAlYtmyZXa57aa+2tiEALF26FF9++SW+/PJLvT07bsU21G9DQRCwY8cOpKam4sKFC9q1gcuWLTO4zsIZ2hBoWzt+8MEHOHDgAC5evIjq6mp4enrirrvuQnx8PMaPH693bpVKhQ8//BC7du3C5cuXIZfL8dBDD+Gpp57SZmQksicrV65Eeno6jh075nQBVlBQkMnXDxw4oF1v6Uj4PXXTvHnzdJI+3eq1114zmMHY2aSlpeGFF17Ae++9xwCLiIiIiIjIUXENFhERERERkYUwwCIiIiIiIrIQBlhEREREREQWwgCLiIiIiIjIQhhgERERERERWQgDLCIiIiIiIgthgEVERERERGQhDLCIiIiIiIgshAEWERERERGRhbjYugJEZDkLFixATk6OyTLPPPMMnnrqqU6qEREROYPW+p/Y2Fi88cYbnVgjItthgEXkQPLz8+Hi4oInn3zSaJnJkyd3Yo2IiMgZtNb/REZGdnKNiGyHARaRgyguLsb169cRGhqKp59+2tbVISIiJ8H+h0gX12AROYi8vDwAwJAhQ2xcEyIicibsf4h0McAichC//PILAHZwRETUudj/EOniFEEiB6Hp4I4dO4aSkhKDZRISEuDh4dGZ1SIiIgdnqv+RyWT4y1/+YotqEdmMSBAEwdaVIKKOEQQBERERqK6uNlqmR48eOHz4sPbvb731FvLz8/Hhhx+afZ3//d//xbFjx7B9+/Z217U91yUiIvvUWv8zcuRInT6jrX2AJfqd9lyXqCM4gkXkAC5cuIDq6moMHz4cqampZr3n9OnTCA4ObtN1Tp8+jZCQkPZUsUPXJSIi+6Tpf0aMGIEdO3a0Wr6tfYAl+p32XJeoI7gGi8gBaKZntKUTKiwsRGhoaJuuc/r06Ta/xxLXJSIi+6Tpf8z9Xm9rH2CJfqc91yXqCAZYRA5A08GZ+3Tu2rVrKCsr0wnINm7ciGnTpiE8PByjR4/GypUrUV9fr329vLwcpaWlEIvFmD9/PoYNG4bp06fj1KlT2jJXr17Fc889h1GjRuHuu+/G008/jbKyMpPXJSKirkvT/4SFhbVatq19jzn9DsC+h+wPAywiB9DWAOv06dNwd3fHwIEDtcdUKhVWr16NjIwMvP3228jJyUFycrLOewDgww8/xFNPPYX09HT4+/tjyZIlUCqVKC4uxsMPP4zbbrsNO3bswPbt21FZWYmXX37Z5HWJiKjrassIVlv7ntb6HQDse8gucQ0WURenVqtRUFAAiUSCwMBAs95TWFiIwMBASCQS7bGWm0P27dsX48ePx/9v7/5dUovDOI5/OmhSikEENWRCDU1BQ3NNTVJIHGoTWiPoB2hrDU0RBLVJLW0N1R/QEtUQNNjSoUU6QkgQOBTnUKHe4dIBb91MEK7e3q9N+Z7v4/bxOT5+Tzab9d6zLEt+v1/b29vq7e2VJCWTScViMeVyOa2vr8s0TS0vL3vXzM3NaX5+/su6AIDm9J4/gUBAAwMDVdfXmj3Vcqe/v1+rq6tkDxoODRbQ5LLZrBzHUSgUUjqd/uu66elpdXd3S/r4p+F8Pq/d3V1dXl7q4eFBb29ven19rTha17IsjY+PeyEnSeFwWJL0+Pio8/NzXV1dVZz0VCwW1dbWVrEHIxoA8H94z5+hoSH5fNW/UtaaPV/lTqlU0v39PdmDhkSDBTS59/GM5+dn7ezsfLrGMAzNzs56ry3LUiKRkCQVCgWZpqmRkRGlUin19PTIMAyZplkxcmhZlqampir2vb6+Vnt7u56enhQKhXR4ePihtt/v/7QuAKC51XrARa3Z81Xu9PX16ezsjOxBQ6LBAppcPB5XPB7/9nrXdWXbtheIp6enenl50dbWllpaWiRJR0dHchzHu+Pnuq7u7u5UKpW8fcrlsvb29jQ5OSmfzyfXddXV1aVgMPitugCA5lZL/tSaPdVyp7W1lexBw+KQC+CHub29lSQNDg5K+v0AYsdxdHJyItu2tb+/r83NTQWDQUWjUe8awzB0fHysTCYj27aVTCaVz+e1sLCg4eFhhcNhpVIp3dzcKJfL6eLiQmtra144/lkXAPBz1Jo91XJHEtmDhsUvWMAPY1mWotGoN58+NjammZkZraysKBAIKBaLaWJiQplMxruraFmWIpGIlpaWtLi4qEKhoNHRUR0cHKizs1OSlE6ntbGxoUQioWKxqEgkolgsJsMwPq0LAPg5as2e7+ROR0cH2YOG1FIul8v/+kMAAAAAwP+AEUEAAAAAqBMaLAAAAACoExosAAAAAKgTGiwAAAAAqBMaLAAAAACoExosAAAAAKgTGiwAAAAAqBMaLAAAAACok18jvMQTp4BDhgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 864x432 with 2 Axes>"
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CEHy6mb-GtHp"
},
"source": [
"We see that the model prediction for the energy is extremely accurate and the force prediction is reasonable. To make this a bit more quantitative, we can compute the RMSE of the energy and convert it to meV / atom."
]
},
{
"cell_type": "code",
"metadata": {
"id": "JMTkHZ9uGrES",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "10e02429-fbf7-4b2b-9525-272e91c64b9b"
},
"source": [
"rmse = energy_loss(params, test_positions, test_energies) * 1000 / 64\n",
"print('RMSE Error of {:.02f} meV / atom'.format(rmse))"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"RMSE Error of 7.98 meV / atom\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kw14TB9YHovL"
},
"source": [
"We see that we get an error of about $2$ meV / atom, which is comparable to previous work on this system.\n",
"\n",
"Now that we have a well-performing neural network, we can see how easily this network can be used to run a simulation approximating Silicon. We will run a constant temperature simulation using a Nose-Hoover thermostat. First, we \"bake\" the params into the energy function using partial evaluation."
]
},
{
"cell_type": "code",
"metadata": {
"id": "-5pfEq6YHFxd"
},
"source": [
"E_fn = partial(energy_fn, params)"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "J92Mr9OvI5ME"
},
"source": [
"Then, we setup the parameters of the simulation and create the simulation environment."
]
},
{
"cell_type": "code",
"metadata": {
"id": "zLQjf50qI2Ed"
},
"source": [
"K_B = 8.617e-5\n",
"dt = 1e-3\n",
"kT = K_B * 300 \n",
"Si_mass = 2.91086E-3\n",
"\n",
"init_fn, apply_fn = simulate.nvt_nose_hoover(E_fn, shift, dt, kT)\n",
"\n",
"apply_fn = jit(apply_fn)"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uzt4iy2JJACQ"
},
"source": [
"Finally we run the simulation for 10000 steps while writing the energy and temperature throughout."
]
},
{
"cell_type": "code",
"metadata": {
"id": "pXpsLlnXI9K5",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e7485d23-8d23-4236-dc86-226c7e52c9c7"
},
"source": [
"# Define the simulation.\n",
"total_steps = 10000\n",
"steps_per_recording = 25\n",
"total_records = total_steps // steps_per_recording\n",
"\n",
"positions = []\n",
"\n",
"@jit\n",
"def sim(state, nbrs):\n",
" def step(i, state_nbrs):\n",
" state, nbrs = state_nbrs\n",
" nbrs = nbrs.update(state.position)\n",
" return apply_fn(state, neighbor=nbrs), nbrs\n",
" return lax.fori_loop(0, steps_per_recording, step, (state, nbrs))\n",
"\n",
"\n",
"# Initialize the simulation\n",
"\n",
"nbrs = neighbor_fn(test_positions[0])\n",
"state = init_fn(key, test_positions[0], Si_mass, neighbor=nbrs)\n",
"\n",
"\n",
"# Run the simulation.\n",
"\n",
"print('Energy (eV)\\tTemperature (K)')\n",
"for i in range(total_records):\n",
" state, nbrs = sim(state, nbrs)\n",
"\n",
" positions += [state.position]\n",
"\n",
" if i % 40 == 0:\n",
" print('{:.02f}\\t\\t\\t{:.02f}'.format(\n",
" E_fn(state.position, neighbor=nbrs),\n",
" quantity.temperature(momentum=state.momentum, mass=Si_mass) / K_B))\n",
"\n",
"positions = np.stack(positions)"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Using a depricated code path to create / update neighbor lists. It will be removed in a later version of JAX MD. Using `neighbor_fn.allocate` and `neighbor_fn.update` is preferred.\n",
"WARNING:absl:Using a depricated code path to create / update neighbor lists. It will be removed in a later version of JAX MD. Using `neighbor_fn.allocate` and `neighbor_fn.update` is preferred.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Energy (eV)\tTemperature (K)\n",
"-375.17\t\t\t257.33\n",
"-378.84\t\t\t333.09\n",
"-378.59\t\t\t244.51\n",
"-378.16\t\t\t248.73\n",
"-378.60\t\t\t342.88\n",
"-378.53\t\t\t269.34\n",
"-378.91\t\t\t297.15\n",
"-378.36\t\t\t293.20\n",
"-378.28\t\t\t304.22\n",
"-378.61\t\t\t288.59\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XbpisZHBKIOP"
},
"source": [
"We see that the energy of the simulation is reasonable and the temperature is stable. Of course, if we were validating this model for use in a research setting there are many measurements that one would like to perform to check its fidelity.\n",
"\n",
"We can now draw the simulation to see what is happening."
]
},
{
"cell_type": "code",
"metadata": {
"id": "WYxLhxrUjcB7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 826
},
"outputId": "7f32b5ea-7623-4f77-d496-980155cbceb6"
},
"source": [
"from jax_md.colab_tools import renderer\n",
"\n",
"nbrs = neighbor_fn(state.position)\n",
"\n",
"renderer.render(box_size,\n",
" {\n",
" 'atom': renderer.Sphere(positions),\n",
" 'bonds': renderer.Bond('atom', nbrs.idx),\n",
" },\n",
" resolution=[512, 512])"
],
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Using a depricated code path to create / update neighbor lists. It will be removed in a later version of JAX MD. Using `neighbor_fn.allocate` and `neighbor_fn.update` is preferred.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<!--\n",
" Copyright 2020 Google LLC\n",
" Licensed under the Apache License, Version 2.0 (the \"License\");\n",
" you may not use this file except in compliance with the License.\n",
" You may obtain a copy of the License at\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
" Unless required by applicable law or agreed to in writing, software\n",
" distributed under the License is distributed on an \"AS IS\" BASIS,\n",
" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
" See the License for the specific language governing permissions and\n",
" limitations under the License.\n",
"-->\n",
"\n",
"<!--\n",
" A fragment of HTML and Javascript that describes a visualization tool.\n",
" \n",
" This code is expected to be injected into Jupyter or Colaboratory notebooks using the `IPython.display.HTML` function. The tool is rendered using WebGL2.\n",
"-->\n",
"\n",
"<div id='seek'>\n",
" <button type='button'\n",
" id='pause_play' \n",
" style='width:40px; vertical-align:middle;' \n",
" onclick=\"toggle_play()\"> || \n",
" </button>\n",
" <input type=\"range\" \n",
" min=\"0\"\n",
" max=\"1\"\n",
" value=\"0\"\n",
" style=\"width:512px; vertical-align:middle;\"\n",
" class=\"slider\"\n",
" id=\"frame_range\"\n",
" oninput='change_frame(this.value)'>\n",
"</div>\n",
"<canvas id=\"canvas\"></canvas>\n",
"<div id='info'> </div>\n",
"<div id='error' style=\"color:red\"> </div>\n",
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/gl-matrix/2.8.1/gl-matrix-min.js\"></script>\n",
"\n",
"<script>\n",
" var DIMENSION;\n",
"\n",
" var SIZE;\n",
"\n",
" var SHAPE = {};\n",
"\n",
" var GEOMETRY = {};\n",
"\n",
" var CURRENT_FRAME = 0;\n",
" var FRAME_COUNT = 0;\n",
"\n",
" var BOX_SIZE;\n",
" var READ_BUFFER_SIZE = null;\n",
" var IS_LOADED = false;\n",
" var SIMULATION_IDX = 0;\n",
"\n",
" // Info\n",
"\n",
" var INFO = document.getElementById('info');\n",
" var ERROR = document.getElementById('error');\n",
"\n",
" // Graphics\n",
"\n",
" var GL;\n",
" var SHADER;\n",
" var BACKGROUND_COLOR = [0.2, 0.2, 0.2];\n",
"\n",
" // 3D Camera\n",
"\n",
" var EYE = mat4.create();\n",
" var PERSPECTIVE = mat4.create();\n",
" var LOOK_AT = mat4.create()\n",
" var YAW = 0.0;\n",
" var PITCH = 0.0;\n",
" var CAMERA_POSITION = mat4.create();\n",
" var Y_ROTATION_MATRIX = mat4.create();\n",
" var X_ROTATION_MATRIX = mat4.create();\n",