In [None]:
from d2l import torch as dl

In [None]:
dl.try_gpu()

In [None]:
import torch

In [None]:
torch.cuda.is_available()

In [None]:
torch.__version__

In [None]:
torch.rand((1,2,4)).to('cuda')

In [None]:
from torch import nn

# model

In [112]:
model = nn.Sequential(
    nn.BatchNorm2d(1),
    nn.Conv2d(1,16,kernel_size=(5,5)),
    nn.ReLU(),
    nn.MaxPool2d((2,2)),
    nn.Conv2d(16,32,kernel_size=(3,3)),
    nn.ReLU(),
    nn.MaxPool2d((2,2)),
    nn.Conv2d(32,64,kernel_size=(3,3)),
    nn.ReLU(),
    nn.MaxPool2d((2,2)),
    nn.Flatten(),
    nn.Linear(64,10),
    nn.Softmax()
)


In [113]:
print(model)

Sequential(
  (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (2): ReLU()
  (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (5): ReLU()
  (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=64, out_features=10, bias=True)
  (12): Softmax(dim=None)
)


In [114]:
x = model(torch.rand((32,1,28,28)))

  return self._call_impl(*args, **kwargs)


In [None]:
x.shape # == (1,10)

torch.Size([32, 10])

# dataset

In [100]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [None]:
ds = MNIST('./',train=True,transform=ToTensor(), download=True)

In [None]:
ds_loader = DataLoader(ds, 32, True)

In [None]:
i = ds_loader.__iter__()

# training

In [None]:
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(),lr=1e-2)

In [None]:
from torch import Tensor


In [119]:

def train(model, loss, optim, ds_loader:DataLoader, epoch=1):
    # param
    num_batches = len(ds_loader.dataset) // ds_loader.batch_size
    
    for i in range(epoch):
        for batch_idx,(X,y) in enumerate(ds_loader):
            # process y -> []*10
            labels = []
            for j in range(32):
                label = [0]*10
                label[y.numpy()[j]] = 1
                label = Tensor(label)
                labels.append(label)
            labels = torch.stack(labels)
            # print(X.shape)
            # return
            assert type(labels)==torch.Tensor, 'labels must be tensor'
            # train
            pred = model.forward(X)
            batch_loss = loss(pred, labels)
            print('[{}:{}]loss={:>7}'.format(batch_idx, num_batches, batch_loss.item()))
            batch_loss.backward()
            optim.step()
            optim.zero_grad()
    print('finish')


In [120]:
train(model,loss_fn, optim, ds_loader, 2)

  return self._call_impl(*args, **kwargs)


[0:1875]loss=2.2999894618988037
[1:1875]loss=2.30553936958313
[2:1875]loss=2.3061225414276123
[3:1875]loss=2.3032851219177246
[4:1875]loss=2.3010573387145996
[5:1875]loss=2.299762487411499
[6:1875]loss=2.306690216064453
[7:1875]loss=2.3008933067321777
[8:1875]loss=2.3019018173217773
[9:1875]loss=2.3039231300354004
[10:1875]loss=2.3017737865448
[11:1875]loss=2.3033289909362793
[12:1875]loss=2.3053090572357178
[13:1875]loss=2.301856756210327
[14:1875]loss=2.2936551570892334
[15:1875]loss=2.300687074661255
[16:1875]loss=2.3001770973205566
[17:1875]loss=2.3032355308532715
[18:1875]loss=2.297631025314331
[19:1875]loss=2.3025012016296387
[20:1875]loss=2.3092782497406006
[21:1875]loss=2.298928737640381
[22:1875]loss=2.30399489402771
[23:1875]loss=2.3025026321411133
[24:1875]loss=2.2983405590057373
[25:1875]loss=2.3007075786590576
[26:1875]loss=2.304231643676758
[27:1875]loss=2.3022918701171875
[28:1875]loss=2.301826000213623
[29:1875]loss=2.3013992309570312
[30:1875]loss=2.3072447776794434
[3

In [124]:
s = model.state_dict()
torch.save(s,'./model.pt')

In [125]:
models = model

In [128]:
models.load_state_dict(torch.load('./model.pt'))

<All keys matched successfully>