-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
show_mnist.py
63 lines (56 loc) · 1.84 KB
/
show_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from model import Net
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from model import Net
import matplotlib.cm as cm
import numpy as np
colors = cm.rainbow(np.linspace(0, 1, 10))
print(colors)
######################################################################
# Load model
#---------------------------
def load_network(network):
save_path = os.path.join('./model/best.pth')
network.load_state_dict(torch.load(save_path))
return network
def test(model, test_loader):
test_loss = 0
correct = 0
is_appear = np.zeros(10)
with torch.no_grad():
for data, target in test_loader:
data = data.cuda()
output = model(data)
location = output.data.cpu()
for i in range(data.size(0)):
l = target[i].data.numpy()
if is_appear[l]==0:
is_appear[l] = 1
ax.scatter( location[i, 0], location[i, 1], c=colors[l], s=10, label = l,
alpha=0.7, edgecolors='none')
else:
ax.scatter( location[i, 0], location[i, 1], c=colors[l], s=10,
alpha=0.7, edgecolors='none')
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=100, shuffle=False)
model = Net()
model = load_network(model)
model.fc2 = nn.Sequential()
model = model.eval()
model = model.cuda()
fig, ax = plt.subplots()
test(model, test_loader)
ax.grid(True)
ax.legend(loc='best')
fig.savefig('train.jpg')