In [1]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam

import numpy as np
#from tqdm import tqdm
from qqdm.notebook import qqdm

import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [3]:
def get_dataloader(train,batch_size=128):
    transform_fu = Compose([
        ToTensor(),
        Normalize(
            (0.1307,), (0.3081,))
    ])
    dataset = MNIST(root='./data', train=train, download=True, transform=transform_fu)
    data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
    return data_loader

train_batch_size = 128
test_batch_size = 1000
train_dataloader = get_dataloader(train=True,batch_size=train_batch_size)
test_dataloader = get_dataloader(train=False, batch_size=test_batch_size)

In [None]:
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel,self).__init__()
        self.fc1 = nn.Linear(1*28*28,100)
        self.fc2 = nn.Linear(100,10)

    def forward(self, image):
        image_viwed = image.view(-1,1*28*28)
        fc1_out = self.fc1(image_viwed)
        fc1_out_relu = F.relu(fc1_out)
        out = self.fc2(fc1_out_relu)
        return F.log_softmax(out,dim=-1)

#1. 实例化模型，优化器，损失函数
model = MnistModel()
optimizer = Adam(model.parameters(),lr=0.001)
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

# if os.path.exists("./models/model.pkl"):
#     model.load_state_dict(torch.load("./models/model.pkl"))
#     optimizer.load_state_dict(torch.load("./models/optimizer.pkl"))


def train(epoch):
    total_loss = []
    loop = qqdm(train_dataloader)
    for index,(input,target) in enumerate(loop):
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output,target)
        loss.backward()
        total_loss.append(loss.item())
        optimizer.step()
        loop.set_infos({
            'Loss': round(torch.rand(1).item(),4),
            'Accuracy':round(torch.rand(1).item(),4),
            'Epoch': epoch+1,
        })

def test():
    total_loss = []
    total_acc = []
    with torch.no_grad():
        loop2 = qqdm(test_dataloader)
        for index, (input, target) in enumerate(loop2):
            output = model(input)
            loss = criterion(output,target)
            total_loss.append(loss.item())
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())
            loop2.set_infos({
                'Test-Loss': round(np.mean(total_loss),4),
                'Test-Accuracy':round(np.mean(total_acc),4),
            })

if __name__ == '__main__':
    epochs=20
    for epoch in range(epochs):
        train(epoch)
        test()

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m38.90it/s[0m  [99m0.0974[0m   [99m0.5336[0m     [99m1[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.38it/s[0m    [99m0.171[0m       [99m0.9498[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m38.39it/s[0m  [99m0.3596[0m   [99m0.759[0m      [99m2[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.39it/s[0m   [99m0.1246[0m       [99m0.9608[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m37.89it/s[0m  [99m0.6343[0m   [99m0.7843[0m     [99m3[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:02<[93m00:00:00[0m[0m  [99m4.94it/s[0m   [99m0.0965[0m       [99m0.9711[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m37.79it/s[0m  [99m0.3695[0m   [99m0.8226[0m     [99m4[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:02<[93m00:00:00[0m[0m  [99m4.79it/s[0m   [99m0.0894[0m       [99m0.9731[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:11<[93m00:00:00[0m[0m  [99m39.31it/s[0m  [99m0.4872[0m   [99m0.6008[0m     [99m5[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.31it/s[0m    [99m0.085[0m       [99m0.9729[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:11<[93m00:00:00[0m[0m  [99m39.23it/s[0m  [99m0.8408[0m   [99m0.686[0m      [99m6[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.84it/s[0m   [99m0.0781[0m       [99m0.9765[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m38.91it/s[0m  [99m0.9579[0m   [99m0.7185[0m     [99m7[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.66it/s[0m   [99m0.0843[0m       [99m0.9727[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m274/[93m469[0m[0m  [99m00:00:07<[93m00:00:05[0m[0m  [99m38.06it/s[0m  [99m0.8204[0m   [99m0.1125[0m     [99m8[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.59it/s[0m   [99m0.0803[0m        [99m0.976[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F  [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m     [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                 
 [99m469/[93m469[0m[0m  [99m00:00:12<[93m00:00:00[0m[0m  [99m38.09it/s[0m  [99m0.6094[0m   [99m0.9392[0m     [99m9[0m                   

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mTest-Loss[0m  [1mTest-Accuracy[0m                   
 [99m10/[93m10[0m[0m  [99m00:00:01<[93m00:00:00[0m[0m  [99m5.08it/s[0m   [99m0.0747[0m       [99m0.9771[0m                       

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))

[K[F [1mIters[0m     [1mElapsed Time[0m       [1mSpeed[0m    [1mLoss[0m   [1mAccuracy[0m  [1mEpoch[0m                   
 [99m15/[93m469[0m[0m  [99m00:00:00<[93m00:00:12[0m[0m  [99m35.64it/s[0m  [99m0.768[0m   [99m0.4947[0m    [99m10[0m                     

IpythonBar(children=(HTML(value='  0.0%'), FloatProgress(value=0.0)))