In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10 
import torch.nn.functional as F
import os
from torch import optim
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchinfo  import summary

In [2]:
transform = transforms.Compose([
    transforms.ToTensor()
])


train_data = CIFAR10(
    root = os.getcwd(),
    download=True,
    train= True,
    transform = transform
)

test_data = CIFAR10(
    root = os.getcwd(),
    download=True,
    train= False,
    transform = transform
)



train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class Resnet(nn.Module):

    def __init__(self,n_classes):
        super().__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        )

        self.layer1 = ResLayer(64,128,3)
        self.layer2 = ResLayer(128,256,4,stride = 2)
        self.layer3 = ResLayer(256,512,4,stride = 2)
        self.layer4 = ResLayer(512,512,4,stride = 2)
        self.linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512,n_classes)
        )

    def forward(self,X):

        x = self.pre(X)

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        x = self.layer4(x) 

        return self.linear(x)


class ResBlock(nn.Module):
    def __init__(self, in_channels,out_channels,stride=1,downsample = False):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size = 3,stride= stride,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Conv2d(out_channels,out_channels,kernel_size = 1,stride= 1,padding=0,bias=False),
            nn.BatchNorm2d(out_channels)
        )
        self.downsample = None
        if downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=1,stride =stride,bias= False),
                nn.BatchNorm2d(out_channels)
                )

    def forward(self,X):

        x = self.f(X)

        res = X
        if self.downsample is not None:
            res = self.downsample(res)

        x = x+res

        return F.gelu(x)


class ResLayer(nn.Module):
    def __init__(self,in_channels,out_channels,n_blocks,stride =1):
        super().__init__()
        self.f = self.make_layer(in_channels,out_channels,n_blocks,stride)
    def make_layer(self,in_channels,out_channels,n_blocks,stride):
        layers = []
        layers.append(ResBlock(in_channels,out_channels,stride=stride,downsample=True))
        
        for i in range(1,n_blocks):

            layers.append(ResBlock(out_channels,out_channels))

        return nn.Sequential(*layers)


        
    def forward(self,X):
        x = self.f(X)
        return x

In [4]:
model = Resnet(10)

In [7]:
model = Resnet(10).cuda()
optimizers = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(label_smoothing =.01)

i = 0
for epoch in range(10):
    for X,y in tqdm(train_dataloader):
        optimizers.zero_grad()
        pred = model(X.cuda())
        loss = loss_fn(pred,y.cuda())
        if i%500==0:
            print(loss.item())
        loss.backward()
        optimizers.step()
        i+=1

  0%|          | 0/3125 [00:00<?, ?it/s]

2.633793592453003
1.8164016008377075
1.670569658279419
1.9431865215301514
1.6359196901321411
1.9278552532196045
1.4356526136398315


  0%|          | 0/3125 [00:00<?, ?it/s]

1.5709526538848877
1.5933587551116943
1.3135788440704346
1.5626775026321411
1.7421927452087402
1.335776448249817


  0%|          | 0/3125 [00:00<?, ?it/s]

1.170583963394165
1.1138091087341309
0.8874974846839905
1.9145245552062988
0.8875733017921448
1.4602675437927246


  0%|          | 0/3125 [00:00<?, ?it/s]

1.0584009885787964
1.107518196105957
1.6755684614181519
0.8177648782730103
0.7844170928001404
1.4553511142730713


  0%|          | 0/3125 [00:00<?, ?it/s]

0.9191901683807373
0.8207089304924011
0.7553139328956604
0.9718639850616455
1.1178653240203857
0.38493332266807556
0.9778648018836975


  0%|          | 0/3125 [00:00<?, ?it/s]

0.5056002140045166
0.41254478693008423
0.7643025517463684
0.6467880606651306
0.7527156472206116
0.8349853157997131


  0%|          | 0/3125 [00:00<?, ?it/s]

0.6749799847602844
0.929689884185791
0.7965520620346069
0.6235604882240295
0.5792805552482605
0.45669227838516235


  0%|          | 0/3125 [00:00<?, ?it/s]

0.4719741940498352
0.683571994304657
1.4468988180160522
0.8766253590583801
0.8353974223136902
0.7703700065612793


  0%|          | 0/3125 [00:00<?, ?it/s]

0.23359215259552002
0.14869390428066254
0.24234870076179504
0.27703550457954407
0.36324381828308105
0.44614049792289734
0.2456749677658081


  0%|          | 0/3125 [00:00<?, ?it/s]

0.4411335587501526
0.1990174502134323
0.2498278021812439
0.49793684482574463
0.13435281813144684
0.1390952616930008
