-
Notifications
You must be signed in to change notification settings - Fork 0
/
softmax回归的从零开始实现
197 lines (190 loc) · 7.16 KB
/
softmax回归的从零开始实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from IPython import display
from matplotlib import pyplot as plt
def get_fashion_mnist_labels(labels):
"""返回Fashion-Mnist数据集的文本标签"""
text_labels=['t-shirt','trouser','pullover','dress','coat',
'sandal','shirt','sneaker','bag','ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):
"""plot a list of images"""
figsize=(num_cols*scale,num_rows*scale)
_,axes=plt.subplots(num_rows,num_cols,figsize=figsize)
axes=axes.flatten()
for i,(ax,img) in enumerate(zip(axes,imgs)):
if torch.is_tensor(img):
#图片张量
ax.imshow(img.numpy())
else:
#PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
def set_axes(axes,xlabel,ylabel,xlim,ylim,xscale,yscale,legend):
"""设置matplotlib的轴"""
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
axes.set_xlim(xlim)
axes.set_xlim(xlim)
if legend:
axes.legend(legend)
axes.grid()
class Accumulator:
"""在n个变量上累加"""
def __init__(self,n):
self.data=[0.0]*n
def add(self, *args):
self.data=[a+float(b) for a,b in zip(self.data,args)]
def reset(self):
self.data=[0.0]*len(self.data)
def __getitem__(self,idx):
return self.data[idx]
def get_dataloader_workers():
"""使用2个进程来读取数据"""
return 2
def load_data_fashion_mnist(batch_size,resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans=[transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans=transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="./", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="./", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=True,
num_workers=get_dataloader_workers()))
def softmax(x):
x_exp=torch.exp(x)
partition=x_exp.sum(1,keepdim=True)
return x_exp/partition
def net(x):
return softmax(torch.matmul(x.reshape((-1,w.shape[0])),w)+b)
def cross_entropy(y_hat,y):
return -torch.log(y_hat[range(len(y_hat)),y])
def accuracy(y_hat,y):
"""计算预测正确的数量"""
if len(y_hat.shape)>1 and y_hat.shape[1]>1:
y_hat=y_hat.argmax(axis=1)
cmp=y_hat.type(y.dtype)==y
return float(cmp.type(y.dtype).sum())
def evaluate_accuacy(net,data_iter):
"""计算在指定数据集上模型的精度"""
if isinstance(net,torch.nn.Module):
net.eval() #将模型设置为评估状态
metric=Accumulator(2) #正确预测数、预测总数
for x,y in data_iter:
metric.add(accuracy(net(x),y),y.numel())
return metric[0]/metric[1]
def train_epoch_ch3(net,train_iter,loss,updater):
"""训练模型一个迭代周期"""
#将模型设置为训练模式
if isinstance(net,torch.nn.Module):
net.train()
#训练损失总和、训练准确度总和、样本数
metric=Accumulator(3)
for x,y in train_iter:
#计算梯度并更新参数
y_hat=net(x)
l=loss(y_hat,y)
if isinstance(updater,torch.optim.Optimizer):
#使用Pytorch内置的优化器和损失函数
updater.zero_grad()
l.backward()
updater.step()
metric.add(float(l)*len(y),accuracy(y_hat,y),y.numel())
else:
#使用定制的优化器和损失函数
l.sum().backward()
updater(x.shape[0])
metric.add(float(l.sum()),accuracy(y_hat,y),y.numel())
#返回训练损失和训练准确率
return metric[0]/metric[2],metric[1]/metric[2]
class Animator:
"""在动画中绘制数据"""
def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,
xscale='linear',yscale='linear',fmts=('-','m--','g-','r:'),
nrows=1,ncols=1,figsize=(3.5,2.5)):
#增量地绘制多条线
if legend is None:
legend=[]
self.fig,self.axes=plt.subplots(nrows,ncols,figsize=figsize)
if nrows*ncols==1:
self.axes=[self.axes,]
#使用lambda函数捕获参数
self.config_axes=lambda:set_axes(
self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)
self.x,self.y,self.fmts=None,None,fmts
def add(self,x,y):
#向图表中添加多个数据点
if not hasattr(y,"__len__"):
y=[y]
n=len(y)
if not hasattr(x,"__len__"):
x=[x]*n
if not self.x:
self.x=[[]for _ in range(n)]
if not self.y:
self.y=[[]for _ in range(n)]
for i,(a,b) in enumerate(zip(x,y)):
if a is not None and b is not None:
self.x[i].append(a)
self.y[i].append(b)
self.axes[0].cla()
for x,y,fmt in zip(self.x,self.y,self.fmts):
self.axes[0].plot(x,y,fmt)
self.config_axes()
if len(self.y[0])==num_epochs:
plt.show()
display.display(self.fig)
display.clear_output(wait=True)
def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):
"""训练模型(定义见第三章)"""
animal=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],
legend=['train loss','train acc','test acc'])
for epoch in range(num_epochs):
train_metrics=train_epoch_ch3(net,train_iter,loss,updater)
test_acc=evaluate_accuacy(net,test_iter)
animal.add(epoch+1,train_metrics+(test_acc,))
train_loss,train_acc=train_metrics
assert train_loss<0.5, train_loss
assert train_acc<=1 and train_acc>0.7,train_acc
assert test_acc<=1 and test_acc>0.7,test_acc
def sgd(params,lr,batch_size):
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param-=lr*param.grad/batch_size
param.grad.zero_()
def updater(batch_size):
return sgd([w,b],lr,batch_size)
def predict_ch3(net,test_iter,n=6):
"""预测标签(定义见第三章)"""
for x,y in test_iter:
break
trues=get_fashion_mnist_labels(y)
preds=get_fashion_mnist_labels(net(x).argmax(axis=1))
titles=[true+'\n'+pred for true,pred in zip(trues,preds)]
show_images(x[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])
plt.show()
if __name__ == "__main__":
batch_size = 256
train_iter,test_iter=load_data_fashion_mnist(batch_size)
#初始化模型参数
num_inputs=784
num_outputs=10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)
b= torch.zeros(num_outputs, requires_grad=True)
lr=0.1
num_epochs=10
train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)
predict_ch3(net,test_iter)