In [30]:
import torch
from torchsummary import summary
import albumentations as A
from albumentations.pytorch import ToTensorV2


from data import CIFAR10WithAlbumentations
from custom_resnet import Net
from utils_train import train, test

In [32]:
means = [0.4914, 0.4822, 0.4465]
stds = [0.2470, 0.2435, 0.2616]

train_transforms = A.Compose(
    [
        A.Normalize(mean=means, std=stds, always_apply=True),
        A.PadIfNeeded(min_height=4, min_width=4, always_apply=True),
        A.RandomCrop(height=32, width=32, always_apply=True),
        A.HorizontalFlip(),
        A.CoarseDropout(max_holes=1, max_height=8, max_width=8, min_holes=1, min_height=8, min_width=8, fill_value=means),
        ToTensorV2(),
    ]
)

test_transforms = A.Compose(
    [
        A.Normalize(mean=means, std=stds, always_apply=True),
        ToTensorV2(),
    ]
)

train_ds = CIFAR10WithAlbumentations('./data', train=True, download=True, transform=train_transforms)
test_ds = CIFAR10WithAlbumentations('./data', train=False, download=True, transform=test_transforms)


Files already downloaded and verified
Files already downloaded and verified


In [33]:
SEED = 1
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=512, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# train dataloader
train_loader = torch.utils.data.DataLoader(train_ds, **dataloader_args)

# test dataloader
test_loader = torch.utils.data.DataLoader(test_ds, **dataloader_args)


CUDA Available? True


In [23]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

