In [1]:
import os
import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

In [2]:
class CNNModel(nn.Module):
    def __init__(self,in_ch,num_cl,hidden_ch=32):
        super(CNNModel,self).__init__()
        self.Model=nn.Sequential(
        self.block(in_ch,hidden_ch,final=False),
        self.block(hidden_ch,hidden_ch*2,final=False),
        self.block(hidden_ch*2,num_cl,final=True)
        )
        
    def block(self,in_ch,out_ch,final):
        if not final:
            return nn.Sequential(
            nn.Conv2d(in_ch,out_ch,kernel_size=5,stride=1,padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.MaxPool2d((2,2))
            )
        else:
            return nn.Sequential(
            nn.Conv2d(in_ch,out_ch,kernel_size=7,stride=1,padding=1),
            nn.Sigmoid()
            )
    def forward(self,x):
        x=self.Model(x)
        return x.reshape(x.shape[0],-1)

In [3]:
x=torch.randn(10,1,28,28)
model=CNNModel(1,10)
print(model(x).shape)

torch.Size([10, 10])


In [4]:
dataset=MNIST('.',download=False,train=False,transform=transforms.ToTensor())
data=DataLoader(dataset,batch_size=100,shuffle=True)
print(dataset)

Dataset MNIST
    Number of datapoints: 10000
    Root location: .
    Split: Test
    StandardTransform
Transform: ToTensor()


In [5]:
epoch=1
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=.001)

In [6]:
def accuracy(data,model):
    correct=0
    total=0
    model.eval()
    with torch.no_grad():
        for images,labels in tqdm(data):
            preds=model(images)
            values,index=preds.max(1)
            correct+=(index==labels).sum()
            total+=preds.size(0)
            acc=correct/total
        print(acc)

In [7]:
for e in range(epoch):
    for images,labels in tqdm(data):
        preds=model(images)
        loss=criterion(preds,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
torch.save(model.state_dict(),os.path.join('F:','info.pth'))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [8]:
accuracy(data,model)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


tensor(0.9653)
