-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_nnd_inf_vs_fixed_gan.py
62 lines (53 loc) · 2.22 KB
/
plot_nnd_inf_vs_fixed_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
from matplotlib import pyplot as plt
import argparse
import re
import os
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-path', type=str, default='~/github/exps/mdl/mnist/',
help='path to the nnd result folder')
args = parser.parse_args()
cs = ['r', 'g', 'b', 'y', 'k', 'm']
folder = os.path.expanduser(args.path)
train_sizes = ['10k', '30k', '60k']
nnd_means = {}
nnd_stds = {}
fig, ax = plt.subplots(1, 1)
with open('results/nnd_inf_vs_fixed_gan.txt', 'w+') as resultf:
for i, ts in enumerate(train_sizes):
runs = os.listdir(folder + '/' + ts)
nndts = []
for run in runs:
# print(ts, run)
nndi = []
with open(folder + '/' + ts + '/' + run + '/nnd.txt') as rf:
for line in rf.readlines():
vals = line.strip().split(' ')
vals = int(vals[0]), float(vals[1]), float(vals[2])
nndi.append(vals)
# print(vals)
nndi = torch.tensor(nndi)
nndts.append(nndi)
# print(nndi)
nndts = torch.stack(nndts)
print(nndts.size())
nnd_means[ts] = nndts.mean(dim=0)
nnd_stds[ts] = nndts.std(dim=0)
print(nnd_means[ts])
print(nnd_stds[ts])
resultf.write('ts ' + ts + '\n')
resultf.write('mean\n')
resultf.write(str(nnd_means[ts]) + '\n')
resultf.write('std\n')
resultf.write(str(nnd_stds[ts]) + '\n')
ax.errorbar(x=nnd_means[ts][1:, 0].long(), y=nnd_means[ts][1:, 1], yerr=nnd_stds[ts][1:, 1],
label='Fix ' + ts, linestyle='--', capsize=4, c=cs[2 * i])
ax.errorbar(x=nnd_means[ts][1:, 0], y=nnd_means[ts][1:, 2], yerr=nnd_stds[ts][1:, 2],
label='Inf ' + ts, linestyle='-', capsize=4, c=cs[2 * i + 1])
print()
ax.legend(fontsize=16)
ax.set_ylabel('NND', fontsize=16)
ax.set_xlabel('Epoch', fontsize=16)
fig.savefig('results/nnd_inf_vs_fixed_gan.pdf', bbox_inches='tight')