model = Net().to(device)
summary(model, input_size=(3, 32, 32))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
              ReLU-2           [-1, 64, 32, 32]               0
       BatchNorm2d-3           [-1, 64, 32, 32]             128
           Dropout-4           [-1, 64, 32, 32]               0
            Conv2d-5          [-1, 128, 32, 32]          73,728
              ReLU-6          [-1, 128, 32, 32]               0
       BatchNorm2d-7          [-1, 128, 32, 32]             256
           Dropout-8          [-1, 128, 32, 32]               0
         MaxPool2d-9          [-1, 128, 16, 16]               0
           Conv2d-10          [-1, 128, 16, 16]         147,456
             ReLU-11          [-1, 128, 16, 16]               0
      BatchNorm2d-12          [-1, 128, 16, 16]             256
          Dropout-13          [-1, 128, 16, 16]               0
           Conv2d-14          [-1,

In [41]:
train_losses = []
test_losses = []
train_acc = []
test_acc = []

In [None]:
model =  Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# scheduler = CosineAnnealingLR(optimizer, T_max=100, verbose=True)

EPOCHS = 24
for epoch in range(EPOCHS):
    print("EPOCH:", epoch)
    train(model, device, train_loader, optimizer, epoch, train_losses, train_acc)
    # scheduler.step()
    test(model, device, test_loader, test_losses, test_acc)


EPOCH: 0


Loss=1.4879919290542603 Batch_id=97 Accuracy=30.77: 100%|██████████| 98/98 [00:25<00:00,  3.83it/s]



Test set: Average loss: 1.6490, Accuracy: 4050/10000 (40.50%)

EPOCH: 1


Loss=1.4426014423370361 Batch_id=97 Accuracy=48.76: 100%|██████████| 98/98 [00:26<00:00,  3.68it/s]



Test set: Average loss: 1.3136, Accuracy: 5474/10000 (54.74%)

EPOCH: 2


Loss=1.143445372581482 Batch_id=97 Accuracy=58.37: 100%|██████████| 98/98 [00:27<00:00,  3.57it/s] 



Test set: Average loss: 1.0598, Accuracy: 6248/10000 (62.48%)

EPOCH: 3


Loss=0.8664868474006653 Batch_id=97 Accuracy=66.15: 100%|██████████| 98/98 [00:27<00:00,  3.61it/s]



Test set: Average loss: 0.9966, Accuracy: 6523/10000 (65.23%)

EPOCH: 4


Loss=0.8230717182159424 Batch_id=97 Accuracy=71.22: 100%|██████████| 98/98 [00:27<00:00,  3.62it/s]



Test set: Average loss: 0.8641, Accuracy: 7140/10000 (71.40%)

EPOCH: 5


Loss=0.6936295628547668 Batch_id=97 Accuracy=76.25: 100%|██████████| 98/98 [00:27<00:00,  3.62it/s]



Test set: Average loss: 0.8161, Accuracy: 7229/10000 (72.29%)

EPOCH: 6


Loss=0.621336042881012 Batch_id=97 Accuracy=78.88: 100%|██████████| 98/98 [00:27<00:00,  3.62it/s]  



Test set: Average loss: 0.6605, Accuracy: 7834/10000 (78.34%)

EPOCH: 7


Loss=0.503520667552948 Batch_id=97 Accuracy=81.38: 100%|██████████| 98/98 [00:27<00:00,  3.62it/s]  



Test set: Average loss: 0.6734, Accuracy: 7825/10000 (78.25%)

EPOCH: 8


Loss=0.4869544506072998 Batch_id=97 Accuracy=83.06: 100%|██████████| 98/98 [00:27<00:00,  3.58it/s] 



Test set: Average loss: 0.6062, Accuracy: 8007/10000 (80.07%)

EPOCH: 9


Loss=0.4928434491157532 Batch_id=97 Accuracy=84.91: 100%|██████████| 98/98 [00:27<00:00,  3.52it/s] 



Test set: Average loss: 0.5615, Accuracy: 8202/10000 (82.02%)

EPOCH: 10


Loss=0.34550741314888 Batch_id=97 Accuracy=86.51: 100%|██████████| 98/98 [00:27<00:00,  3.59it/s]   



Test set: Average loss: 0.5557, Accuracy: 8253/10000 (82.53%)

EPOCH: 11


Loss=0.2929704189300537 Batch_id=97 Accuracy=87.82: 100%|██████████| 98/98 [00:26<00:00,  3.65it/s] 



Test set: Average loss: 0.5466, Accuracy: 8307/10000 (83.07%)

EPOCH: 12


Loss=0.40034154057502747 Batch_id=97 Accuracy=89.01: 100%|██████████| 98/98 [00:27<00:00,  3.62it/s]



Test set: Average loss: 0.5438, Accuracy: 8361/10000 (83.61%)

EPOCH: 13


Loss=0.2906898856163025 Batch_id=97 Accuracy=89.57: 100%|██████████| 98/98 [00:27<00:00,  3.54it/s] 



Test set: Average loss: 0.5418, Accuracy: 8366/10000 (83.66%)

EPOCH: 14


Loss=0.3396380841732025 Batch_id=97 Accuracy=90.94: 100%|██████████| 98/98 [00:27<00:00,  3.60it/s] 



Test set: Average loss: 0.5316, Accuracy: 8383/10000 (83.83%)

EPOCH: 15


Loss=0.21374957263469696 Batch_id=97 Accuracy=91.69: 100%|██████████| 98/98 [00:26<00:00,  3.65it/s]



Test set: Average loss: 0.5086, Accuracy: 8552/10000 (85.52%)

EPOCH: 16


Loss=0.2035236805677414 Batch_id=97 Accuracy=92.62: 100%|██████████| 98/98 [00:26<00:00,  3.65it/s] 



Test set: Average loss: 0.5352, Accuracy: 8512/10000 (85.12%)

EPOCH: 17


Loss=0.30982136726379395 Batch_id=97 Accuracy=92.69: 100%|██████████| 98/98 [00:27<00:00,  3.57it/s]



Test set: Average loss: 0.5392, Accuracy: 8524/10000 (85.24%)

EPOCH: 18


Loss=0.29852205514907837 Batch_id=97 Accuracy=93.54: 100%|██████████| 98/98 [00:27<00:00,  3.54it/s]



Test set: Average loss: 0.5132, Accuracy: 8624/10000 (86.24%)

EPOCH: 19


Loss=0.17050530016422272 Batch_id=97 Accuracy=93.73: 100%|██████████| 98/98 [00:26<00:00,  3.63it/s]



Test set: Average loss: 0.5380, Accuracy: 8553/10000 (85.53%)

EPOCH: 20


Loss=0.197427436709404 Batch_id=21 Accuracy=95.01:  22%|██▏       | 22/98 [00:06<00:20,  3.70it/s]  