diff --git a/Cargo.lock b/Cargo.lock index aaf9341..618ff07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,6 +6,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d2e7343e7fc9de883d1b0341e0b13970f764c14101234857d2ddafa1cb1cac2" +[[package]] +name = "ansi_term" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" +dependencies = [ + "winapi", +] + [[package]] name = "approx" version = "0.3.2" @@ -36,6 +45,17 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "0.1.7" @@ -122,6 +142,21 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "clap" +version = "2.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9" +dependencies = [ + "ansi_term", + "atty", + "bitflags", + "strsim", + "textwrap", + "unicode-width", + "vec_map", +] + [[package]] name = "cloudabi" version = "0.0.3" @@ -426,6 +461,15 @@ dependencies = [ "lzw", ] +[[package]] +name = "heck" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20564e78d53d2bb135c343b3f47714a56af2061f1c928fdb541dc7b9fdd94205" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "hermit-abi" version = "0.1.8" @@ -489,13 +533,13 @@ dependencies = [ [[package]] name = "kohonen" -version = "0.1.4" +version = "0.1.5" dependencies = [ "csv", "easy_graph", - "num-traits", "rand 0.5.6", "statistical", + "structopt", ] [[package]] @@ -778,6 +822,32 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +[[package]] +name = "proc-macro-error" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "syn-mid", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.9" @@ -1203,6 +1273,36 @@ dependencies = [ "byteorder", ] +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + +[[package]] +name = "structopt" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8faa2719539bbe9d77869bfb15d4ee769f99525e707931452c97b693b3f159d" +dependencies = [ + "clap", + "lazy_static", + "structopt-derive", +] + +[[package]] +name = "structopt-derive" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f88b8e18c69496aad6f9ddf4630dd7d585bcaf765786cb415b9aec2fe5a0430" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "syn" version = "1.0.17" @@ -1214,6 +1314,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "syn-mid" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7be3539f6c128a931cf19dcee741c1af532c7fd387baa739c03dd2e96479338a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tempfile" version = "3.1.0" @@ -1228,6 +1339,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "tiff" version = "0.4.0" @@ -1250,12 +1370,36 @@ dependencies = [ "winapi", ] +[[package]] +name = "unicode-segmentation" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e83e153d1053cbb5a118eeff7fd5be06ed99153f00dbcd8ae310c5fb2b22edc0" + +[[package]] +name = "unicode-width" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479" + [[package]] name = "unicode-xid" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" +[[package]] +name = "vec_map" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" + +[[package]] +name = "version_check" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "078775d0255232fb988e6fccf26ddc9d1ac274299aaedcedce21c6f72cc533ce" + [[package]] name = "void" version = "1.0.2" diff --git a/Cargo.toml b/Cargo.toml index cfb6264..818dcdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kohonen" -version = "0.1.4" +version = "0.1.5" authors = ["m-lange "] edition = "2018" @@ -8,10 +8,12 @@ edition = "2018" [dependencies] csv = "1.1" -num-traits = "0.2" rand = "0.5.5" easy_graph = { git = "https://github.com/mlange-42/easy_graph.git" } +# TODO put CLI in feature +structopt = "0.3" + [dev-dependencies] statistical = "1.0.0" diff --git a/cmd_examples/countries.bat b/cmd_examples/countries.bat new file mode 100644 index 0000000..11ca467 --- /dev/null +++ b/cmd_examples/countries.bat @@ -0,0 +1,14 @@ +..\target\release\kohonen.exe ^ +--file ..\example_data\countries.csv ^ +--size 20 16 ^ +--episodes 10000 ^ +--layers "child_mort_2010 birth_p_1000 GNI LifeExpectancy PopGrowth PopUrbanized PopGrowthUrb AdultLiteracy PrimSchool Income_low_40 Income_high_20" "continent" ^ +--categ 0 1 ^ +--norm gauss none ^ +--weights 1 1 ^ +--alpha 0.2 0.01 lin ^ +--radius 10 0.8 lin ^ +--decay 0.2 0.001 exp ^ +--neigh gauss ^ +--no-data - ^ +--fps 1 diff --git a/cmd_examples/countries_debug.bat b/cmd_examples/countries_debug.bat new file mode 100644 index 0000000..edcfcb3 --- /dev/null +++ b/cmd_examples/countries_debug.bat @@ -0,0 +1,13 @@ +..\target\debug\kohonen.exe ^ +--file ..\example_data\countries.csv ^ +--size 20 16 ^ +--episodes 5000 ^ +--layers "child_mort_2010 birth_p_1000 GNI LifeExpectancy PopGrowth PopUrbanized PopGrowthUrb AdultLiteracy PrimSchool Income_low_40 Income_high_20" ^ +--categ 0 ^ +--norm gauss ^ +--weights 1 ^ +--alpha 0.2 0.01 lin ^ +--radius 8 0.7 lin ^ +--decay 0.2 0.001 exp ^ +--neigh gauss ^ +--no-data - diff --git a/cmd_examples/iris.bat b/cmd_examples/iris.bat new file mode 100644 index 0000000..f3d22be --- /dev/null +++ b/cmd_examples/iris.bat @@ -0,0 +1,12 @@ +..\target\release\kohonen.exe ^ +--file ..\example_data/iris.csv ^ +--size 20 16 ^ +--episodes 5000 ^ +--layers "sepal_length sepal_width petal_length petal_width" "species" ^ +--categ 0 1 ^ +--norm gauss none ^ +--weights 1 1 ^ +--alpha 0.2 0.01 lin ^ +--radius 8 0.7 lin ^ +--decay 0.2 0.001 exp ^ +--neigh gauss diff --git a/example_data/countries.csv b/example_data/countries.csv new file mode 100644 index 0000000..60473d9 --- /dev/null +++ b/example_data/countries.csv @@ -0,0 +1,181 @@ +Country;code;continent;child_mort_rank;child_mort_1990;child_mort_2010;mort_u1_1990;mort_u1_2010;mort_neo;pop_1000;birth_p_1000;GNI;LifeExpectancy;PopGrowth;Fertility;PopUrbanized;PopGrowthUrb;AdultLiteracy;PrimSchool;Income_low_40;Income_high_20 +Afghanistan;AF;AS;11;209;149;140;103;45;31412;44;330;48;4.4;6.3;23;5.5;-;-;22;39 +Albania;AL;EU;108;41;18;36;16;9;3204;13;4000;77;-0.1;1.5;52;1.6;96;85;20;43 +Algeria;DZ;AF;69;68;36;55;31;18;35468;20;4460;73;1.7;2.3;66;2.9;73;95;18;42 +Angola;AO;AF;8;243;161;144;98;41;19082;42;3960;51;3.1;5.4;59;5.3;70;-;8;62 +Argentina;AR;SA;126;27;14;24;12;7;40412;17;8450;76;1.1;2.2;92;1.4;98;-;13;51 +Armenia;AM;AS;98;55;20;46;18;11;3092;15;3090;74;-0.7;1.7;64;-0.9;100;93;22;40 +Australia;AU;OC;165;9;5;8;4;3;22268;14;43740;82;1.3;1.9;89;1.5;-;97;18;41 +Austria;AT;EU;172;9;4;8;4;2;8394;9;46710;81;0.5;1.4;68;0.6;-;-;22;38 +Azerbaijan;AZ;AS;63;93;46;74;39;19;9188;20;5180;71;1.2;2.2;52;1;100;86;20;42 +Bahamas;BS;MA;118;22;16;18;14;7;343;15;14000;75;1.5;1.9;84;1.7;-;92;-;- +Bahrain;BH;AS;139;17;10;15;9;4;1262;18;25420;75;4.7;2.5;89;4.7;91;99;-;- +Bangladesh;BD;AS;61;143;48;99;38;27;148692;20;640;69;1.7;2.2;28;3.5;56;89;22;41 +Barbados;BB;MA;98;18;20;16;17;10;273;11;14000;77;0.3;1.6;44;1.8;-;-;-;- +Belarus;BY;EU;156;17;6;14;4;3;9595;11;6030;70;-0.3;1.4;75;0.3;100;95;23;36 +Belgium;BE;EU;172;10;4;9;4;2;10712;11;45420;80;0.4;1.8;97;0.4;-;99;21;41 +Belize;BZ;MA;113;44;17;35;14;8;312;25;3740;76;2.5;2.8;52;2.9;-;100;11;59 +Benin;BJ;AF;20;178;115;107;73;32;8850;40;750;56;3.1;5.3;42;4.1;42;94;18;46 +Bhutan;BT;AS;52;139;56;96;44;26;726;20;1920;67;1.3;2.4;35;5.1;53;88;14;53 +Bolivia;BO;SA;55;121;54;84;42;23;9930;26;1790;66;2;3.3;67;2.9;91;95;9;61 +Bosnia and Herzegovina;BA;EU;145;19;8;17;8;5;3760;9;4790;76;-0.7;1.1;49;0.4;98;87;18;43 +Botswana;BW;AF;61;59;48;46;36;19;2007;24;6890;53;1.9;2.8;61;3.7;84;87;9;65 +Brazil;BR;SA;103;59;19;50;17;12;194946;16;9390;73;1.3;1.8;87;2.1;90;95;11;58 +Brunei Darussalam;BN;AS;152;12;7;9;6;4;399;19;31180;78;2.3;2;76;3;95;97;-;- +Bulgaria;BG;EU;130;22;13;18;11;7;7494;10;6240;73;-0.8;1.5;71;-0.4;98;98;14;51 +Burkina Faso;BF;AF;3;205;176;103;93;38;16469;43;550;55;2.8;5.9;26;5.9;29;64;18;47 +Burundi;BI;AF;14;183;142;110;88;42;8383;34;160;50;2;4.3;11;4.8;67;99;21;43 +Cambodia;KH;AS;58;121;51;87;43;22;14138;23;760;63;2;2.6;20;4.3;78;89;16;52 +Cameroon;CM;AF;15;137;136;85;84;34;19599;36;1160;51;2.4;4.5;58;4.2;71;92;15;51 +Canada;CA;NA;156;8;6;7;5;4;34017;11;41950;81;1;1.7;81;1.3;-;-;20;40 +Cape Verde;CV;AF;69;59;36;46;29;14;496;21;3160;74;1.8;2.4;61;3.4;85;83;13;56 +Central African Republic ;CF;AF;9;165;159;110;106;42;4401;35;460;48;2;4.6;39;2.3;55;67;15;49 +Chad;TD;AF;5;207;173;113;99;41;11227;45;600;49;3.1;6;28;4.5;34;-;17;47 +Chile;CL;SA;142;19;9;16;8;5;17114;14;9940;79;1.3;1.9;89;1.6;99;95;24;31 +China;CN;AS;108;48;18;38;16;11;1341335;12;4260;73;0.8;1.6;47;3.7;94;96;16;48 +Colombia;CO;SA;103;37;19;30;17;12;46295;20;5510;73;1.7;2.4;75;2.1;93;93;8;62 +Comoros;KM;AF;34;125;86;88;63;32;735;38;820;61;2.6;4.9;28;2.6;74;87;8;68 +Congo;??;AF;29;116;93;74;61;29;4043;35;2310;57;2.6;4.5;62;3.3;-;-;13;53 +Costa Rica;CR;MA;139;17;10;15;9;6;4659;16;6580;79;2.1;1.8;64;3.3;96;-;12;55 +Cote dIvoire;CI;AF;18;151;123;105;86;41;19738;34;1070;55;2.3;4.4;51;3.5;55;57;16;48 +Croatia;HR;EU;156;13;6;11;5;3;4403;10;13760;76;-0.1;1.5;58;0.2;99;95;20;42 +Cuba;CU;MA;156;13;6;11;5;3;11258;10;5550;79;0.3;1.5;75;0.4;100;100;-;- +Cyprus;CY;EU;172;11;4;10;3;2;1104;12;30460;79;1.8;1.5;70;2.1;98;99;-;- +Czech Republic;CZ;EU;172;14;4;12;3;2;10493;11;17870;78;0.1;1.5;74;0;-;-;25;36 +North Korea;KP;AS;73;45;33;23;26;18;24346;14;500;69;0.9;2;60;1.1;100;-;-;- +Democratic Republic of the Congo;CD;AF;6;181;170;117;112;46;65966;44;180;48;3;5.8;35;4.2;67;-;15;51 +Denmark;DK;EU;172;9;4;7;3;2;5550;11;58980;79;0.4;1.9;87;0.5;-;95;23;36 +Djibouti;DJ;AF;31;123;91;95;73;34;889;29;1280;58;2.3;3.8;76;2.3;-;45;17;47 +Dominican Republic;DO;MA;81;62;27;48;22;15;9927;22;4860;73;1.6;2.6;69;2.7;88;82;13;54 +Ecuador;EC;SA;98;52;20;41;18;10;14465;21;4510;75;1.7;2.5;67;2.7;84;97;13;54 +Egypt;EG;AF;91;94;22;68;19;9;81121;23;2340;73;1.8;2.7;43;1.8;66;95;22;42 +El Salvador;SV;MA;118;62;16;48;14;6;6193;20;3360;72;0.7;2.3;64;2.1;84;96;13;52 +Equatorial Guinea;GQ;AF;19;190;121;118;81;35;700;37;14680;51;3.1;5.2;40;3.8;93;57;-;- +Eritrea;ER;AF;49;141;61;87;42;18;5254;36;340;61;2.5;4.5;22;4.1;67;37;-;- +Estonia;EE;EU;165;21;5;17;4;3;1341;12;14360;75;-0.8;1.7;69;-0.9;100;97;18;43 +Ethiopia;ET;AF;23;184;106;111;68;35;82950;31;380;59;2.7;4.2;17;4.1;30;84;23;39 +Fiji;FJ;OC;113;30;17;25;15;8;861;22;3610;69;0.8;2.7;52;1.9;-;92;-;- +Finland;FI;EU;186;7;3;6;2;2;5365;11;47170;80;0.4;1.9;85;0.7;-;96;24;37 +France;FR;EU;172;9;4;7;3;2;62787;13;42390;81;0.5;2;85;1.2;-;99;20;40 +Gabon;GA;AF;43;93;74;68;54;26;1505;27;7760;62;2.4;3.3;86;3.5;88;-;16;48 +Gambia;GM;AF;28;165;98;78;57;31;1728;38;440;58;2.9;4.9;58;5;46;76;13;53 +Georgia;GE;AS;91;47;22;40;20;15;4352;12;2700;74;-1.1;1.6;53;-1.3;100;100;16;47 +Germany;DE;EU;172;9;4;7;3;2;82302;8;43330;80;0.2;1.4;74;0.2;-;100;22;37 +Ghana;GH;AF;43;122;74;77;50;28;24392;32;1240;64;2.5;4.2;51;4.2;67;76;15;48 +Greece;GR;EU;172;13;4;11;3;2;11359;10;27240;80;0.6;1.5;61;0.8;97;100;19;42 +Grenada;GD;MA;136;21;11;17;9;5;104;19;5560;76;0.4;2.2;39;1.2;-;98;-;- +Guatemala;GT;MA;76;78;32;56;25;15;14389;32;2740;71;2.4;4;49;3.3;74;96;11;58 +Guinea;GN;AF;17;229;130;135;81;38;9982;39;380;54;2.7;5.2;35;3.9;39;74;17;46 +Guinea-Bissau;GW;AF;10;210;150;125;92;40;1515;38;540;48;2;5.1;30;2.3;52;-;19;43 +Guyana;GY;SA;79;66;30;50;25;19;754;18;3270;70;0.2;2.3;29;0;-;99;14;50 +Haiti;HT;MA;7;151;165;104;70;27;9993;27;650;62;1.7;3.3;52;4.7;49;-;8;63 +Honduras;HN;MA;88;58;24;45;20;12;7601;27;1880;73;2.2;3.1;52;3.4;84;97;8;61 +Hungary;HU;EU;156;19;6;17;5;4;9984;10;12990;74;-0.2;1.4;68;0;99;96;21;40 +Iceland;IS;EU;193;6;2;5;2;1;320;15;33870;82;1.1;2.1;93;1.3;-;98;-;- +India;IN;AS;46;115;63;81;48;32;1224614;22;1340;65;1.7;2.6;30;2.5;63;97;19;45 +Indonesia;ID;AS;72;85;35;56;27;17;239871;18;2580;69;1.3;2.1;44;3.2;92;98;19;45 +Iran;IR;AS;85;65;26;50;22;14;73974;17;4530;73;1.5;1.7;71;2.6;85;100;17;45 +Iraq;IQ;AS;67;46;39;37;31;20;31672;36;2320;68;3;4.7;66;2.7;78;88;-;- +Ireland;IE;EU;172;9;4;8;3;2;4470;16;40990;80;1.2;2.1;62;1.6;-;97;20;42 +Israel;IL;AS;165;12;5;10;4;2;7418;21;27340;81;2.5;2.9;92;2.6;-;97;16;45 +Italy;IT;EU;172;10;4;8;3;2;60551;9;35090;82;0.3;1.4;68;0.4;99;99;18;42 +Jamaica;JM;MA;88;38;24;31;20;9;2741;18;4750;73;0.7;2.3;52;1;86;81;14;51 +Japan;JP;AS;186;6;3;5;2;1;126536;9;42150;83;0.2;1.4;67;0.5;-;100;25;36 +Jordan;JO;AS;91;38;22;32;18;13;6187;25;4350;73;3;3.1;79;3.4;92;94;18;45 +Kazakhstan;KZ;AS;73;57;33;48;29;17;16026;21;7440;67;-0.2;2.6;59;0;100;99;21;40 +Kenya;KE;AF;35;99;85;64;55;28;40513;38;780;57;2.7;4.7;22;3.7;87;83;13;53 +Kuwait;KW;AS;136;15;11;13;10;6;2737;18;14000;74;1.4;2.3;98;1.4;94;93;-;- +Kyrgyzstan;KG;AS;68;72;38;59;33;19;5334;24;880;67;1;2.7;35;0.5;99;91;21;43 +Laos;LA;AS;55;145;54;100;42;21;6201;23;1010;67;2;2.7;33;5.8;73;82;19;45 +Latvia;LV;EU;139;21;10;16;8;5;2252;11;11620;73;-0.8;1.5;68;-1;100;94;18;43 +Lebanon;LB;AS;91;38;22;31;19;12;4228;15;9020;72;1.8;1.8;87;2;90;91;-;- +Lesotho;LS;AF;35;89;85;72;65;35;2171;28;1080;48;1.4;3.2;27;4.7;90;73;10;56 +Liberia;LR;AF;24;227;103;151;74;34;3994;38;190;56;3.2;5.2;48;3.9;59;-;18;45 +Libya;LY;AF;113;45;17;33;13;10;6355;23;12020;75;1.9;2.6;78;2.1;89;-;-;- +Lithuania;LT;EU;152;17;7;14;5;3;3324;10;11400;72;-0.5;1.5;67;-0.6;100;97;18;44 +Luxembourg;LU;EU;186;8;3;7;2;1;507;11;79510;80;1.4;1.6;85;1.7;-;97;21;39 +Madagascar;MG;AF;48;159;62;97;43;22;20714;35;440;66;3;4.7;30;4.3;64;99;16;54 +Malawi;MW;AF;30;222;92;131;58;27;14901;44;330;54;2.3;6;20;5;74;91;18;46 +Malaysia;MY;AS;156;18;6;15;5;3;28401;20;7900;74;2.2;2.6;72;4.1;92;94;13;52 +Maldives;MV;AS;124;102;15;74;14;9;316;17;4270;77;1.8;1.8;40;4;98;96;17;44 +Mali;ML;AF;2;255;178;131;99;48;15370;46;600;51;2.9;6.3;36;5;26;77;17;46 +Malta;MT;EU;156;11;6;10;5;4;417;9;18350;79;0.6;1.3;95;0.9;92;91;-;- +Mauritania;MR;AF;21;124;111;80;75;39;3460;34;1060;58;2.8;4.5;41;3;57;76;17;46 +Mauritius;MU;AF;124;24;15;21;13;9;1299;13;7740;73;1;1.6;42;0.8;88;94;-;- +Mexico;MX;MA;113;49;17;38;14;7;113423;20;9330;77;1.5;2.3;78;1.9;93;100;12;56 +Micronesia;FM;OC;64;56;42;44;34;18;111;25;2700;69;0.7;3.5;23;0.1;-;-;7;64 +Mongolia;MN;AS;76;107;32;76;26;12;2756;23;1890;68;1.1;2.5;62;1.6;97;100;18;44 +Montenegro;ME;EU;145;18;8;16;7;5;631;12;6690;74;0.2;1.7;61;1.4;-;88;22;39 +Morocco;MA;AF;69;86;36;67;30;19;31951;20;2850;72;1.3;2.3;58;2.2;56;90;17;48 +Mozambique;MZ;AF;16;219;135;146;92;39;23391;38;440;50;2.7;4.9;38;5.7;55;91;15;52 +Myanmar;MM;AS;45;112;66;79;50;32;47963;17;500;65;1;2;34;2.5;92;-;-;- +Namibia;NA;AF;65;73;40;49;29;17;2283;26;4650;62;2.4;3.2;38;4;89;90;4;78 +Nepal;NP;AS;59;141;50;97;41;28;29959;24;490;68;2.3;2.7;19;6;59;-;15;54 +Netherlands;NL;EU;172;8;4;7;4;3;16613;11;49720;81;0.5;1.8;83;1.5;-;99;21;39 +New Zealand;NZ;OC;156;11;6;9;5;3;4368;15;29050;81;1.3;2.2;86;1.3;-;99;18;44 +Nicaragua;NI;MA;81;68;27;52;23;12;5788;24;1080;74;1.7;2.6;57;2.2;78;93;12;57 +Niger;NE;AF;12;311;143;132;73;32;15512;49;360;54;3.4;7.1;17;4;29;54;20;43 +Nigeria;NG;AF;12;213;143;126;88;40;158423;40;1180;51;2.4;5.5;50;4.1;61;63;15;49 +Norway;NO;EU;186;9;3;7;3;2;4883;12;85380;81;0.7;1.9;79;1.2;-;99;24;37 +Palestine;PS;AS;91;45;22;36;20;15;4039;33;2500;73;3.3;4.5;74;3.8;95;78;-;- +Oman;OM;AS;142;47;9;36;8;5;2782;18;17890;73;2;2.3;73;2.5;87;81;-;- +Pakistan;PK;AS;33;124;87;96;70;41;173593;27;1050;65;2.2;3.4;36;3;56;66;21;42 +Panama;PA;MA;98;33;20;26;17;9;3517;20;6990;76;1.9;2.5;75;3.5;94;97;11;57 +Papua New Guinea;PG;OC;49;90;61;65;47;23;6858;30;1300;62;2.5;4;13;1.6;60;-;12;56 +Paraguay;PY;SA;87;50;25;40;21;14;6455;24;2940;72;2.1;3;61;3.3;95;86;11;57 +Peru;PE;SA;103;78;19;55;15;9;29077;20;4710;74;1.5;2.5;77;2;90;97;12;53 +Philippines;PH;AS;80;59;29;42;23;14;93261;25;2050;68;2.1;3.1;49;2.1;95;92;15;50 +Poland;PL;EU;156;17;6;15;5;4;38277;11;12420;76;0;1.4;61;0;100;96;20;42 +Portugal;PT;EU;172;15;4;11;3;2;10676;9;21860;79;0.4;1.3;61;1.5;95;99;17;46 +Qatar;QA;AS;145;21;8;17;7;4;1759;12;14000;78;6.6;2.3;96;6.8;95;98;-;52 +South Korea;KR;AS;165;8;5;6;4;2;48184;10;19890;81;0.6;1.3;83;1.2;-;99;21;38 +Moldova;MD;EU;103;37;19;30;16;9;3573;12;1810;69;-1;1.5;47;-1;98;90;18;45 +Romania;RO;EU;126;37;14;29;11;8;21486;10;7840;74;-0.4;1.4;57;0;98;96;21;39 +Russian Federation;RU;EU;133;27;12;22;9;6;142958;12;9910;69;-0.2;1.5;73;-0.2;100;94;16;49 +Rwanda;RW;AF;31;163;91;99;59;29;10624;41;540;55;2;5.4;19;8.2;71;96;12;58 +Saint Lucia;LC;MA;118;23;16;18;14;10;174;18;4970;74;1.2;2;28;0.9;-;93;15;49 +Saint Vincent and the Grenadines;VC;MA;97;27;21;21;19;13;109;17;4850;72;0.1;2.1;49;1;-;98;-;- +Samoa;WS;OC;98;27;20;23;17;8;183;25;2930;72;0.6;3.9;20;0.4;99;99;-;- +Sao Tome and Principe;ST;AF;37;94;80;61;53;25;165;31;1200;64;1.8;3.7;62;3.5;89;98;14;56 +Saudi Arabia;SA;AS;108;45;18;36;15;10;27448;22;17200;74;2.7;2.8;82;3;86;86;-;- +Senegal;SN;AF;42;139;75;70;50;27;12434;37;1050;59;2.7;4.8;42;3.1;50;75;17;46 +Serbia;RS;EU;152;29;7;25;6;4;9856;11;5820;74;0.1;1.6;56;0.7;98;96;23;37 +Sierra Leone;SL;AF;4;276;174;162;114;45;5868;39;340;47;1.9;5;38;2.7;41;-;16;49 +Singapore;SG;AS;186;8;3;6;2;1;5086;9;40920;81;2.6;1.3;100;2.6;95;-;14;49 +Slovakia;SK;EU;145;18;8;15;7;4;5462;10;16220;75;0.2;1.3;55;0;-;-;24;35 +Slovenia;SI;EU;186;10;3;9;2;2;2030;10;23860;79;0.3;1.4;50;0.2;100;98;21;39 +Solomon Islands;SB;OC;81;45;27;36;23;12;538;32;1030;67;2.8;4.2;19;4.3;-;81;-;- +Somalia;SO;AF;1;180;180;108;108;52;9331;44;500;51;1.7;6.3;37;2.9;-;-;-;- +South Africa;ZA;AF;51;60;57;47;41;18;50133;21;6100;52;1.5;2.5;62;2.4;89;90;9;63 +Spain;ES;EU;165;11;5;9;4;3;46077;11;31650;81;0.8;1.5;77;1;98;100;19;42 +Sri Lanka;LK;AS;113;32;17;26;14;10;20860;18;2290;75;0.9;2.3;14;-0.4;91;95;17;48 +Suriname;SR;SA;78;52;31;44;27;14;525;18;5920;70;1.3;2.3;69;2;95;90;11;57 +Swaziland;SZ;AF;39;96;78;70;55;21;1186;29;2600;48;1.6;3.4;21;1.2;87;83;12;56 +Sweden;SE;EU;186;7;3;6;2;2;9380;12;49930;81;0.5;1.9;85;0.6;-;96;23;37 +Switzerland;CH;EU;165;8;5;7;4;3;7664;10;70350;82;0.7;1.5;74;0.7;-;100;20;41 +Syrian Arab Republic;SY;AS;118;38;16;31;14;9;20411;23;2640;76;2.5;2.9;56;3.2;84;-;19;44 +Tajikistan;TJ;AS;46;116;63;91;52;25;6879;28;780;67;1.3;3.3;26;0.4;100;98;23;39 +Thailand;TW;AS;130;32;13;26;11;8;69122;12;4210;74;1;1.6;34;1.7;94;90;11;59 +Macedonia;MK;EU;133;39;12;34;10;8;2061;11;4520;75;0.4;1.4;59;0.5;97;93;15;50 +Timor-Leste;TL;OC;54;169;55;127;46;24;1124;39;2220;62;2.1;6.2;28;3.6;51;83;21;41 +Togo;TG;AF;24;147;103;87;66;32;6028;32;440;57;2.5;4.1;43;4.3;57;95;16;47 +Tonga;TO;OC;118;25;16;21;13;8;104;27;3380;72;0.4;3.9;23;0.6;99;-;-;- +Trinidad and Tobago;TT;MA;81;37;27;32;24;18;1341;15;15380;70;0.5;1.6;14;2.9;99;96;16;46 +Tunisia;TN;AF;118;49;16;39;14;9;10481;17;4070;74;1.2;2;67;2;78;99;16;47 +Turkey;TR;AS;108;80;18;66;14;10;72752;18;9500;74;1.5;2.1;70;2.3;91;95;16;46 +Turkmenistan;TM;AS;52;98;56;78;47;23;5042;22;3700;65;1.6;2.4;50;2.1;100;-;16;47 +Uganda;UG;AF;27;175;99;106;63;26;33425;45;490;54;3.2;6.1;13;4.1;73;92;15;51 +Ukraine;UA;EU;130;21;13;18;11;6;45448;11;3010;68;-0.6;1.4;69;-0.5;100;89;23;37 +United Arab Emirates;AE;AS;152;22;7;18;6;4;7512;12;14000;76;7.1;1.7;84;7.4;90;98;-;- +United Kingdom;GB;EU;165;9;5;8;5;3;62036;12;38540;80;0.4;1.9;80;0.5;-;100;18;44 +Tanzania;TZ;AF;41;155;76;95;50;26;44841;42;530;57;2.8;5.5;26;4.5;73;97;18;45 +United States;US;NA;145;11;8;9;7;4;310384;14;47140;78;1;2.1;82;1.5;-;92;16;46 +Uruguay;UY;SA;136;23;11;20;9;6;3369;15;10590;77;0.4;2.1;92;0.6;98;99;15;49 +Uzbekistan;UZ;AS;57;77;52;63;44;23;27445;21;1280;68;1.5;2.4;36;0.9;99;90;19;44 +Vanuatu;VU;OC;126;39;14;31;12;7;240;30;2760;71;2.5;3.9;26;4;82;-;-;- +Venezuela;VE;SA;108;33;18;28;16;10;28980;21;11590;74;1.9;2.5;93;2.4;95;94;15;49 +Vietnam;VN;AS;90;51;23;37;19;12;87848;17;1100;75;1.3;1.8;30;3.4;93;-;18;45 +Yemen;YE;AS;40;128;77;90;57;32;24053;38;1060;65;3.5;5.2;32;5.6;62;73;18;45 +Zambia;ZM;AF;21;183;111;109;69;30;13089;46;1070;49;2.5;6.3;36;2.1;71;92;11;55 +Zimbabwe;ZW;AF;37;78;80;52;51;27;12571;30;460;50;0.9;3.3;38;2.3;92;-;13;56 +Sudan and South Sudan;SD;AF;24;125;103;78;66;35;43552;33;1270;61;2.5;4.4;40;4.5;70;-;-;- diff --git a/examples/countries.rs b/examples/countries.rs new file mode 100644 index 0000000..31a0299 --- /dev/null +++ b/examples/countries.rs @@ -0,0 +1,53 @@ +use easy_graph::ui::window::WindowBuilder; +use kohonen::calc::neighborhood::Neighborhood; +use kohonen::map::som::DecayParam; +use kohonen::proc::{InputLayer, ProcessorBuilder}; +use kohonen::ui::LayerView; + +fn main() { + let layers = vec![ + InputLayer::cont_simple(&[ + "child_mort_2010", + "birth_p_1000", + "GNI", + "LifeExpectancy", + "PopGrowth", + "PopUrbanized", + "PopGrowthUrb", + "AdultLiteracy", + "PrimSchool", + "Income_low_40", + "Income_high_20", + ]), + //InputLayer::cat_simple("species"), + ]; + + let proc = ProcessorBuilder::new(&layers) + .with_delimiter(b';') + .with_no_data("-") + .build_from_file("example_data/countries.csv") + .unwrap(); + + let mut som = proc.create_som( + 16, + 20, + 5000, + Neighborhood::Gauss, + DecayParam::lin(0.2, 0.01), + DecayParam::lin(8.0, 0.5), + DecayParam::exp(0.2, 0.001), + ); + + let win_x = WindowBuilder::new() + .with_position((10, 10)) + .with_dimensions(1000, 500) + .with_fps_skip(5.0) + .build(); + + let mut view_x = LayerView::new(win_x, &[0], &proc.data().names_ref_vec(), None); + + while view_x.is_open() { + som.epoch(proc.data(), None); + view_x.draw(&som); + } +} diff --git a/examples/csv_tests.rs b/examples/csv_tests.rs deleted file mode 100644 index 30dc29d..0000000 --- a/examples/csv_tests.rs +++ /dev/null @@ -1,17 +0,0 @@ -use csv::StringRecord; - -fn main() { - let mut reader = csv::ReaderBuilder::new() - .delimiter(b';') - .from_path("example_data/iris.csv") - .unwrap(); - let header: StringRecord = reader.headers().unwrap().clone(); - let _header: Vec<_> = header.iter().collect(); - - /* - println!("{:?}", header); - for record in reader.records() { - println!("{:?}", record); - } - */ -} diff --git a/examples/iris.rs b/examples/iris.rs index 0f3fba2..3556a31 100644 --- a/examples/iris.rs +++ b/examples/iris.rs @@ -1,5 +1,5 @@ use easy_graph::ui::window::WindowBuilder; -use kohonen::calc::neighborhood::GaussNeighborhood; +use kohonen::calc::neighborhood::Neighborhood; use kohonen::map::som::DecayParam; use kohonen::proc::{InputLayer, ProcessorBuilder}; use kohonen::ui::LayerView; @@ -19,7 +19,7 @@ fn main() { 16, 20, 5000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(8.0, 0.5), DecayParam::exp(0.2, 0.001), @@ -37,8 +37,8 @@ fn main() { .with_fps_skip(5.0) .build(); - let mut view_x = LayerView::new(win_x, &[0], None); - let mut view_y = LayerView::new(win_y, &[1], None); + let mut view_x = LayerView::new(win_x, &[0], &proc.data().names_ref_vec(), None); + let mut view_y = LayerView::new(win_y, &[1], &proc.data().names_ref_vec(), None); while view_x.is_open() || view_y.is_open() { som.epoch(proc.data(), None); diff --git a/examples/layer_view_simple.rs b/examples/layer_view_simple.rs index 6b974ee..b49c2a6 100644 --- a/examples/layer_view_simple.rs +++ b/examples/layer_view_simple.rs @@ -1,5 +1,5 @@ use easy_graph::ui::window::WindowBuilder; -use kohonen::calc::neighborhood::GaussNeighborhood; +use kohonen::calc::neighborhood::Neighborhood; use kohonen::map::som::{DecayParam, Layer, Som, SomParams}; use kohonen::ui::LayerView; @@ -7,7 +7,7 @@ fn main() { let dim = 5; let params = SomParams::xyf( 1000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.1, 0.01), DecayParam::lin(10.0, 0.6), DecayParam::exp(0.25, 0.0001), @@ -20,7 +20,7 @@ fn main() { .with_fps_skip(5.0) .build(); - let mut view = LayerView::new(win, &[0], None); + let mut view = LayerView::new(win, &[0], &["A", "B", "C", "D", "E"], None); while view.is_open() { view.draw(&som); diff --git a/examples/som_with_view.rs b/examples/som_with_view.rs index 70c4e0e..25cefd0 100644 --- a/examples/som_with_view.rs +++ b/examples/som_with_view.rs @@ -2,7 +2,7 @@ use easy_graph::color::style::{BLACK, BLUE, WHITE}; use easy_graph::ui::drawing::IntoDrawingArea; use easy_graph::ui::element::{Circle, PathElement}; use easy_graph::ui::window::WindowBuilder; -use kohonen::calc::neighborhood::GaussNeighborhood; +use kohonen::calc::neighborhood::Neighborhood; use kohonen::data::DataFrame; use kohonen::map::som::{DecayParam, Som, SomParams}; use rand::prelude::*; @@ -16,7 +16,7 @@ fn run_som(graphics: bool) { let cols = ["A", "B"]; let params = SomParams::simple( 1000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.1, 0.01), DecayParam::lin(10.0, 0.6), DecayParam::exp(0.25, 0.0001), @@ -24,7 +24,7 @@ fn run_som(graphics: bool) { let mut som = Som::new(cols.len(), 20, 16, params); let mut rng = rand::thread_rng(); - let mut data = DataFrame::::empty(&cols); + let mut data = DataFrame::empty(&cols); let norm = rand::distributions::Normal::new(0.0, 0.06); for _i in 0..5000 { diff --git a/examples/xyf_with_view.rs b/examples/xyf_with_view.rs index b959363..1b49f05 100644 --- a/examples/xyf_with_view.rs +++ b/examples/xyf_with_view.rs @@ -2,7 +2,7 @@ use easy_graph::color::style::{BLACK, BLUE, WHITE}; use easy_graph::ui::drawing::IntoDrawingArea; use easy_graph::ui::element::{Circle, PathElement}; use easy_graph::ui::window::WindowBuilder; -use kohonen::calc::neighborhood::GaussNeighborhood; +use kohonen::calc::neighborhood::Neighborhood; use kohonen::data::DataFrame; use kohonen::map::som::{DecayParam, Layer, Som, SomParams}; use kohonen::ui::LayerView; @@ -18,7 +18,7 @@ fn run_xyf(graphics: bool) { let cols = ["A", "B", "C", "D"]; let params = SomParams::xyf( 1000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.1, 0.01), DecayParam::lin(10.0, 0.6), DecayParam::exp(0.25, 0.0001), @@ -27,7 +27,7 @@ fn run_xyf(graphics: bool) { let mut som = Som::new(cols.len(), 12, 24, params); let mut rng = rand::thread_rng(); - let mut data = DataFrame::::empty(&cols); + let mut data = DataFrame::empty(&cols); let norm = rand::distributions::Normal::new(0.0, 0.06); for _i in 0..5000 { @@ -55,7 +55,7 @@ fn run_xyf(graphics: bool) { .with_dimensions(800, 500) .with_fps_skip(2.0) .build(); - Some(LayerView::new(win, &[], None)) + Some(LayerView::new(win, &[], &cols, None)) } else { None }; diff --git a/run_cli.bat b/run_cli.bat new file mode 100644 index 0000000..1664ff8 --- /dev/null +++ b/run_cli.bat @@ -0,0 +1,12 @@ +.\target\debug\kohonen.exe ^ +--file example_data/iris.csv ^ +--size 20 16 ^ +--episodes 5000 ^ +--layers "sepal_length sepal_width petal_length petal_width" "species" ^ +--categ 0 1 ^ +--norm gauss none ^ +--weights 1 1 ^ +--alpha 0.2 0.01 lin ^ +--radius 8 0.5 lin ^ +--decay 0.2 0.001 exp ^ +--neigh gauss diff --git a/run_cli_release.bat b/run_cli_release.bat new file mode 100644 index 0000000..89741a5 --- /dev/null +++ b/run_cli_release.bat @@ -0,0 +1,12 @@ +.\target\release\kohonen.exe ^ +--file example_data/iris.csv ^ +--size 20 16 ^ +--episodes 5000 ^ +--layers "sepal_length sepal_width petal_length petal_width" "species" ^ +--categ 0 1 ^ +--norm gauss none ^ +--weights 1 1 ^ +--alpha 0.2 0.01 lin ^ +--radius 8 0.7 lin ^ +--decay 0.2 0.001 exp ^ +--neigh gauss diff --git a/src/calc/neighborhood.rs b/src/calc/neighborhood.rs index 8a1ca33..3d7fa55 100644 --- a/src/calc/neighborhood.rs +++ b/src/calc/neighborhood.rs @@ -1,7 +1,45 @@ //! Neighborhoods (i.e. kernels), for effect on nearby SOM-units. +use crate::ParseEnumError; + +/// Neighborhoods. +#[derive(Debug, Clone)] +pub enum Neighborhood { + Gauss, +} +impl Neighborhood { + /// Calculates the weight, depending on the squared(!) distance. + pub fn weight(&self, distance_sq: f64) -> f64 { + match self { + Neighborhood::Gauss => { + if distance_sq == 0.0 { + 1.0 + } else { + (-0.5 * distance_sq).exp() + } + } + } + } + /// Maximum search distance in the SOM. Not squared! + pub fn radius(&self) -> f64 { + match self { + Neighborhood::Gauss => 3.0, + } + } + pub fn from_string(str: &str) -> Result { + match str { + "gauss" => Ok(Neighborhood::Gauss), + _ => Err(ParseEnumError(format!( + "Not a neighborhood: {}. Must be one of (gauss|)", + str + ))), + } + } +} + +/* /// Trait for neighborhoods. -pub trait Neighborhood { +pub trait Neighborhood: Debug + Copy { /// Calculates the weight, depending on the squared(!) distance. fn weight(&self, distance_sq: f64) -> f64; /// Maximum search distance in the SOM. Not squared! @@ -9,6 +47,7 @@ pub trait Neighborhood { } /// Gaussian (normal) neighborhood. +#[derive(Debug, Copy)] pub struct GaussNeighborhood(); impl Neighborhood for GaussNeighborhood { @@ -23,14 +62,14 @@ impl Neighborhood for GaussNeighborhood { 3.0 } } - +*/ #[cfg(test)] mod test { - use crate::calc::neighborhood::{GaussNeighborhood, Neighborhood}; + use crate::calc::neighborhood::Neighborhood; #[test] fn gauss() { - let neigh = GaussNeighborhood(); + let neigh = Neighborhood::Gauss; assert_eq!(neigh.weight(0.0), 1.0); assert!(neigh.weight(3.0 * 3.0) < 0.12); } diff --git a/src/calc/nn.rs b/src/calc/nn.rs index 598efd7..06e6b25 100644 --- a/src/calc/nn.rs +++ b/src/calc/nn.rs @@ -16,7 +16,7 @@ const TANIMOTO: TanimotoMetric = TanimotoMetric(); /// Dimensions with NA values are ignored. /// # Returns /// (index, distance) -pub fn nearest_neighbor(from: &[f64], to: &DataFrame) -> (usize, f64) { +pub fn nearest_neighbor(from: &[f64], to: &DataFrame) -> (usize, f64) { assert_eq!(from.len(), to.ncols()); let mut min_dist = std::f64::MAX; @@ -35,7 +35,7 @@ pub fn nearest_neighbor(from: &[f64], to: &DataFrame) -> (usize, f64) { /// Dimensions with NA values are ignored. /// # Returns /// (index, distance) -pub fn nearest_neighbor_tanimoto(from: &[f64], to: &DataFrame) -> (usize, f64) { +pub fn nearest_neighbor_tanimoto(from: &[f64], to: &DataFrame) -> (usize, f64) { assert_eq!(from.len(), to.ncols()); let mut min_dist = std::f64::MAX; @@ -54,7 +54,7 @@ pub fn nearest_neighbor_tanimoto(from: &[f64], to: &DataFrame) -> (usize, f /// Dimensions with NA values are ignored. /// # Returns /// (index, weighted-distance) -pub fn nearest_neighbor_xyf(from: &[f64], to: &DataFrame, layers: &[Layer]) -> (usize, f64) { +pub fn nearest_neighbor_xyf(from: &[f64], to: &DataFrame, layers: &[Layer]) -> (usize, f64) { assert_eq!(from.len(), to.ncols()); let mut min_dist = std::f64::MAX; @@ -84,8 +84,8 @@ pub fn nearest_neighbor_xyf(from: &[f64], to: &DataFrame, layers: &[Layer]) /// # Returns /// Vec(index, weighted-distance) pub fn nearest_neighbors( - from: &DataFrame, - to: &DataFrame, + from: &DataFrame, + to: &DataFrame, mut result: Vec<(usize, f64)>, ) -> Vec<(usize, f64)> { assert_eq!(from.ncols(), to.ncols()); @@ -182,7 +182,7 @@ mod test { fn xyf_nn() { let mut rng = rand::thread_rng(); let from = [0.0, 0.0, 0.0, 0.0, 0.0]; - let mut to = DataFrame::::empty(&["A", "B", "C", "D", "E"]); + let mut to = DataFrame::empty(&["A", "B", "C", "D", "E"]); for _i in 0..10 { to.push_row(&[ @@ -202,7 +202,7 @@ mod test { fn nn_simple() { let mut rng = rand::thread_rng(); let from = [0.0, 0.0, 0.0]; - let mut to = DataFrame::::empty(&["A", "B", "C"]); + let mut to = DataFrame::empty(&["A", "B", "C"]); for _i in 0..100 { to.push_row(&[ @@ -220,8 +220,8 @@ mod test { #[test] fn nns_simple() { let mut rng = rand::thread_rng(); - let mut from = DataFrame::::empty(&["A", "B", "C"]); - let mut to = DataFrame::::empty(&["A", "B", "C"]); + let mut from = DataFrame::empty(&["A", "B", "C"]); + let mut to = DataFrame::empty(&["A", "B", "C"]); for _i in 0..100 { from.push_row(&[ diff --git a/src/calc/norm.rs b/src/calc/norm.rs index 0d5e0ff..8e23734 100644 --- a/src/calc/norm.rs +++ b/src/calc/norm.rs @@ -1,48 +1,75 @@ //! Normalization and de-normalization of data. use crate::data::DataFrame; +use crate::ParseEnumError; /// Normalization types. #[derive(PartialEq, Clone, Debug)] pub enum Norm { /// Normalize to [0, 1]. - Unity, + Unit, /// Normalize to a mean of 0.5 and standard deviation of 0.5. Gauss, /// No normalization None, } +impl Norm { + pub fn from_string(str: &str) -> Result { + match str { + "unit" => Ok(Norm::Unit), + "gauss" => Ok(Norm::Gauss), + "none" => Ok(Norm::None), + _ => Err(ParseEnumError(format!( + "Not a normalizer: {}. Must be one of (unit|gauss|none)", + str + ))), + } + } +} + /// De-normalization parameters. Obtained from [`normalize`](fn.normalize.html). #[derive(Debug)] -pub struct DeNorm { +pub struct LinearTransform { scale: f64, offset: f64, } +impl LinearTransform { + pub fn transform(&self, value: f64) -> f64 { + value * self.scale + self.offset + } + pub fn inverse(&self) -> LinearTransform { + LinearTransform { + scale: 1.0 / self.scale, + offset: -self.offset / self.scale, + } + } +} + /// Normalize a data frame, with a [`Norm`](struct.Norm.html) and scale per column. /// # Returns -/// A tuple of: (normalized data frame, vector of [`DeNorm`](struct.DeNorm.html), one per column). +/// A tuple of: (normalized data frame, vector of [`LinearTransform`](struct.LinearTransform.html) for de-normalization, one per column). pub fn normalize( - data: &DataFrame, + data: &DataFrame, norm: &[Norm], scale: &[f64], -) -> (DataFrame, Vec) { +) -> (DataFrame, Vec) { let mut counts = vec![0; data.ncols()]; let mut params: Vec<_> = norm .iter() .map(|n| match n { - Norm::Unity => (std::f64::MAX, std::f64::MIN), + Norm::Unit => (std::f64::MAX, std::f64::MIN), _ => (0.0, 0.0), }) .collect(); for row in data.iter_rows() { for (i, v) in row.iter().enumerate() { - let norm = &norm[i]; if !v.is_nan() { + let norm = &norm[i]; match norm { - Norm::Unity => { + Norm::Unit => { if *v < params[i].0 { params[i].0 = *v } @@ -60,25 +87,35 @@ pub fn normalize( } } } + //println!("Params: {:?}", params); + //println!("Counts: {:?}", counts); let denorm: Vec<_> = params .iter() .zip(counts) .zip(norm) .zip(scale) .map(|((((p1, p2), count), norm), scale)| match norm { - Norm::Unity => DeNorm { - scale: scale * 1.0 / (p2 - p1), - offset: -*p1, - }, + Norm::Unit => { + let sc = scale / (p2 - p1); + LinearTransform { + //scale: scale * 1.0 / (p2 - p1), + //offset: -*p1, + scale: sc, + offset: -*p1 * sc, + } + } Norm::Gauss => { let sd = ((count as f64 * p2 - p1.powi(2)) / (count * (count - 1)) as f64).sqrt(); let mean = p1 / count as f64; - DeNorm { - scale: scale * 1.0 / (2.0 * sd), - offset: -(mean - sd), + let sc = scale / (2.0 * sd); + LinearTransform { + //scale: scale * 1.0 / (2.0 * sd), + //offset: -(mean - sd), + scale: sc, + offset: -(mean - sd) * sc, } } - Norm::None => DeNorm { + Norm::None => LinearTransform { scale: *scale, offset: 0.0, }, @@ -86,39 +123,35 @@ pub fn normalize( .collect(); let cols: Vec<_> = data.names().iter().map(|x| &**x).collect(); - let mut df = DataFrame::::empty(&cols); + let mut df = DataFrame::empty(&cols); for row in data.iter_rows() { df.push_row_iter( denorm .iter() .zip(row) - .map(|(de, v)| (v + de.offset) * de.scale), + //.map(|(de, v)| (v + de.offset) * de.scale), + .map(|(de, v)| de.transform(*v)), ); } - /*let denorm = denorm - .iter() - .map(|de| DeNorm { - scale: 1.0 / de.scale, - offset: -de.offset, - }) - .collect();*/ + let denorm = denorm.iter().map(|de| de.inverse()).collect(); (df, denorm) } -/// De-normalize a data frame, with a [`DeNorm`](struct.DeNorm.html) per column, as obtained from [`normalize`](fn.normalize.html). +/// De-normalize a data frame, with a [`LinearTransform`](struct.DeNorm.html) per column, as obtained from [`normalize`](fn.normalize.html). /// # Returns /// A de-normalized data frame -pub fn denormalize(data: &DataFrame, denorm: &[DeNorm]) -> DataFrame { +pub fn denormalize(data: &DataFrame, denorm: &[LinearTransform]) -> DataFrame { let cols: Vec<_> = data.names().iter().map(|x| &**x).collect(); - let mut df = DataFrame::::empty(&cols); + let mut df = DataFrame::empty(&cols); for row in data.iter_rows() { df.push_row_iter( denorm .iter() .zip(row) - .map(|(de, v)| v / de.scale - de.offset), + //.map(|(de, v)| v / de.scale - de.offset), + .map(|(de, v)| de.transform(*v)), ); } df @@ -134,7 +167,7 @@ mod tests { #[test] fn normalization() { let mut rng = rand::thread_rng(); - let mut data = DataFrame::::empty(&["A", "B", "C"]); + let mut data = DataFrame::empty(&["A", "B", "C"]); let norm = rand::distributions::Normal::new(1.0, 2.0); for _i in 0..20 { @@ -147,9 +180,10 @@ mod tests { let (df, denorm) = normalize( &data, - &[Norm::Unity, Norm::Gauss, Norm::None], + &[Norm::Unit, Norm::Gauss, Norm::None], &[1.0, 1.0, 0.5], ); + assert_eq!(data.nrows(), df.nrows()); assert_eq!(data.ncols(), df.ncols()); diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..8206942 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,176 @@ +//! Command-line interface for SOMs. +use crate::calc::neighborhood::Neighborhood; +use crate::calc::norm::Norm; +use crate::map::som::{DecayFunction, DecayParam}; +use crate::proc::InputLayer; +use structopt::StructOpt; + +#[derive(StructOpt)] +#[structopt(name = "Super-SOM command line application")] +pub struct Cli { + // TODO: add and implement no-data value (use countries example) + /// Path to the training data file. + #[structopt(short, long)] + file: String, + /// SOM size: width, height. + #[structopt(short, long, number_of_values = 2)] + size: Vec, + /// Number of training episodes + #[structopt(short, long)] + episodes: u32, + /// Layer columns. Put layers in quotes: `"X1 X2 X3" "Y1"` + #[structopt(short, long)] + layers: Vec, + /// Layer weights list + #[structopt(short, long)] + weights: Vec, + /// Are layers categorical list (0/1). Default 1.0 + #[structopt(short, long)] + categ: Vec, + /// Normalizer per layer list (gauss, unit, none). Default gauss. + #[structopt(short, long)] + norm: Vec, + /// Learning rate: start, end, type (lin|exp) + #[structopt(short, long, number_of_values = 3)] + alpha: Vec, + /// Neighborhood radius: start, end, type (lin|exp) + #[structopt(short, long, number_of_values = 3)] + radius: Vec, + /// Weight decay: start, end, type (lin|exp) + #[structopt(short, long, number_of_values = 3)] + decay: Vec, + /// Neighborhood function (gauss|) + #[structopt(short = "-g", long)] + neigh: Option, + /// Disable GUI + #[structopt(long = "--no-gui")] + nogui: bool, + /// Disable GUI + #[structopt(long = "--fps")] + fps: Option, + /// No-data value. Default 'NA'. + #[structopt(long = "--no-data")] + no_data: Option, +} + +#[derive(Debug)] +pub struct CliParsed { + pub file: String, + pub size: (usize, usize), + pub episodes: u32, + pub layers: Vec, + pub alpha: DecayParam, + pub radius: DecayParam, + pub decay: DecayParam, + pub neigh: Neighborhood, + pub gui: bool, + pub no_data: String, + pub fps: f64, +} + +impl CliParsed { + pub fn from_cli(mut cli: Cli) -> Self { + CliParsed { + file: cli.file.clone(), + size: (cli.size[0], cli.size[1]), + episodes: cli.episodes, + layers: Self::to_layers(&mut cli), + alpha: Self::to_decay(cli.alpha, "alpha"), + radius: Self::to_decay(cli.radius, "radius"), + decay: Self::to_decay(cli.decay, "decay"), + neigh: match &cli.neigh { + Some(n) => Neighborhood::from_string(n).unwrap(), + None => Neighborhood::Gauss, + }, + gui: !cli.nogui, + no_data: cli.no_data.unwrap_or("NA".to_string()), + fps: cli.fps.unwrap_or(2.0), + } + } + + fn to_decay(values: Vec, name: &str) -> DecayParam { + if values.len() != 3 { + panic!(format!( + "Three argument required for {}: start value, end value, decay function (lin|exp)", + name + )); + } + DecayParam::new( + values[0] + .parse() + .expect(&format!("Unable to parse value {} in {}", values[0], name)), + values[1] + .parse() + .expect(&format!("Unable to parse value {} in {}", values[1], name)), + DecayFunction::from_string(&values[2]).unwrap(), + /* + match &values[2][..] { + "lin" => DecayFunction::Linear, + "exp" => DecayFunction::Exponential, + _ => panic!("Expected decay funtion 'lin' or 'exp'"), + },*/ + ) + } + fn to_layers(cli: &mut Cli) -> Vec { + if cli.layers.is_empty() { + panic!("Expected columns for at least one layer (option --layers)"); + } + let n_layers = cli.layers.len(); + + if cli.weights.len() != 0 && cli.weights.len() != n_layers { + panic!("Expected no weights, or as many as layers (option --weights)"); + } + if cli.categ.len() != 0 && cli.categ.len() != n_layers { + panic!("Expected no categorical 0/1, or as many as layers (option --weights)"); + } + if cli.norm.len() != 0 && cli.norm.len() != n_layers { + panic!("Expected no normalizers, or as many as layers (option --norm)"); + } + + if cli.weights.is_empty() { + cli.weights = vec![1.0; n_layers]; + } + if cli.categ.is_empty() { + cli.categ = vec![0; n_layers]; + } + if cli.norm.is_empty() { + cli.norm = vec!["gauss".to_string(); n_layers]; + } + + cli.layers + .iter() + .zip(&cli.weights) + .zip(&cli.categ) + .zip(&cli.norm) + .map(|(((lay, wt), cat), norm)| { + InputLayer::new( + &lay.trim().split(' ').map(|s| &*s).collect::>(), + *wt, + *cat > 0, + Norm::from_string(norm).unwrap(), + None, + ) + }) + .collect::>() + } +} + +/* +mod parse { + use crate::map::som::DecayFunction; + use std::convert::TryFrom; + + pub fn parse_decay_param(src: &str) -> (f64, f64, DecayFunction) { + let split: Vec<_> = src.split(" ").collect(); + ( + split[0].parse().unwrap(), + split[1].parse().unwrap(), + match split[2] { + "lin" => DecayFunction::Linear, + "exp" => DecayFunction::Exponential, + _ => panic!("Expected decay funtion 'lin' or 'exp'"), + }, + ) + } +} +*/ diff --git a/src/data/mod.rs b/src/data/mod.rs index 8bf5ee6..dab8245 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,25 +1,18 @@ //! Data structures like tables. -use num_traits::Float; use std::slice::{Chunks, ChunksMut}; /// A data frame with all columns of the same Float type. #[allow(dead_code)] -pub struct DataFrame -where - T: Float, -{ +pub struct DataFrame { ncols: usize, nrows: usize, names: Vec, - data: Vec, + data: Vec, } #[allow(dead_code)] -impl DataFrame -where - T: Float, -{ +impl DataFrame { /// Creates an empty data frame, with the given columns and zero rows. pub fn empty(columns: &[&str]) -> Self { DataFrame { @@ -31,7 +24,7 @@ where } /// Creates a blank data frame, with the given number of columns and rows, filled with a value. - pub fn filled(nrows: usize, columns: &[&str], fill: T) -> Self { + pub fn filled(nrows: usize, columns: &[&str], fill: f64) -> Self { DataFrame { names: columns.iter().map(|s| s.to_string()).collect(), ncols: columns.len(), @@ -41,7 +34,7 @@ where } /// Creates a data frame from a vector of rows. - pub fn from_rows(columns: &[&str], rows: &[Vec]) -> Self { + pub fn from_rows(columns: &[&str], rows: &[Vec]) -> Self { assert_eq!(columns.len(), rows[0].len()); DataFrame { names: columns.iter().map(|s| s.to_string()).collect(), @@ -66,7 +59,7 @@ where /// ` x x x x x x x x x x x x x x x ...` /// /// `|___ row 1 ___|___ row 2 ___|___ ...` - pub fn data(&self) -> &[T] { + pub fn data(&self) -> &[f64] { &self.data } @@ -74,54 +67,58 @@ where pub fn names(&self) -> &[String] { &self.names } + /// Returns the data frame's column names as a vector of references. + pub fn names_ref_vec(&self) -> Vec<&str> { + self.names.iter().map(|x| &**x).collect() + } /// Appends a row to the end of the data frame, from a slice. - pub fn push_row(&mut self, row: &[T]) { + pub fn push_row(&mut self, row: &[f64]) { assert_eq!(row.len(), self.ncols); self.data.extend_from_slice(row); self.nrows += 1; } /// Appends a row to the end of the data frame, from an iterator. - pub fn push_row_iter(&mut self, row: impl Iterator) { + pub fn push_row_iter(&mut self, row: impl Iterator) { self.data.extend(row); self.nrows += 1; } /// Returns a reference to the value at (row, column). - pub fn get(&self, row: usize, col: usize) -> &T { + pub fn get(&self, row: usize, col: usize) -> &f64 { let idx = self.index(row, col); &self.data[idx] } /// Returns a mutable reference to the value at (row, column). - pub fn get_mut(&mut self, row: usize, col: usize) -> &mut T { + pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 { let idx = self.index(row, col); &mut self.data[idx] } /// Sets the value at (row, column), consuming the value. - pub fn set(&mut self, row: usize, col: usize, value: T) { + pub fn set(&mut self, row: usize, col: usize, value: f64) { let idx = self.index(row, col); self.data[idx] = value } /// Returns a reference to the value at the given index in raw data. - pub fn get_at(&self, index: usize) -> &T { + pub fn get_at(&self, index: usize) -> &f64 { &self.data[index] } /// Returns a mutable reference to the value at the given index in raw data. - pub fn get_mut_at(&mut self, index: usize) -> &mut T { + pub fn get_mut_at(&mut self, index: usize) -> &mut f64 { &mut self.data[index] } /// Sets the value at the given index in raw data, consuming the value. - pub fn set_at(&mut self, index: usize, value: T) { + pub fn set_at(&mut self, index: usize, value: f64) { self.data[index] = value } /// Returns a row as a slice reference. - pub fn get_row(&self, row: usize) -> &[T] { + pub fn get_row(&self, row: usize) -> &[f64] { let idx = self.index(row, 0); &self.data[idx..idx + self.ncols] } /// Returns a row as a mutable slice reference. - pub fn get_row_mut(&mut self, row: usize, col: usize) -> &mut [T] { + pub fn get_row_mut(&mut self, row: usize, col: usize) -> &mut [f64] { let idx = self.index(row, col); &mut self.data[idx..idx + self.ncols] } @@ -139,52 +136,70 @@ where } /// An iterator over rows. - pub fn iter_rows(&self) -> Chunks { + pub fn iter_rows(&self) -> Chunks { self.data.chunks(self.ncols) } /// A mutable iterator over rows. - pub fn iter_rows_mut(&mut self) -> ChunksMut { + pub fn iter_rows_mut(&mut self) -> ChunksMut { self.data.chunks_mut(self.ncols) } /// Copies a column's values into a new vector. - pub fn copy_column(&self, column: usize) -> Vec { + pub fn copy_column(&self, column: usize) -> Vec { self.iter_rows().map(|row| row[column]).collect() } /// Returns ranges of columns. - pub fn ranges(&self) -> Vec<(T, T)> { + pub fn ranges(&self) -> Vec<(f64, f64)> { let ncol = self.ncols; - let mut min = vec![T::max_value(); self.ncols]; - let mut max = vec![T::min_value(); self.ncols]; + let mut min = vec![std::f64::MAX; ncol]; + let mut max = vec![std::f64::MIN; ncol]; + let mut any = vec![false; ncol]; for row in self.iter_rows() { for col in 0..ncol { let v = row[col]; - if v < min[col] { - min[col] = v; - } - if v > max[col] { - max[col] = v; + if !v.is_nan() { + if v < min[col] { + min[col] = v; + } + if v > max[col] { + max[col] = v; + } + any[col] = true; } } } - min.into_iter().zip(max).collect() + min.into_iter() + .zip(max) + .enumerate() + .map(|(i, (mn, mx))| { + if any[i] { + (mn, mx) + } else { + (std::f64::NAN, std::f64::NAN) + } + }) + .collect() } /// Returns means of columns. - pub fn means(&self) -> Vec { + pub fn means(&self) -> Vec { let ncol = self.ncols; - let nrows = T::from(self.nrows).unwrap(); + //let nrows = self.nrows; - let mut means = vec![T::zero(); ncol]; + let mut means = vec![0.0; ncol]; + let mut counts = vec![0; ncol]; for row in self.iter_rows() { for col in 0..ncol { let v = row[col]; - means[col] = means[col] + v; + if !v.is_nan() { + means[col] += v; + counts[col] += 1; + } } } - for col in 0..ncol { - means[col] = means[col] / nrows; + for i in 0..ncol { + means[i] /= counts[i] as f64; } means } @@ -198,7 +213,7 @@ mod test { fn create_df() { let cols = ["A", "B", "C", "D"]; let rows = 100; - let df = DataFrame::::filled(rows, &cols, 0.0); + let df = DataFrame::filled(rows, &cols, 0.0); assert_eq!(df.ncols, cols.len()); assert_eq!(df.nrows, rows); @@ -213,7 +228,7 @@ mod test { vec![2.0, 3.0, 4.0, 5.], vec![3.0, 4.0, 5.0, 6.0], ]; - let df = DataFrame::::from_rows(&cols, &data); + let df = DataFrame::from_rows(&cols, &data); assert_eq!(df.ncols, 4); assert_eq!(df.nrows, 3); @@ -223,7 +238,7 @@ mod test { #[test] fn add_rows() { let cols = ["A", "B", "C", "D"]; - let mut df = DataFrame::::empty(&cols); + let mut df = DataFrame::empty(&cols); df.push_row(&[1.0, 2.0, 3.0, 4.0]); df.push_row(&[2.0, 3.0, 4.0, 5.0]); @@ -242,7 +257,7 @@ mod test { fn iter_rows() { let cols = ["A", "B", "C", "D"]; let rows = 10; - let mut df = DataFrame::::empty(&cols); + let mut df = DataFrame::empty(&cols); for _i in 0..rows { df.push_row(&[1.0, 2.0, 3.0, 4.0]); @@ -259,7 +274,7 @@ mod test { #[test] fn ranges() { let cols = ["A", "B", "C", "D"]; - let mut df = DataFrame::::empty(&cols); + let mut df = DataFrame::empty(&cols); df.push_row(&[1.0, 2.0, 3.0, 4.0]); df.push_row(&[2.0, 3.0, 4.0, 5.0]); diff --git a/src/lib.rs b/src/lib.rs index 71e514e..393b2a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,19 @@ //! Self-organizing maps / Kohonen maps with an arbitrary amount of layers (Super-SOMs). pub mod calc; +pub mod cli; pub mod data; pub mod map; pub mod proc; pub mod ui; + +use core::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParseEnumError(String); + +impl fmt::Display for ParseEnumError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/main.rs b/src/main.rs index f328e4d..66fc9ac 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1 +1,60 @@ -fn main() {} +use easy_graph::ui::window::WindowBuilder; +use kohonen::cli::{Cli, CliParsed}; +use kohonen::proc::ProcessorBuilder; +use kohonen::ui::LayerView; +use std::time::Duration; +use structopt::StructOpt; + +fn main() { + let args = Cli::from_args(); + let parsed = CliParsed::from_cli(args); + println!("{:#?}", parsed); + + let proc = ProcessorBuilder::new(&parsed.layers) + .with_delimiter(b';') + .with_no_data(&parsed.no_data) + .build_from_file(&parsed.file) + .unwrap(); + + let mut som = proc.create_som( + parsed.size.1, + parsed.size.0, + parsed.episodes, + parsed.neigh.clone(), + parsed.alpha.clone(), + parsed.radius.clone(), + parsed.decay.clone(), + ); + + let mut viewers: Option> = if parsed.gui { + Some( + proc.layers() + .iter() + .enumerate() + .map(|(i, _)| { + let win = WindowBuilder::new() + .with_dimensions(800, 700) + .with_fps_skip(parsed.fps) + .build(); + LayerView::new(win, &[i], &proc.data().names_ref_vec(), None) + }) + .collect(), + ) + } else { + None + }; + + if let Some(views) = &mut viewers { + while views.iter().fold(false, |a, v| a || v.is_open()) { + let res = som.epoch(&proc.data(), None); + for view in views.iter_mut() { + view.draw(&som); + } + if res.is_none() { + std::thread::sleep(Duration::from_millis(40)); + } + } + } else { + while let Some(()) = som.epoch(&proc.data(), None) {} + } +} diff --git a/src/map/som.rs b/src/map/som.rs index 1d1575a..83f8095 100644 --- a/src/map/som.rs +++ b/src/map/som.rs @@ -4,31 +4,26 @@ use crate::calc::metric::{Metric, SqEuclideanMetric}; use crate::calc::neighborhood::Neighborhood; use crate::calc::nn; use crate::data::DataFrame; +use crate::ParseEnumError; use rand::prelude::*; use std::cmp; /// SOM training parameters -pub struct SomParams -where - N: Neighborhood, -{ +pub struct SomParams { epochs: u32, //metric: M, - neighborhood: N, + neighborhood: Neighborhood, alpha: DecayParam, radius: DecayParam, decay: DecayParam, layers: Vec, } -impl SomParams -where - N: Neighborhood, -{ +impl SomParams { /// Creates parameters for a simple SOM with a simple layer. pub fn simple( epochs: u32, - neighborhood: N, + neighborhood: Neighborhood, alpha: DecayParam, radius: DecayParam, decay: DecayParam, @@ -46,7 +41,7 @@ where /// Creates parameters for a multi-layers SOM (Super-SOM) using the X-Y-Fused algorithm (XYF). pub fn xyf( epochs: u32, - neighborhood: N, + neighborhood: Neighborhood, alpha: DecayParam, radius: DecayParam, decay: DecayParam, @@ -107,19 +102,41 @@ impl Layer { } /// Decay functions for learing parameters. +#[derive(Debug, Clone)] pub enum DecayFunction { /// Linear decay Linear, /// Exponential decay Exponential, } +impl DecayFunction { + pub fn from_string(str: &str) -> Result { + match str { + "lin" => Ok(DecayFunction::Linear), + "exp" => Ok(DecayFunction::Exponential), + _ => Err(ParseEnumError(format!( + "Not a decay function: {}. Must be one of (lin|exp)", + str + ))), + } + } +} /// Decay parameters for learing parameters. +#[derive(Debug, Clone)] pub struct DecayParam { start: f64, end: f64, function: DecayFunction, } impl DecayParam { + /// Creates a learning parameter from start and end value and decay function. + pub fn new(start: f64, end: f64, function: DecayFunction) -> Self { + DecayParam { + start, + end, + function, + } + } /// Creates a linearly decaying learning parameter from start and end value. pub fn lin(start: f64, end: f64) -> Self { DecayParam { @@ -153,26 +170,20 @@ impl DecayParam { /// Super-SOM core type. #[allow(dead_code)] -pub struct Som -where - N: Neighborhood, -{ +pub struct Som { dims: usize, nrows: usize, ncols: usize, - weights: DataFrame, - distances_sq: DataFrame, - params: SomParams, + weights: DataFrame, + distances_sq: DataFrame, + params: SomParams, epoch: u32, } #[allow(dead_code)] -impl Som -where - N: Neighborhood, -{ +impl Som { /// Creates a new SOM or Super-SOM - pub fn new(dims: usize, nrows: usize, ncols: usize, params: SomParams) -> Self { + pub fn new(dims: usize, nrows: usize, ncols: usize, params: SomParams) -> Self { let mut som = Som { dims, nrows, @@ -187,7 +198,7 @@ where } /// Returns a reference to the SOM's parameters. - pub fn params(&self) -> &SomParams { + pub fn params(&self) -> &SomParams { &self.params } @@ -203,7 +214,7 @@ where } /// Pre-calculates the unit-to-unit distance matrix. - fn calc_distance_matix(nrows: usize, ncols: usize) -> DataFrame { + fn calc_distance_matix(nrows: usize, ncols: usize) -> DataFrame { let metric = SqEuclideanMetric(); let mut df = DataFrame::filled(nrows * ncols, &vec![""; nrows * ncols], 0.0); for r1 in 0..nrows { @@ -215,8 +226,7 @@ where df.set( idx1, idx2, - metric.distance(&[r1 as f64, c1 as f64], &[r2 as f64, c2 as f64]) - as f32, + metric.distance(&[r1 as f64, c1 as f64], &[r2 as f64, c2 as f64]), ); } } @@ -233,7 +243,7 @@ where (row * self.ncols as i32 + col) as usize } /// Returns a reference to the units weights data frame. - pub fn weights(&self) -> &DataFrame { + pub fn weights(&self) -> &DataFrame { &self.weights } /// Returns a reference to the weights of the unit at (row, col). @@ -254,7 +264,7 @@ where } /// Trains the SOM for one epoch. Updates learning parameters - pub fn epoch(&mut self, samples: &DataFrame, count: Option) -> Option<()> { + pub fn epoch(&mut self, samples: &DataFrame, count: Option) -> Option<()> { if self.epoch >= self.params.epochs { return None; } @@ -263,7 +273,7 @@ where let mut indices: Vec<_> = (0..samples.nrows()).collect(); rng.shuffle(&mut indices); - let cnt = cmp::min(count.unwrap_or(samples.nrows()), samples.nrows()); + let cnt = cmp::min(count.unwrap_or_else(|| samples.nrows()), samples.nrows()); for idx in indices.iter().take(cnt) { let sample = samples.get_row(*idx); @@ -327,9 +337,12 @@ where if dist_sq <= search_rad_sq { let weight = neigh.weight(radius_inf_sq * dist_sq); for i in 0..self.dims { - let value = *self.weights.get(index, i); - self.weights - .set(index, i, value + weight * alpha * (sample[i] - value)) + let smp = sample[i]; + if !smp.is_nan() { + let value = *self.weights.get(index, i); + self.weights + .set(index, i, value + weight * alpha * (smp - value)); + } } } } @@ -339,7 +352,7 @@ where #[cfg(test)] mod test { - use crate::calc::neighborhood::GaussNeighborhood; + use crate::calc::neighborhood::Neighborhood; use crate::data::DataFrame; use crate::map::som::{DecayParam, Som, SomParams}; use rand::Rng; @@ -348,7 +361,7 @@ mod test { fn create_som() { let params = SomParams::simple( 100, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(1.0, 0.5), DecayParam::lin(0.2, 0.001), @@ -361,7 +374,7 @@ mod test { fn create_large_som() { let params = SomParams::simple( 100, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(1.0, 0.5), DecayParam::lin(0.2, 0.001), @@ -378,7 +391,7 @@ mod test { fn train_step() { let params = SomParams::simple( 100, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(1.0, 0.5), DecayParam::lin(0.2, 0.001), @@ -392,7 +405,7 @@ mod test { let cols = ["A", "B", "C", "D", "E"]; let params = SomParams::simple( 10, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(5.0, 0.5), DecayParam::exp(0.2, 0.001), @@ -400,7 +413,7 @@ mod test { let mut som = Som::new(cols.len(), 16, 16, params); let mut rng = rand::thread_rng(); - let mut data = DataFrame::::empty(&cols); + let mut data = DataFrame::empty(&cols); for _i in 0..100 { data.push_row(&[ diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 8c250ce..b57776a 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -1,7 +1,7 @@ //! Pre- and post-processing of SOM training data, SOM creation. use crate::calc::neighborhood::Neighborhood; -use crate::calc::norm::{normalize, DeNorm, Norm}; +use crate::calc::norm::{normalize, LinearTransform, Norm}; use crate::data::DataFrame; use crate::map::som::{DecayParam, Layer, Som, SomParams}; use csv::{ReaderBuilder, StringRecord}; @@ -86,43 +86,62 @@ impl InputLayer { } } +#[derive(Clone, Debug)] +pub struct CsvOptions { + delimiter: u8, + no_data: String, +} + pub struct ProcessorBuilder { input_layers: Vec, - delimiter: u8, + csv_options: CsvOptions, } impl ProcessorBuilder { pub fn new(layers: &[InputLayer]) -> Self { ProcessorBuilder { input_layers: layers.to_vec(), - delimiter: b',', + csv_options: CsvOptions { + delimiter: b',', + no_data: "NA".to_string(), + }, } } pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.delimiter = delimiter; + self.csv_options.delimiter = delimiter; + self + } + pub fn with_no_data(mut self, no_data: &str) -> Self { + self.csv_options.no_data = no_data.to_string(); self } pub fn build_from_file(self, path: &str) -> Result> { - let proc = Processor::new(self.input_layers, path)?; + let proc = Processor::new(self.input_layers, path, &self.csv_options)?; Ok(proc) } } +#[allow(dead_code)] pub struct Processor { input_layers: Vec, - data: DataFrame, + data: DataFrame, layers: Vec, norm: Vec, - denorm: Vec, + denorm: Vec, scale: Vec, + csv_options: CsvOptions, } impl Processor { - fn new(input_layers: Vec, path: &str) -> Result> { - Self::read_file(input_layers, path) + fn new( + input_layers: Vec, + path: &str, + csv_options: &CsvOptions, + ) -> Result> { + Self::read_file(input_layers, path, csv_options) } /// The normalized data. - pub fn data(&self) -> &DataFrame { + pub fn data(&self) -> &DataFrame { &self.data } pub fn layers(&self) -> &[Layer] { @@ -134,7 +153,7 @@ impl Processor { pub fn norm(&self) -> &[Norm] { &self.norm } - pub fn denorm(&self) -> &[DeNorm] { + pub fn denorm(&self) -> &[LinearTransform] { &self.denorm } pub fn scale(&self) -> &[f64] { @@ -144,10 +163,13 @@ impl Processor { fn read_file( mut input_layers: Vec, path: &str, + csv_options: &CsvOptions, ) -> Result> { + let no_data = &csv_options.no_data; + // Read csv let mut reader = ReaderBuilder::new() - .delimiter(b';') + .delimiter(csv_options.delimiter) .from_path(path) .unwrap(); let header: StringRecord = reader.headers().unwrap().clone(); @@ -179,7 +201,7 @@ impl Processor { for (idx, lay) in categorical.iter() { let v = rec.get(lay.indices.as_ref().unwrap()[0]).unwrap(); let levels = &mut cat_levels[*idx]; - if !levels.contains(v) { + if v != no_data && !levels.contains(v) { levels.insert(v.to_string()); } } @@ -217,10 +239,12 @@ impl Processor { let levels = &cat_levels[idx]; colnames.extend(levels.iter().map(|l| base.clone() + l)); } else { - colnames.extend(lay.names.iter().map(|l| l.clone())); + colnames.extend(lay.names.iter().cloned()); } } - let mut df = DataFrame::::empty(&colnames.iter().map(|x| &**x).collect::>()); + + // transform to SOM training data format + let mut df = DataFrame::empty(&colnames.iter().map(|x| &**x).collect::>()); let mut row = vec![0.0; colnames.len()]; reader.seek(start_pos).unwrap(); @@ -234,15 +258,29 @@ impl Processor { let indices = inp.indices.as_ref().unwrap(); if inp.is_class { let v = rec.get(indices[0]).unwrap(); - let pos = cat_levels[layer_index] - .iter() - .position(|v2| v == v2) - .unwrap(); - row[start + pos] = 1.0; + if v == no_data { + for i in start..(start + cat_levels[layer_index].len()) { + row[i] = std::f64::NAN; + } + } else { + let pos = cat_levels[layer_index] + .iter() + .position(|v2| v == v2) + .unwrap(); + row[start + pos] = 1.0; + } } else { for (i, idx) in inp.indices.as_ref().unwrap().iter().enumerate() { - let v: f64 = rec.get(*idx).unwrap().parse()?; - row[start + i] = v; + let str = rec.get(*idx).unwrap(); + if str == no_data { + row[start + i] = std::f64::NAN; + } else { + let v: f64 = str.parse().expect(&format!( + "Unable to parse value {} in column {}", + str, inp.names[i] + )); + row[start + i] = v; + } } } start += lay.ncols(); @@ -259,10 +297,14 @@ impl Processor { } } let (data_norm, denorm) = normalize(&df, &norm, &scale); + /* for row in df.iter_rows() { println!("{:?}", row); } + for row in data_norm.iter_rows() { + println!("{:?}", row); + } println!("{:?}", cat_levels); println!("{:?}", df.names()); println!("{:?}", norm); @@ -276,29 +318,27 @@ impl Processor { norm, denorm, scale, + csv_options: csv_options.clone(), }) } - pub fn create_som( + pub fn create_som( &self, nrows: usize, ncols: usize, epochs: u32, - neighborhood: N, + neighborhood: Neighborhood, alpha: DecayParam, radius: DecayParam, decay: DecayParam, - ) -> Som - where - N: Neighborhood, - { + ) -> Som { let params = SomParams::xyf( epochs, neighborhood, alpha, radius, decay, - self.layers.iter().map(|l| l.clone()).collect(), + self.layers.to_vec(), ); Som::new(self.data.ncols(), nrows, ncols, params) @@ -307,12 +347,10 @@ impl Processor { #[cfg(test)] mod test { - use crate::calc::neighborhood::GaussNeighborhood; + use crate::calc::neighborhood::Neighborhood; use crate::calc::norm::Norm; use crate::map::som::DecayParam; use crate::proc::{InputLayer, ProcessorBuilder}; - use crate::ui::LayerView; - use easy_graph::ui::window::WindowBuilder; #[test] fn create_proc() { @@ -331,11 +369,11 @@ mod test { .build_from_file("example_data/iris.csv") .unwrap(); - let mut som = proc.create_som( + let som = proc.create_som( 16, 20, 1000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.2, 0.01), DecayParam::lin(8.0, 0.5), DecayParam::exp(0.2, 0.001), diff --git a/src/ui/layer_view.rs b/src/ui/layer_view.rs index 6f5082e..4ef5bab 100644 --- a/src/ui/layer_view.rs +++ b/src/ui/layer_view.rs @@ -1,8 +1,10 @@ //! Viewer for SOMs as heatmaps. -use crate::calc::neighborhood::Neighborhood; use crate::map::som::Som; -use easy_graph::color::style::{ShapeStyle, BLACK, GREEN, RED, WHITE, YELLOW}; +use easy_graph::color::style::text_anchor::{HPos, Pos, VPos}; +use easy_graph::color::style::{ + IntoFont, Palette, Palette99, ShapeStyle, TextStyle, BLACK, GREEN, RED, WHITE, YELLOW, +}; use easy_graph::color::{ColorMap, LinearColorMap}; use easy_graph::ui::drawing::IntoDrawingArea; use easy_graph::ui::element::Rectangle; @@ -12,16 +14,25 @@ use easy_graph::ui::window::BufferWindow; pub struct LayerView { window: BufferWindow, layers: Vec, + names: Vec, layout_columns: Option, + scale: Option, } impl LayerView { /// Creates a new viewer for a selection of layers, or of all layers it `layers` is empty. - pub fn new(window: BufferWindow, layers: &[usize], layout_columns: Option) -> Self { + pub fn new( + window: BufferWindow, + layers: &[usize], + names: &[&str], + layout_columns: Option, + ) -> Self { LayerView { window, layers: layers.to_vec(), + names: names.iter().map(|n| n.to_string()).collect(), layout_columns, + scale: None, } } /// If the viewer's window is still open. @@ -29,13 +40,124 @@ impl LayerView { self.window.is_open() } /// Draws the given SOM. Should be called only for the same SOM repeatedly, not for different SOMs! - pub fn draw(&mut self, som: &Som) - where - N: Neighborhood, - { + pub fn draw(&mut self, som: &Som) { + let params = som.params(); + if (self.layers.len() == 1 && params.layers()[self.layers[0]].categorical()) + || (self.layers.is_empty() + && params.layers().len() == 1 + && params.layers()[0].categorical()) + { + self.draw_classes(som); + } else { + self.draw_columns(som); + } + } + + fn draw_classes(&mut self, som: &Som) { + let params = som.params(); + let layer = if self.layers.is_empty() { + 0 + } else { + self.layers[0] + }; + let start_col = params + .layers() + .iter() + .enumerate() + .take_while(|(i, _)| *i != layer) + .map(|(_, l)| l.ncols()) + .sum(); + let classes: Vec<_> = self.names[start_col..(start_col + params.layers()[layer].ncols())] + .iter() + .map(|n| n.splitn(2, ':').nth(1).unwrap()) + .collect(); + + let columns = self.get_columns(som); + + let margin = 5_i32; + let heading = 16_i32; + let legend = 120_i32; + + let (som_rows, som_cols) = som.size(); + let (width, height) = self.window.size(); + let width = width - 2 * margin as usize; + let height = height - 2 * margin as usize; + + if self.layout_columns.is_none() { + let (cols, scale) = + Self::calc_layout_columns(width, height, som_rows, som_cols, 1, heading, legend); + self.layout_columns = Some(cols); + self.scale = Some(scale); + } + + let scale = self.scale.unwrap(); + let test_style = + TextStyle::from(("sans-serif", 14).into_font()).pos(Pos::new(HPos::Left, VPos::Top)); + + self.window.draw(|b| { + let root = b.into_drawing_area(); + root.fill(&WHITE).unwrap(); + + let x_min = margin; + let y_min = margin + heading; + for (idx, row) in som.weights().iter_rows().enumerate() { + let (r, c) = som.to_row_col(idx); + let x = x_min + (c as i32 * scale); + let y = y_min + (r as i32 * scale); + + let mut v_max = std::f64::MIN; + let mut idx_max = 0; + for (index, col) in columns.iter() { + let v = row[*col]; + if v > v_max { + v_max = v; + idx_max = *index; + } + } + + let color = Palette99::pick(idx_max); //color_map.get_color(v_min, v_max, v); + + root.draw(&Rectangle::new( + [(x, y), (x + scale, y + scale)], + ShapeStyle::from(&color).filled(), + )) + .unwrap(); + } + root.draw(&Rectangle::new( + [ + (x_min, y_min), + ( + x_min + scale * som_cols as i32, + y_min + scale * som_rows as i32, + ), + ], + ShapeStyle::from(&BLACK), + )) + .unwrap(); + + let x = x_min + som.ncols() as i32 * scale + 10; + for (i, class) in classes.iter().enumerate() { + let color = Palette99::pick(i); + root.draw(&Rectangle::new( + [ + (x, y_min + i as i32 * 14), + (x + 10, y_min + i as i32 * 14 + 10), + ], + ShapeStyle::from(&color).filled(), + )) + .unwrap(); + root.draw_text(class, &test_style, (x + 14, y_min + i as i32 * 14)) + .unwrap(); + } + }); + } + + fn draw_columns(&mut self, som: &Som) { let columns = self.get_columns(som); let margin = 5_i32; + let heading = 16_i32; + let legend = 20_i32; let (som_rows, som_cols) = som.size(); let (width, height) = self.window.size(); @@ -43,13 +165,17 @@ impl LayerView { let height = height - 2 * margin as usize; if self.layout_columns.is_none() { - self.layout_columns = Some(Self::calc_layout_columns( + let (cols, scale) = Self::calc_layout_columns( width, height, som_rows, som_cols, columns.len(), - )); + heading, + legend, + ); + self.layout_columns = Some(cols); + self.scale = Some(scale); } let layout_columns = self.layout_columns.unwrap(); @@ -58,12 +184,14 @@ impl LayerView { let panel_width = width as f64 / layout_columns as f64; let panel_height = height as f64 / layout_rows as f64; - let x_scale = panel_width / som_cols as f64; - let y_scale = panel_height / som_rows as f64; - let scale = (if x_scale < y_scale { x_scale } else { y_scale }) as i32; + let scale = self.scale.unwrap(); let ranges = som.weights().ranges(); + let color_map = LinearColorMap::new(&[&GREEN, &YELLOW, &RED]); + let names = &self.names; + let test_style = + TextStyle::from(("sans-serif", 14).into_font()).pos(Pos::new(HPos::Left, VPos::Bottom)); self.window.draw(|b| { let root = b.into_drawing_area(); @@ -73,7 +201,7 @@ impl LayerView { let lay_row = index / layout_columns; let lay_col = index % layout_columns; let x_min = margin + (lay_col as f64 * panel_width) as i32; - let y_min = margin + (lay_row as f64 * panel_height) as i32; + let y_min = margin + heading + (lay_row as f64 * panel_height) as i32; for (idx, row) in som.weights().iter_rows().enumerate() { let (r, c) = som.to_row_col(idx); let v = row[col]; @@ -99,14 +227,34 @@ impl LayerView { ShapeStyle::from(&BLACK), )) .unwrap(); + root.draw_text(&names[col], &test_style, (x_min, y_min - 1)) + .unwrap(); + let steps = 25; + let total_height = scale * som.nrows() as i32 - 40; + let total_width = scale * som.ncols() as i32; + let x = x_min + total_width; + for i in 0..steps { + let value = i as f64 / steps as f64; + let color = color_map.get_color(0.0, 1.0, value); + let y = y_min + total_height + 20 - (total_height as f64 * value) as i32; + root.draw(&Rectangle::new( + [ + (x + 3, y), + ( + x + legend - 3, + y + (total_height as f64 / steps as f64) as i32, + ), + ], + ShapeStyle::from(&color).filled(), + )) + .unwrap(); + } } }); } + /// Calculates the required columns as a vector of (index, column index). - fn get_columns(&self, som: &Som) -> Vec<(usize, usize)> - where - N: Neighborhood, - { + fn get_columns(&self, som: &Som) -> Vec<(usize, usize)> { let params = som.params(); let mut columns = vec![]; if params.layers().is_empty() || self.layers.is_empty() { @@ -133,28 +281,40 @@ impl LayerView { som_rows: usize, som_cols: usize, data_columns: usize, - ) -> usize { - (1..data_columns) - .map(|cols| { - let layout_rows = (data_columns as f64 / cols as f64).ceil() as usize; - let panel_width = width as f64 / cols as f64; - let panel_height = height as f64 / layout_rows as f64; - - let x_scale = panel_width / som_cols as f64; - let y_scale = panel_height / som_rows as f64; - let scale = (if x_scale < y_scale { x_scale } else { y_scale }) as i32; - - (cols, scale) - }) - .max_by(|(_col1, scale1), (_col2, scale2)| scale1.cmp(scale2)) - .unwrap() - .0 + heading: i32, + legend: i32, + ) -> (usize, i32) { + if data_columns == 1 { + let panel_width = width as f64 - legend as f64; + let panel_height = height as f64 - heading as f64; + + let x_scale = panel_width / som_cols as f64; + let y_scale = panel_height / som_rows as f64; + let scale = (if x_scale < y_scale { x_scale } else { y_scale }) as i32; + + (1, scale) + } else { + (1..data_columns) + .map(|cols| { + let layout_rows = (data_columns as f64 / cols as f64).ceil() as usize; + let panel_width = (width as f64 / cols as f64) - legend as f64; + let panel_height = (height as f64 / layout_rows as f64) - heading as f64; + + let x_scale = panel_width / som_cols as f64; + let y_scale = panel_height / som_rows as f64; + let scale = (if x_scale < y_scale { x_scale } else { y_scale }) as i32; + + (cols, scale) + }) + .max_by(|(_col1, scale1), (_col2, scale2)| scale1.cmp(scale2)) + .unwrap() + } } } #[cfg(test)] mod test { - use crate::calc::neighborhood::GaussNeighborhood; + use crate::calc::neighborhood::Neighborhood; use crate::map::som::{DecayParam, Layer, Som, SomParams}; use crate::ui::layer_view::LayerView; use easy_graph::ui::window::WindowBuilder; @@ -164,7 +324,7 @@ mod test { let dim = 5; let params = SomParams::xyf( 1000, - GaussNeighborhood(), + Neighborhood::Gauss, DecayParam::lin(0.1, 0.01), DecayParam::lin(10.0, 0.6), DecayParam::exp(0.25, 0.0001), @@ -177,7 +337,7 @@ mod test { .with_fps_skip(10.0) .build(); - let mut view = LayerView::new(win, &[0], None); + let mut view = LayerView::new(win, &[0], &["A", "B", "C", "D", "E"], None); //while view.window.is_open() { view.draw(&som);