<a href="https://colab.research.google.com/github/liangjieyu123/CropIntensity/blob/main/%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E8%AF%86%E5%88%AB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize

In [2]:
batch_size=128
def get_dataloader(train=True):
  transform_fn=Compose([
      ToTensor(),
      Normalize(mean=(0.1307,),std=(0.3081,))
  ])
  dataset = MNIST(root="./data",train=train,download=True,transform=transform_fn)
  data_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
  return data_loader

In [3]:
class MnistNet(nn.Module):
  def __init__(self):
    super(MnistNet,self).__init__()
    self.fc1=nn.Linear(28*28*1,28)
    self.fc2=nn.Linear(28,10)

  def forward(self,x):
    x=x.view([-1,28*28*1])
    x=self.fc1(x)
    x=F.relu(x)
    x=self.fc2(x)
    out=F.log_softmax(x,dim=-1)
    return out

In [4]:
if torch.cuda.is_available():
  print("Cuda is available")
else:
  print("Cuda is unavailable")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Cuda is unavailable


In [5]:
from torch.optim import Adam
mnist_net=MnistNet().to(device)
optimizer = Adam(mnist_net.parameters(),lr=0.001)
def train(epoch):
  mode=True
  mnist_net.train(mode=mode)

  data_loader = get_dataloader()
  for idx,(input,traget) in enumerate(data_loader):
    input = input.to(device)
    traget = traget.to(device)
    output = mnist_net(input)
    loss = F.nll_loss(output,traget)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if idx%1000 ==0:
      print(epoch,idx,loss.item())

In [None]:
# 模型保存
torch.save(mnist_net.state_dict(),"save/mnist_net.pt")  #保存模型参数
torch.save(optimizer.state_dict(),"save/mnist_optimizer.pt")

In [None]:
# 模型加载
mnist_net.load_state_dict(torch.load("data/mnist_net.pt"))
optimizer.load_state_dict(torch.load("data/mnist_optimizer.pt"))

In [6]:
if __name__=='__main__':
  for i in range(3):
    train(i)
  # test()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

0 0 2.3094167709350586
1 0 0.21299728751182556
2 0 0.21546868979930878


In [7]:
# 模型测试
def test():
  loss_list = []
  acc_list = []
  test_dataloader = get_dataloader(train=False)
  for idx,(input,traget) in enumerate(test_dataloader):
    with torch.no_grad():
      output = mnist_net(input)
      cur_loss = F.nll_loss(output,traget)
      loss_list.append(cur_loss)
      # 计算准确率
      pred = output.max(dim=-1)[-1]
      cur_acc = pred.eq(traget).float().mean()
      acc_list.append(cur_acc)
  print("平均精度：",np.mean(acc_list),"平均损失：",np.mean(loss_list)) 

test()


平均精度： 0.9478837 平均损失： 0.17319018


In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

path = "/content/drive/MyDrive/model/pytorch-lightning-unet/train_img"
os.chdir(path)
os.listdir(path)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

img = Image.open("BJ2000520171218RGB_2_img.png")
print(img.mode)
plt.figure("img")
plt.imshow(img)
plt.show()

import os
from google.colab import drive
drive.mount('/content/drive')

path = "/content/drive/MyDrive/model/pytorch-lightning-unet/train_img"
os.chdir(path)
os.listdir(path)
