-
Notifications
You must be signed in to change notification settings - Fork 0
/
nnue.rs
798 lines (710 loc) · 27.9 KB
/
nnue.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
#![allow(non_upper_case_globals)]
use crate::color::{WB, Color::*};
use crate::default::DEFAULT_NETWORK;
use crate::misc::vmirror;
use crate::piece::{KQRBNP, Piece::{WhiteKing, BlackKing}};
use crate::rand::RandDist;
use crate::state::{MiniState, State};
use crate::simd::{LANES, simd_load, relu_ps, horizontal_sum};
use std::fs::File;
use std::io::{Read, Write, BufWriter, Error};
use std::mem::MaybeUninit;
use std::simd::Simd;
struct Guard<const B : bool> { }
impl <const B : bool> Guard<B> {
const CHECK : () = assert!(B);
fn assert() { let _ = Self::CHECK; }
}
macro_rules! static_assert {
($cond:expr) => { Guard::<{$cond}>::assert(); }
}
// ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
// The Expositor NNUE has 768 inputs, two hidden layers, and a single output neuron. For the
// network to be efficiently updatable, the order of the inputs and first layer neurons (and
// the weights between them) is fixed: inputs #0 to #383 are for white and #384 to #767 are
// for black. However, for the sake of symmetry (and to avoid having the network learn some
// concepts twice), if black is the side to move we swap the upper and lower banks of the
// first layer activations before computing the second layer. In this way we ensure the
// lower bank, #0 to #(N1-1), is for the side to move and the upper bank, #N1 to #(N1×2-1),
// is for the side waiting.
//
// For this to work, we do two things. First, we arrange the inputs for black so that black's
// position is flipped vertically, e.g. input #4 is hot when a white king is on e1 and the
// corresponding input, #(Np+4), is hot when a black king is on e8. Second, we mirror the
// weights to the banks of the first layer, so that e.g. the #4 → #0 weight is the same as
// the #(Np+4) → #N1 weight, or e.g. the #(Np+4) → #0 weight is the same as the #4 → #N1
// weight.
//
// neuron-to
// ─┴─
// weight[x][n] = weight[x±Np][n±N1] with signs chosen so that n±N1 and x±Np
// ─┬─ are not out of bounds
// input-from
//
// Since the weights are mirrored, we only bother storing half of them. Here is the same
// equality as above, written out explicitly, with the canonical form on the righthand of
// each equality:
//
// x = 0..Np x = Np..Np×2
// -----------------------------------------------------------
// |
// n = 0..N1 | w1[x][n] w1[x][n]
// |
// n = N1..N1×2 | w1[x][n] = w1[x+Np][n-N1] w1[x][n] = w1[x-Np][x-N1]
// |
//
// A last word about notation and terminology: we always use minuscule s to denote the
// stimiulus or weighted sum of inputs to a neuron, e.g. s1[n]. The letter a is used for
// activation, by which we always mean the output of the activation function, and so we have
// that a1[n] := relu(s1[n]) and a2[n] := relu(s2[n]) for all n.
pub type Simd32 = Simd<f32, LANES>;
const SIMD_ZERO : Simd32 = Simd::from_array([0.0; LANES]);
pub const SideToMove : usize = 0;
pub const SideWaiting : usize = 1;
pub const SameSide : usize = 0;
pub const OppoSide : usize = 1;
// We switch between different regions based on the position of the kings.
pub const REGIONS : usize = 5;
// We switch between different heads (layer two and output neurons)
// based on the number of men on the board
pub const HEADS : usize = 4;
// Np must be a multiple of 2×LANES
// N1 must be a multiple of 8×LANES
pub const Np : usize = 384; // Do not modify
pub const N1 : usize = 256; // Okay to vary
// N2 must be a multiple of LANES
pub const N2 : usize = 8; // Okay to vary
pub const N3 : usize = 1; // Do not modify
// Number of input vectors and number of vectors per layer
pub const vNp : usize = Np / LANES;
pub const vN1 : usize = N1 / LANES;
pub const vN2 : usize = N2 / LANES;
const BODY_TOTAL : usize = (Np*2)*N1 // Weights
+ N1; // Biases
const HEAD_TOTAL : usize = (N1*2)*N2 + N2*N3 // Weights
+ N2 + N3; // Biases
// We set the alignment of Network structs to 32 bytes so that SIMD loads and stores will be
// aligned. The Rust reference states that the size of the struct will be a multiple of the
// alignment, so the size of the struct in terms of single-precision floating point numbers
// (which are 4 bytes long) is a multiple of 8.
//
pub const BODY_SIZE : usize = ((BODY_TOTAL + 7) / 8) * 8;
pub const HEAD_SIZE : usize = ((HEAD_TOTAL + 7) / 8) * 8;
const SIZE : usize = BODY_SIZE*REGIONS + HEAD_SIZE*HEADS;
// ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
#[derive(Clone, PartialEq)]
#[repr(align(32))]
pub struct NetworkHead {
pub w2 : [[[f32; N1]; 2]; N2], // weight[to-second][from-fst-side][from-first]
pub w3 : [f32; N2], // weight[from-second]
pub b2 : [f32; N2], // bias[second]
pub b3 : f32, // bias
}
#[derive(Clone, PartialEq)]
#[repr(align(32))]
pub struct NetworkBody {
pub w1 : [[[f32; N1]; Np]; 2], // weight[from-inp-side][from-input][to-first]
pub b1 : [f32; N1], // bias[first]
}
#[derive(Clone, PartialEq)]
#[repr(align(32))]
pub struct Network {
pub rn : [NetworkBody; REGIONS],
pub hd : [NetworkHead; HEADS],
}
fn static_assert_block() { static_assert!(SIZE == std::mem::size_of::<Network>()); }
pub static mut NETWORK : Network = unsafe { std::mem::transmute(*DEFAULT_NETWORK) };
// ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
pub static mut S1_MAX : Simd32 = SIMD_ZERO;
pub static mut S1_MIN : Simd32 = SIMD_ZERO;
pub static mut S2_INP_MAX : f32 = 0.0;
pub static mut S2_INP_MIN : f32 = 0.0;
pub static mut S2_WORST_MAX : f32 = 0.0;
pub static mut S2_WORST_MIN : f32 = 0.0;
pub static mut S2_MAX : f32 = 0.0;
pub static mut S2_MIN : f32 = 0.0;
pub fn print_stats()
{
use std::simd::num::SimdFloat;
unsafe {
let mut w1_max = 0.0;
let mut w1_min = 0.0;
for r in 0..REGIONS {
let body = &NETWORK.rn[r];
for c in WB {
for x in 0..Np {
for n in 0..N1 {
let w = body.w1[c][x][n];
if w > w1_max { w1_max = w; }
if w < w1_min { w1_min = w; }
}
}
}
}
let mut w2_max = 0.0;
let mut w2_min = 0.0;
for h in 0..HEADS {
let head = &NETWORK.hd[h];
for n in 0..N2 {
for c in WB {
for x in 0..N1 {
let w = head.w2[n][c][x];
if w > w2_max { w2_max = w; }
if w < w2_min { w2_min = w; }
}
}
}
}
let mut w3_max = 0.0;
let mut w3_min = 0.0;
for h in 0..HEADS {
let head = &NETWORK.hd[h];
for n in 0..N2 {
let w = head.w3[n];
if w > w3_max { w3_max = w; }
if w < w3_min { w3_min = w; }
}
}
let upper = S1_MAX.reduce_max();
let lower = S1_MIN.reduce_min();
eprintln!("w1 maximum: {w1_max:+11.6}");
eprintln!("w1 minimum: {w1_min:+11.6}");
eprintln!();
eprintln!("s1 maximum: {:+11.6}", upper);
eprintln!("s1 minimum: {:+11.6}", lower);
eprintln!();
eprintln!("w2 maximum: {w2_max:+11.6}");
eprintln!("w2 minimum: {w2_min:+11.6}");
eprintln!();
eprintln!("s2 inp max: {:+11.6}", S2_INP_MAX );
eprintln!("s2 inp min: {:+11.6}", S2_INP_MIN );
eprintln!("s2 peak mx: {:+11.6}", S2_WORST_MAX);
eprintln!("s2 peak mn: {:+11.6}", S2_WORST_MIN);
eprintln!("s2 maximum: {:+11.6}", S2_MAX );
eprintln!("s2 minimum: {:+11.6}", S2_MIN );
eprintln!();
eprintln!("w3 maximum: {w3_max:+11.6}");
eprintln!("w3 minimum: {w3_min:+11.6}");
}
}
// ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
impl Network {
pub const fn zero() -> Self
{
// See https://github.com/rust-lang/rust/issues/62061
// and https://github.com/maxbla/const-zero
const SZ : usize = std::mem::size_of::<Network>();
union Empty {
ary : [u8; SZ],
net : std::mem::ManuallyDrop<Network>
}
const ZERO : Empty = Empty { ary: [0; SZ] };
return std::mem::ManuallyDrop::<Network>::into_inner(unsafe { ZERO.net });
}
pub fn perturb_fst_snd(&mut self, n1 : usize)
{
let f1 = 32.0_f32.sqrt().recip();
// for r in 0..REGIONS {
// let region = &mut self.rn[r];
// for np in 0..Np { region.w1[SameSide][np][n1] += f32::triangular() * f1; }
// for np in 0..Np { region.w1[OppoSide][np][n1] += f32::triangular() * f1; }
// }
for kind in 0..6 {
for rank in 0..8 {
if kind == 5 && (rank == 0 || rank == 7) { continue; }
for file in 0..8 {
let sq = rank*8 + file;
let np = kind*64 + sq;
let a = f32::triangular() * f1;
let b = f32::triangular() * f1;
for r in 0..REGIONS {
if kind != 0 || king_region(sq) == r {
self.rn[r].w1[SameSide][np][n1] += a;
}
self.rn[r].w1[OppoSide][np][n1] += b;
}
}
}
}
let f2 = ((N1*2) as f32).sqrt().recip();
// for h in 0..HEADS {
// let head = &mut self.hd[h];
// for n2 in 0..N2 {
// head.w2[n2][SideToMove ][n1] += f32::uniform() * f2;
// head.w2[n2][SideWaiting][n1] += f32::uniform() * f2;
// }
// }
for n2 in 0..N2 {
let a = f32::uniform() * f2;
let b = f32::uniform() * f2;
for h in 0..HEADS {
self.hd[h].w2[n2][SideToMove ][n1] += a;
self.hd[h].w2[n2][SideWaiting][n1] += b;
}
}
}
pub fn perturb_thd(&mut self)
{
let f3 = (N2 as f32).sqrt().recip();
for h in 0..HEADS {
let head = &mut self.hd[h];
for n2 in 0..N2 { head.w3[n2] += f32::uniform() * f3; }
}
}
fn checksum(&self) -> u64
{
// Fletcher's checksum
let mut lo : u32 = 0;
let mut hi : u32 = 0;
let array = unsafe { std::mem::transmute::<_,&[u32; SIZE]>(self) };
for x in &array[0..SIZE] {
let (sum, overflow) = lo.overflowing_add(*x);
lo = if overflow { sum + 1 } else { sum };
let (sum, overflow) = hi.overflowing_add(lo);
hi = if overflow { sum + 1 } else { sum };
}
return ((hi as u64) << 32) | (lo as u64);
}
pub fn load(path : &str) -> std::io::Result<Self>
{
let mut fh = File::open(path)?;
let mut sgntr = [0; 4];
let mut array = [0; SIZE*4];
let mut check = [0; 8];
fh.read_exact(&mut sgntr)?;
if sgntr != "EXPO".as_bytes() {
return Err(Error::other("missing signature"));
}
fh.read_exact(&mut array)?;
let network = unsafe { std::mem::transmute::<_,Self>(array) };
fh.read_exact(&mut check)?;
if network.checksum() != u64::from_le_bytes(check) {
return Err(Error::other("checksum mismatch"));
}
return Ok(network);
}
pub fn save(&self, path : &str) -> std::io::Result<()>
{
let mut w = BufWriter::new(File::create(path)?);
let bytes = unsafe { std::mem::transmute::<_,&[u8; SIZE*4]>(self) };
w.write_all("EXPO".as_bytes())?;
w.write_all(bytes)?;
w.write_all(&self.checksum().to_le_bytes())?;
return Ok(());
}
pub fn save_default(&self) -> std::io::Result<()>
{
let mut w = BufWriter::new(File::create("default.nnue")?);
let bytes = unsafe { std::mem::transmute::<_,&[u8; SIZE*4]>(self) };
w.write_all(bytes)?;
return Ok(());
}
pub fn evaluate(&self, state : &State, head_idx : usize) -> f32
{
// This method is slower, but does not mutate state.
let wk_idx = state.boards[WhiteKing].trailing_zeros() as usize ;
let bk_idx = vmirror(state.boards[BlackKing].trailing_zeros() as usize);
let w_region = &self.rn[king_region(wk_idx)];
let b_region = &self.rn[king_region(bk_idx)];
let head = &self.hd[head_idx];
let mut s1 : [[MaybeUninit<Simd32>; vN1]; 2] =
[MaybeUninit::uninit_array(), MaybeUninit::uninit_array()];
for n in 0..vN1 { s1[White][n].write(simd_load!(w_region.b1, n)); }
for n in 0..vN1 { s1[Black][n].write(simd_load!(b_region.b1, n)); }
let s1 = unsafe { std::mem::transmute::<_,&mut [[Simd32; vN1]; 2]>(&mut s1) };
for kind in KQRBNP {
let mut sources = state.boards[White+kind];
while sources != 0 {
let src = sources.trailing_zeros() as usize;
let x = (kind as usize)*64 + src;
for n in 0..vN1 { s1[White][n] += simd_load!(w_region.w1[SameSide][x], n); }
for n in 0..vN1 { s1[Black][n] += simd_load!(b_region.w1[OppoSide][x], n); }
sources &= sources - 1;
}
}
for kind in KQRBNP {
let mut sources = state.boards[Black+kind];
while sources != 0 {
let src = sources.trailing_zeros() as usize;
let x = (kind as usize)*64 + vmirror(src);
for n in 0..vN1 { s1[White][n] += simd_load!(w_region.w1[OppoSide][x], n); }
for n in 0..vN1 { s1[Black][n] += simd_load!(b_region.w1[SameSide][x], n); }
sources &= sources - 1;
}
}
for c in WB { for n in 0..vN1 { s1[c][n] = relu_ps(s1[c][n]); } }
let a1 = s1;
let mut s2 : [MaybeUninit<f32>; N2] = MaybeUninit::uninit_array();
for n in 0..N2 {
let mut s = SIMD_ZERO;
let c = state.turn;
for x in 0..vN1 { s += a1[ c][x] * simd_load!(head.w2[n][SideToMove ], x); }
for x in 0..vN1 { s += a1[!c][x] * simd_load!(head.w2[n][SideWaiting], x); }
s2[n].write(head.b2[n] + horizontal_sum(s));
}
let s2 = unsafe { std::mem::transmute::<_,&mut [f32; N2]>(&mut s2) };
let mut s = SIMD_ZERO;
for x in 0..vN2 { s += relu_ps(simd_load!(s2, x)) * simd_load!(head.w3, x); }
let s3 = head.b3 + horizontal_sum(s);
return s3;
}
}
pub const fn king_region(idx : usize) -> usize {
// This function appears in Mantissa's source (with a different constant),
// but it isn't borrowed from Mantissa; it's code that I wrote. After a
// discussion with Jeremy about king regions, for fun I spent some time
// trying to optimize an implementation and ultimately came up with what
// you see here. (It's used with permission in Mantissa; that was the
// original purpose.)
if idx == 4 { return 0; }
if idx >= 24 { return 4; }
const MAP : usize = 256597072250967;
return ((MAP >> (idx*2)) & 3) + 1;
}
impl MiniState {
pub const fn head_index(&self) -> usize
{
// This is abhorrent, but "for" cannot be used in const functions.
let w = unsafe { std::mem::transmute::<_,&[u64; 2]>(&self.positions[0]) };
let b = unsafe { std::mem::transmute::<_,&[u64; 2]>(&self.positions[1]) };
let absent = (w[0] & 0x_80_80_80_80_80_80_80_80).count_ones()
+ (w[1] & 0x_80_80_80_80_80_80_80_80).count_ones()
+ (b[0] & 0x_80_80_80_80_80_80_80_80).count_ones()
+ (b[1] & 0x_80_80_80_80_80_80_80_80).count_ones();
let mut men = 32 - (absent as usize);
if self.variable[0].1 >= 0 { men += 1; }
if self.variable[1].1 >= 0 { men += 1; }
if HEADS == 8 { return (men - 1) / 4; }
else if HEADS == 4 { return (men - 1) / 8; }
else if HEADS == 2 { return (men - 1) / 16; }
else { return 0; }
}
}
impl State {
pub const fn head_index(&self) -> usize
{
let men = (self.sides[0] | self.sides[1]).count_ones() as usize;
if men > 32 { return HEADS - 1; }
if HEADS == 8 { return (men - 1) / 4; }
else if HEADS == 4 { return (men - 1) / 8; }
else if HEADS == 2 { return (men - 1) / 16; }
else { return 0; }
}
/*
pub fn initialize_nnue(&mut self)
{
let wk_idx = self.boards[WhiteKing].trailing_zeros() as usize ;
let bk_idx = vmirror(self.boards[BlackKing].trailing_zeros() as usize);
let w_region = unsafe { &NETWORK.rn[king_region(wk_idx)] };
let b_region = unsafe { &NETWORK.rn[king_region(bk_idx)] };
self.s1.clear();
// NOTE this is somewhat unsafe, but the only way
// I've found to prevent an unnecessary copy.
self.s1.reserve(1);
unsafe { self.s1.set_len(1); }
let s1 = &mut self.s1[0];
for n in 0..vN1 { s1[White][n] = simd_load!(w_region.b1, n); }
for n in 0..vN1 { s1[Black][n] = simd_load!(b_region.b1, n); }
for kind in KQRBNP {
let mut sources = self.boards[White+kind];
while sources != 0 {
let src = sources.trailing_zeros() as usize;
let x = (kind as usize)*64 + src;
for n in 0..vN1 { s1[White][n] += simd_load!(w_region.w1[SameSide][x], n); }
for n in 0..vN1 { s1[Black][n] += simd_load!(b_region.w1[OppoSide][x], n); }
sources &= sources - 1;
}
}
for kind in KQRBNP {
let mut sources = self.boards[Black+kind];
while sources != 0 {
let src = sources.trailing_zeros() as usize;
let x = (kind as usize)*64 + vmirror(src);
for n in 0..vN1 { s1[White][n] += simd_load!(w_region.w1[OppoSide][x], n); }
for n in 0..vN1 { s1[Black][n] += simd_load!(b_region.w1[SameSide][x], n); }
sources &= sources - 1;
}
}
}
pub fn evaluate(&self) -> f32
{
unsafe {
let head = &NETWORK.hd[self.head_index()];
let s1 = &self.s1[self.s1.len()-1];
let mut a1 : [[MaybeUninit<Simd32>; vN1]; 2] =
[MaybeUninit::uninit_array(), MaybeUninit::uninit_array()];
for c in WB { for n in 0..vN1 { a1[c][n].write(relu_ps(s1[c][n])); } }
let a1 = std::mem::transmute::<_,&mut [[Simd32; vN1]; 2]>(&mut a1);
let mut s2 : [MaybeUninit<f32>; N2] = MaybeUninit::uninit_array();
let c = self.turn;
for n in 0..N2 {
let mut s_a = SIMD_ZERO;
let mut s_b = SIMD_ZERO;
let mut s_c = SIMD_ZERO;
let mut s_d = SIMD_ZERO;
for x in 0..vN1/4 {
s_a += a1[c][x*4+0] * simd_load!(head.w2[n][SideToMove], x*4+0);
s_b += a1[c][x*4+1] * simd_load!(head.w2[n][SideToMove], x*4+1);
s_c += a1[c][x*4+2] * simd_load!(head.w2[n][SideToMove], x*4+2);
s_d += a1[c][x*4+3] * simd_load!(head.w2[n][SideToMove], x*4+3);
}
for x in 0..vN1/4 {
s_a += a1[!c][x*4+0] * simd_load!(head.w2[n][SideWaiting], x*4+0);
s_b += a1[!c][x*4+1] * simd_load!(head.w2[n][SideWaiting], x*4+1);
s_c += a1[!c][x*4+2] * simd_load!(head.w2[n][SideWaiting], x*4+2);
s_d += a1[!c][x*4+3] * simd_load!(head.w2[n][SideWaiting], x*4+3);
}
let s = (s_a + s_b) + (s_c + s_d);
s2[n].write(head.b2[n] + horizontal_sum(s));
{
let mut pos = 0.0;
let mut neg = 0.0;
let sm = std::mem::transmute::<_, &[f32; N1]>(&a1[ c]);
for x in 0..N1 {
let i = sm[x] * head.w2[n][SideToMove][x];
if i > S2_INP_MAX { S2_INP_MAX = i; }
if i < S2_INP_MIN { S2_INP_MIN = i; }
if i > 0.0 { pos += i; }
if i < 0.0 { neg += i; }
}
let sw = std::mem::transmute::<_, &[f32; N1]>(&a1[!c]);
for x in 0..N1 {
let i = sw[x] * head.w2[n][SideWaiting][x];
if i > S2_INP_MAX { S2_INP_MAX = i; }
if i < S2_INP_MIN { S2_INP_MIN = i; }
if i > 0.0 { pos += i; }
if i < 0.0 { neg += i; }
}
let b = head.b2[n];
if b > 0.0 { pos += b; }
if b < 0.0 { neg += b; }
if pos > S2_WORST_MAX { S2_WORST_MAX = pos; }
if neg < S2_WORST_MIN { S2_WORST_MIN = neg; }
let s = pos + neg;
if s > S2_MAX { S2_MAX = s; }
if s < S2_MIN { S2_MIN = s; }
}
}
let s2 = std::mem::transmute::<_,&mut [f32; N2]>(&mut s2);
let mut s = SIMD_ZERO;
for x in 0..vN2 { s += relu_ps(simd_load!(s2, x)) * simd_load!(head.w3, x); }
let s3 = head.b3 + horizontal_sum(s);
return s3;
}
}
*/
}
impl Network {
pub fn save_image(&self, file : &str, unif : bool, vis : i8) -> std::io::Result<()>
{
let all = vis < 0;
let region = if all { 0 } else { vis as usize };
let aspect = if all { 32 } else { 8 }; // width of image in neurons
let upscale = 2; // number of pixels per square
let wasp = aspect;
let hasp = N1 / aspect;
let wnum = wasp;
let hnum = hasp * if all { REGIONS } else { 1 };
let width = (6*8*upscale)*wnum + (6*wnum - 1) + (wasp - 1)*2;
let height = (2*8*upscale)*hnum + (2*hnum - 1) + (hasp - 1)*2;
let border = [0, 0, 32];
let mut scale = [0.0; N1];
let mut unif_scale : f32 = 0.0;
let mut count = 0;
for n in 0..N1 {
let mut bound : f32 = 0.0;
let rs = if all || unif { 0..REGIONS } else { region..region+1 };
for r in rs {
for side in 0..2 {
for x in 0..Np {
let z = self.rn[r].w1[side][x][n].abs();
bound = bound.max(z);
unif_scale += z;
if z != 0.0 { count += 1; }
}
}
}
scale[n] = bound;
}
unif_scale /= count as f32;
let mut w = BufWriter::new(File::create(format!("{}.ppm", file))?);
writeln!(&mut w, "P6")?;
writeln!(&mut w, "{} {}", width, height)?;
writeln!(&mut w, "255")?;
for tile_row in 0..hasp {
let rs = if all { 0..REGIONS } else { 0..1 };
for subtile_row in rs {
for subneuron_row in 0..2 {
let topmost = tile_row == 0 && subtile_row == 0 && subneuron_row == 0;
if !topmost {
let pixels = if subtile_row == 0 && subneuron_row == 0 { 3 } else { 1 };
for _ in 0..(width*pixels) { w.write(&border)?; }
}
for rank in (0..8).rev() {
for _ in 0..upscale {
for tile_column in 0..wasp {
let n = tile_row * aspect + tile_column;
for kind in 0..6 {
let leftmost = tile_column == 0 && kind == 0;
if !leftmost {
let pixels = if kind == 0 { 3 } else { 1 };
for _ in 0..pixels { w.write(&border)?; }
}
for file in 0..8 {
let r = if all { (REGIONS-1) - subtile_row } else { region };
let square = rank*8 + file;
let side = subneuron_row ^ 1;
let square = if side != 0 { vmirror(square) } else { square };
let x : usize = kind*64 + square;
let s = if unif { unif_scale } else { scale[n] };
if s == 0.0 {
for _ in 0..upscale { w.write(&[192, 64, 64])?; }
continue;
}
let w1 = self.rn[r].w1[side][x][n] / s;
let normed =
if unif { (1.0 + (w1 * -0.5).exp2()).recip() }
else { (w1 + 1.0) * 0.5 };
debug_assert!(
1.0 >= normed && normed >= 0.0,
"out of range ({} {} {} {} {})",
normed, w1, self.rn[r].w1[side][x][n], s,
if unif { "unif" } else { "indp" }
);
const C : f32 = 31.5;
let b1 = self.rn[r].b1[n];
let bias = if b1 > 0.0 { (b1/scale[n]) * (C*2.0) } else { 0.0 };
let red = bias + normed * (255.0 - bias*2.0);
let grn = normed * 255.0 ;
let blu = C + normed * (255.0 - C*2.0);
let red = red.round() as u8;
let grn = grn.round() as u8;
let blu = blu.round() as u8;
for _ in 0..upscale { w.write(&[red, grn, blu])?; }
}
}
}
}
}
}
}
}
w.flush()?;
let status = std::process::Command::new("convert")
.arg(&format!("{}.ppm", file)).arg(&format!("{}.png", file)).status()?;
if status.success() { std::fs::remove_file(format!("{}.ppm", file))?; }
return Ok(());
}
/*
pub fn stat(&self)
{
use std::mem::transmute;
for r in 0..REGIONS {
let region = &self.rn[r];
println!("r{r} w1"); stat_slice(unsafe { transmute::<_,&[f32; Np*N1*2]>(®ion.w1) });
println!("r{r} b1"); stat_slice(®ion.b1);
}
for h in 0..HEADS {
let head = &self.hd[h];
println!("h{h} w2"); stat_slice(unsafe { transmute::<_,&[f32; N1*N2*2]>(&head.w2) });
println!("h{h} b2"); stat_slice(&head.b2);
println!("h{h} w3"); stat_slice(&head.w3);
}
let mut fst_lo = [0.0; N1*REGIONS];
let mut fst_hi = [0.0; N1*REGIONS];
let mut fst_lo_b = [0.0; N1*REGIONS];
let mut fst_hi_b = [0.0; N1*REGIONS];
for r in 0..REGIONS {
let region = &self.rn[r];
for n in 0..N1 {
let mut inp = [[0.0; Np]; 2];
for x in 0..Np { inp[0][x] = region.w1[0][x][n]; }
for x in 0..Np { inp[1][x] = region.w1[1][x][n]; }
inp.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Strictly speaking, we should check that inp[0..16] only contains only
// negative weights and inp[Np-16..Np] contains only positive weights.
let lo : f32 = inp[0][ 0 ..16].iter().sum::<f32>()
+ inp[1][ 0 ..16].iter().sum::<f32>();
let hi : f32 = inp[0][Np-16..Np].iter().sum::<f32>()
+ inp[1][Np-16..Np].iter().sum::<f32>();
fst_lo [N1*r + n] = lo;
fst_hi [N1*r + n] = hi;
fst_lo_b[N1*r + n] = lo + region.b1[n];
fst_hi_b[N1*r + n] = hi + region.b1[n];
}
}
println!("first layer");
let lo_min = fst_lo .iter().fold(0.0f32, |a, x| a.min(*x));
let hi_max = fst_hi .iter().fold(0.0f32, |a, x| a.max(*x));
let lo_min_b = fst_lo_b.iter().fold(0.0f32, |a, x| a.min(*x));
let hi_max_b = fst_hi_b.iter().fold(0.0f32, |a, x| a.max(*x));
println!(" {lo_min:+10.3} {hi_max:+10.3}");
println!(" {lo_min_b:+10.3} {hi_max_b:+10.3}");
/*
for h in 0..HEADS {
let head = &self.hd[h];
// This assumes the first layer activations are clamped to [0, 1].
let mut snd_lo = [0.0; N2];
let mut snd_hi = [0.0; N2];
let mut snd_lo_b = [0.0; N2];
let mut snd_hi_b = [0.0; N2];
for n in 0..N2 {
let mut lo = 0.0;
let mut hi = 0.0;
for x in 0..N1 {
let w = head.w2[n][x];
if w < 0.0 { lo += w; }
if w > 0.0 { hi += w; }
}
snd_lo [n] = lo;
snd_hi [n] = hi;
snd_lo_b[n] = lo + head.b2[n];
snd_hi_b[n] = hi + head.b2[n];
}
println!("second layer head {h}");
let lo_min = snd_lo .iter().fold(0.0f32, |a, x| a.min(*x));
let hi_max = snd_hi .iter().fold(0.0f32, |a, x| a.max(*x));
let lo_min_b = snd_lo_b.iter().fold(0.0f32, |a, x| a.min(*x));
let hi_max_b = snd_hi_b.iter().fold(0.0f32, |a, x| a.max(*x));
println!(" {lo_min:+10.3} {hi_max:+10.3}");
println!(" {lo_min_b:+10.3} {hi_max_b:+10.3}");
// This assumes the second layer activations are clamped to [0, 1].
let mut thd_lo = 0.0;
let mut thd_hi = 0.0;
for n in 0..N2 {
let w = head.w3[n];
if w < 0.0 { thd_lo += w; }
if w > 0.0 { thd_hi += w; }
}
println!("third layer head {h}");
println!(" {thd_lo:+10.3} {thd_hi:+10.3}");
println!(" {:+10.3} {:+10.3}", thd_lo+head.b3, thd_hi+head.b3);
}
*/
}
*/
}
/*
fn stat_slice(xs : &[f32])
{
let len = xs.len();
let mut xs = xs.to_vec();
xs.sort_by(|a, b| a.partial_cmp(b).unwrap());
let min = xs[0];
let p05 = xs[len/20];
let med = xs[len/2];
let p95 = xs[len*19/20];
let max = xs[len-1];
let avg = xs.iter().sum::<f32>() / (len as f32);
println!(" {min:+10.6} {p05:+10.6} {med:+10.6} {p95:+10.6} {max:+10.6} ({avg:+10.6})");
for x in xs.iter_mut() { *x = x.abs(); }
xs.sort_by(|a, b| a.partial_cmp(b).unwrap());
let min = xs[0];
let p05 = xs[len/20];
let med = xs[len/2];
let p95 = xs[len*19/20];
let max = xs[len-1];
let avg = xs.iter().sum::<f32>() / (len as f32);
println!(" {min:+10.6} {p05:+10.6} {med:+10.6} {p95:+10.6} {max:+10.6} ({avg:+10.6})");
}
*/