In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms

In [28]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5,self).__init__()
        
        self.conv_unit = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
            nn.Conv2d(16,32,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),   
        )
        self.fc_unit = nn.Sequential(
            nn.Linear(32*5*5,120),
            nn.ReLU(inplace=True),
            nn.Linear(120,84),
            nn.ReLU(inplace=True),
            nn.Linear(84,32),
            nn.ReLU(inplace=True),
            nn.Linear(32,10),
            nn.ReLU(inplace=True),
        )
        
    def forward(self,x):
        batch_size = x.size(0)
        x = self.conv_unit(x)
        x = x.view(batch_size,32*5*5)
        logits = self.fc_unit(x)
        return logits
    
# def main():
#     model = LeNet5()
#     x = torch.randn(2,3,32,32)
#     logits = model(x)
#     print(logits.shape)
    
# if '__main__'==__name__:
#     main()

In [29]:
class ResnetBlk(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(ResnetBlk,self).__init__()
        self.unit = nn.Sequential(
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
        )
        self.extra = nn.Sequential()
        if ch_out!=ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1),
                nn.BatchNorm2d(ch_out),
            )
            
    def forward(self,x):
        out = self.unit(x)
        return self.extra(x) + out

In [30]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18,self).__init__()
        self.unit = nn.Sequential(
            nn.Conv2d(3,16,stride=3,kernel_size=3,padding=0),
            nn.BatchNorm2d(16),
        )
        self.bltk1=ResnetBlk(16,32)
        self.bltk2=ResnetBlk(32,64)
        self.bltk3=ResnetBlk(64,128)
        self.bltk4=ResnetBlk(128,256)
        self.outlayer=nn.Linear(256*10*10,10)
        
    def forward(self,x):
        x = F.relu(self.unit(x))
        x = self.bltk1(x)
        x = self.bltk2(x)
        x = self.bltk3(x)
        x = self.bltk4(x)
        x = x.view(x.size(0),-1)
        x = self.outlayer(x)
        
        return x

In [31]:
# def main():
#     blk = ResnetBlk(64,128)
#     tmp = torch.randn(2,64,32,32)
#     out = blk(tmp)
#     print(out.shape)
    
#     model = ResNet18()
#     tmp = torch.randn(2,3,32,32)
#     out = model(tmp)
#     print('resnet:',out.shape)
    

# if __name__=='__main__':
#     main()

torch.Size([2, 128, 32, 32])
resnet: torch.Size([2, 10])


In [36]:
def main():
    batch_size = 32
    cifar_train = datasets.CIFAR10('cifar',True,transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download = True)
    cifar_train = torch.utils.data.DataLoader(cifar_train,batch_size=batch_size,shuffle=True)
    cifar_test = datasets.CIFAR10('cifar',False,transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download = True)
    cifar_test = torch.utils.data.DataLoader(cifar_test,batch_size=batch_size,shuffle=True)
    x,label = iter(cifar_train).next()
#     torch.Size([32, 3, 32, 32])
#     torch.Size([32])
    device = torch.device('cuda:0')
    model = ResNet18().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr = 3e-4)
    
    for epoch in range(1000):
        model.train()
        for idx,(train,target) in enumerate(cifar_train):
            train,target = train.to(device),target.to(device)
            logits = model(train)
            loss = criterion(logits,target)
            optim.zero_grad()
            loss.backward()
            optim.step()
        print(epoch,'loss:',loss.item())
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x,label in cifar_test:
                #[b,3,32,32]
                #[b]
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                pred = logits.argmax(dim=1)
                total_correct += torch.eq(pred,target).float().sum().item()
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'acc:', acc)


if __name__ == '__main__':
    main()

In [37]:
main()

Files already downloaded and verified
Files already downloaded and verified


RuntimeError: Error attempting to use dtype torch.float32 with layout torch.strided and device type CUDA.  Torch not compiled with CUDA enabled.
