In [52]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [53]:
trainset, validset = torch.utils.data.random_split(
    torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform = transform),
    lengths=[40000, 10000])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(validset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale',
      'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',
      'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',
      'bottles', 'bowls', 'cans', 'cups', 'plates',
      'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',
      'clock', 'computer keyboard', 'lamp', 'telephone', 'television',
      'bed', 'chair', 'couch', 'table', 'wardrobe',
      'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',
      'bear', 'leopard', 'lion', 'tiger', 'wolf',
      'bridge', 'castle', 'house', 'road', 'skyscraper',
      'cloud', 'forest', 'mountain', 'plain', 'sea',
      'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',
      'fox', 'porcupine', 'possum', 'raccoon', 'skunk',
      'crab', 'lobster', 'snail', 'spider', 'worm',
      'baby', 'boy', 'girl', 'man', 'woman',
      'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',
      'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',
      'maple', 'oak', 'palm', 'pine', 'willow',
      'bicycle', 'bus', 'motorcycle', 'pickup' 'truck', 'train',
      'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor')


Files already downloaded and verified
Files already downloaded and verified


In [54]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels,out_channels,stride=stride,kernel_size = 1,bias = False),nn.BatchNorm2d(out_channels))
            
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(identity)
        out = self.relu(out)
        
        return out
    
class Resnet18(nn.Module):
    def __init__(self, num_classes = 100):
        super(Resnet18,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.layer1 = self._make_layer(64,64,2,stride = 1)
        self.layer2 = self._make_layer(64,128,2,stride = 2)
        self.layer3 = self._make_layer(128,256,2,stride = 2)
        self.layer4 = self._make_layer(256,512,2,stride = 2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, blocks, stride = 1):
        layer = []
        layer.append(ResidualBlock(in_channels,out_channels,stride))
        for _ in range(1,blocks):
            layer.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layer)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        
        return out
net = Resnet18().to(device)

# optimizer and loss function
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)


In [55]:
# start training
#from tqdm.notebook import tqdm
#progress_bar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc='Training' )

for epoch in range(30):  # epoch
    running_loss = 0.0
    print(epoch)
    for i, data in (enumerate(trainloader, 0)):

        inputs, labels = data[0].to(device), data[1].to(device)

        # zero gradient
        optimizer.zero_grad()

        # forward, backward, optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        #wandb.log({"loss": loss})
        #progress_bar.update(1)
        running_loss += loss.item()
        if i % 250 == 0:
            # validation
            net.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for data in validloader:
                    images, labels = data[0].to(device), data[1].to(device)
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            #progress_bar.set_postfix(accuracy=100 * correct / total)
            #wandb.log({"accuracy": correct/total})
        running_loss = 0.0

print('Finished Training')
torch.save(net, 'ResNet18CIFAR100.pt')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
Finished Training


In [56]:
# test network on test set
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

# calculate accuracy
class_correct = list(0. for i in range(100))
class_total = list(0. for i in range(100))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(100):
    if class_total[i] != 0:
        print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
    else:
        print('Accuracy of %5s : %2d %%' % (classes[i], 0))


Accuracy of the network on the 10000 test images: 37 %
Accuracy of beaver : 80 %
Accuracy of dolphin : 14 %
Accuracy of otter : 33 %
Accuracy of  seal : 33 %
Accuracy of whale : 33 %
Accuracy of aquarium fish : 42 %
Accuracy of flatfish :  0 %
Accuracy of   ray :  0 %
Accuracy of shark : 16 %
Accuracy of trout : 66 %
Accuracy of orchids : 25 %
Accuracy of poppies : 40 %
Accuracy of roses : 33 %
Accuracy of sunflowers : 16 %
Accuracy of tulips :  0 %
Accuracy of bottles : 12 %
Accuracy of bowls : 28 %
Accuracy of  cans : 33 %
Accuracy of  cups : 14 %
Accuracy of plates : 28 %
Accuracy of apples : 75 %
Accuracy of mushrooms : 42 %
Accuracy of oranges :  0 %
Accuracy of pears : 50 %
Accuracy of sweet peppers : 50 %
Accuracy of clock : 20 %
Accuracy of computer keyboard : 20 %
Accuracy of  lamp :  0 %
Accuracy of telephone : 40 %
Accuracy of television : 25 %
Accuracy of   bed : 25 %
Accuracy of chair : 28 %
Accuracy of couch : 44 %
Accuracy of table : 50 %
Accuracy of wardrobe : 25 %
Accu