In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
use_wandb = True

In [3]:
from ds_toolkit.general_utils.gpu_utils import gpu_alloc

gpu_alloc()

Setting proc title..
Process title :  MUHSIN_16-Nov_01-AM_NO_JOB_NAME
Searching for GPUs..
CUDA environment device set to 1
GPU allocation finished.


In [4]:
if use_wandb:
    import wandb
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33maskmuhsin[0m (use `wandb login --relogin` to force relogin)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary

In [6]:
from datasets.cifar10 import trainloader, testloader
from utils.general_utils import setup_env
from utils.training import train_model
from utils.testing import test_model

Files already downloaded and verified
Total number of images:  50000
Mean =  (0.4890062, 0.47970363, 0.47680542) 
 STD =  (0.264582, 0.258996, 0.25643882)
Files already downloaded and verified
Files already downloaded and verified
Batch Size -- 512


In [7]:
from models.model_v7 import Net

In [8]:
cuda, device = setup_env()

train_dataloader, test_dataloader = trainloader, testloader

model = Net(
    batch_norm=True,
    dropout_value=0.0,
).to(device)
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             864
              ReLU-2           [-1, 32, 32, 32]               0
       BatchNorm2d-3           [-1, 32, 32, 32]              64
           Dropout-4           [-1, 32, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          18,432
              ReLU-6           [-1, 64, 32, 32]               0
       BatchNorm2d-7           [-1, 64, 32, 32]             128
           Dropout-8           [-1, 64, 32, 32]               0
        conv_block-9           [-1, 64, 32, 32]               0
           Conv2d-10           [-1, 32, 16, 16]          18,464
           Conv2d-11           [-1, 32, 16, 16]           1,056
             ReLU-12           [-1, 32, 16, 16]               0
      BatchNorm2d-13           [-1, 32, 16, 16]              64
 transition_block-14           [-1, 32,

In [12]:
if use_wandb:

    wandb.init(
        project='s7_cifar10',              ## unique to each project / assignment
        entity='weights_heist_eva7',       ## this will not change
        ## below are optional but recommended 
        tags=['muhsin'],                   ## can help later to filter runs
        name='version 7.2 continue',    ## a name for the run. 
        notes="continue training"
    )

[34m[1mwandb[0m: wandb version 0.12.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = StepLR(optimizer, step_size=12, gamma=0.6)

logs = []
EPOCHS = 40
for epoch in range(EPOCHS):
    print("EPOCH:", epoch)
    train_batch_loss, train_batch_acc = train_model(
        model, device, train_dataloader, optimizer, epoch
    )
    test_loss, test_acc = test_model(model, device, test_dataloader)
    scheduler.step()
    
    temp_log = {
        'train_losses': np.mean(train_batch_loss),
        'test_losses': test_loss,
        'train_acc': np.mean(train_batch_acc),
        'test_acc': test_acc,
        "lr": optimizer.param_groups[0]['lr'],
    }
    logs.append(temp_log)
    
    if use_wandb:
        wandb.log(temp_log)

EPOCH: 0


Loss=0.7177518010139465 Batch_id=97 Accuracy=77.45: 100%|██████████| 98/98 [00:09<00:00, 10.85it/s]



Test set: Average loss: 0.5912, Accuracy: 7989/10000 (79.89%)

EPOCH: 1


Loss=0.7636212110519409 Batch_id=97 Accuracy=77.57: 100%|██████████| 98/98 [00:09<00:00, 10.75it/s]



Test set: Average loss: 0.5593, Accuracy: 8106/10000 (81.06%)

EPOCH: 2


Loss=0.6215888857841492 Batch_id=97 Accuracy=77.64: 100%|██████████| 98/98 [00:09<00:00, 10.71it/s]



Test set: Average loss: 0.5854, Accuracy: 8045/10000 (80.45%)

EPOCH: 3


Loss=0.6967869400978088 Batch_id=97 Accuracy=77.74: 100%|██████████| 98/98 [00:09<00:00, 10.25it/s]



Test set: Average loss: 0.5490, Accuracy: 8118/10000 (81.18%)

EPOCH: 4


Loss=0.5811810493469238 Batch_id=97 Accuracy=77.89: 100%|██████████| 98/98 [00:09<00:00, 10.20it/s]



Test set: Average loss: 0.5350, Accuracy: 8169/10000 (81.69%)

EPOCH: 5


Loss=0.5997220277786255 Batch_id=97 Accuracy=78.69: 100%|██████████| 98/98 [00:08<00:00, 11.15it/s]



Test set: Average loss: 0.5472, Accuracy: 8156/10000 (81.56%)

EPOCH: 6


Loss=0.6460701823234558 Batch_id=97 Accuracy=78.67: 100%|██████████| 98/98 [00:09<00:00, 10.22it/s]



Test set: Average loss: 0.5211, Accuracy: 8254/10000 (82.54%)

EPOCH: 7


Loss=0.6910363435745239 Batch_id=97 Accuracy=78.83: 100%|██████████| 98/98 [00:09<00:00, 10.83it/s]



Test set: Average loss: 0.5505, Accuracy: 8180/10000 (81.80%)

EPOCH: 8


Loss=0.5742940306663513 Batch_id=97 Accuracy=79.35: 100%|██████████| 98/98 [00:08<00:00, 10.98it/s] 



Test set: Average loss: 0.5136, Accuracy: 8276/10000 (82.76%)

EPOCH: 9


Loss=0.6108983755111694 Batch_id=97 Accuracy=79.69: 100%|██████████| 98/98 [00:08<00:00, 11.14it/s] 



Test set: Average loss: 0.5231, Accuracy: 8254/10000 (82.54%)

EPOCH: 10


Loss=0.5643239617347717 Batch_id=97 Accuracy=79.82: 100%|██████████| 98/98 [00:09<00:00, 10.73it/s] 



Test set: Average loss: 0.5085, Accuracy: 8297/10000 (82.97%)

EPOCH: 11


Loss=0.5313037037849426 Batch_id=97 Accuracy=80.06: 100%|██████████| 98/98 [00:09<00:00, 10.63it/s]



Test set: Average loss: 0.5393, Accuracy: 8189/10000 (81.89%)

EPOCH: 12


Loss=0.4850599765777588 Batch_id=97 Accuracy=81.22: 100%|██████████| 98/98 [00:08<00:00, 11.45it/s] 



Test set: Average loss: 0.4672, Accuracy: 8431/10000 (84.31%)

EPOCH: 13


Loss=0.49570515751838684 Batch_id=97 Accuracy=81.75: 100%|██████████| 98/98 [00:08<00:00, 11.11it/s]



Test set: Average loss: 0.4741, Accuracy: 8369/10000 (83.69%)

EPOCH: 14


Loss=0.6148566007614136 Batch_id=97 Accuracy=81.68: 100%|██████████| 98/98 [00:08<00:00, 10.97it/s] 



Test set: Average loss: 0.4642, Accuracy: 8415/10000 (84.15%)

EPOCH: 15


Loss=0.5777533650398254 Batch_id=97 Accuracy=82.03: 100%|██████████| 98/98 [00:09<00:00, 10.88it/s] 



Test set: Average loss: 0.4783, Accuracy: 8401/10000 (84.01%)

EPOCH: 16


Loss=0.5223316550254822 Batch_id=97 Accuracy=82.20: 100%|██████████| 98/98 [00:08<00:00, 11.08it/s] 



Test set: Average loss: 0.4655, Accuracy: 8472/10000 (84.72%)

EPOCH: 17


Loss=0.5484999418258667 Batch_id=97 Accuracy=82.27: 100%|██████████| 98/98 [00:08<00:00, 11.30it/s] 



Test set: Average loss: 0.4612, Accuracy: 8470/10000 (84.70%)

EPOCH: 18


Loss=0.4879567325115204 Batch_id=97 Accuracy=82.19: 100%|██████████| 98/98 [00:08<00:00, 10.94it/s] 



Test set: Average loss: 0.4763, Accuracy: 8389/10000 (83.89%)

EPOCH: 19


Loss=0.4811389446258545 Batch_id=97 Accuracy=82.31: 100%|██████████| 98/98 [00:09<00:00, 10.12it/s] 



Test set: Average loss: 0.4766, Accuracy: 8454/10000 (84.54%)

EPOCH: 20


Loss=0.5073748826980591 Batch_id=97 Accuracy=82.25: 100%|██████████| 98/98 [00:09<00:00,  9.82it/s] 



Test set: Average loss: 0.4746, Accuracy: 8417/10000 (84.17%)

EPOCH: 21


Loss=0.4610551595687866 Batch_id=97 Accuracy=82.71: 100%|██████████| 98/98 [00:09<00:00, 10.32it/s] 



Test set: Average loss: 0.4543, Accuracy: 8482/10000 (84.82%)

EPOCH: 22


Loss=0.5047925114631653 Batch_id=97 Accuracy=82.66: 100%|██████████| 98/98 [00:08<00:00, 11.33it/s] 



Test set: Average loss: 0.4512, Accuracy: 8503/10000 (85.03%)

EPOCH: 23


Loss=0.5287055969238281 Batch_id=97 Accuracy=82.72: 100%|██████████| 98/98 [00:09<00:00, 10.64it/s] 



Test set: Average loss: 0.4740, Accuracy: 8418/10000 (84.18%)

EPOCH: 24


Loss=0.4288865923881531 Batch_id=97 Accuracy=83.77: 100%|██████████| 98/98 [00:09<00:00, 10.73it/s] 



Test set: Average loss: 0.4442, Accuracy: 8504/10000 (85.04%)

EPOCH: 25


Loss=0.5110190510749817 Batch_id=97 Accuracy=83.58: 100%|██████████| 98/98 [00:09<00:00, 10.75it/s] 



Test set: Average loss: 0.4410, Accuracy: 8530/10000 (85.30%)

EPOCH: 26


Loss=0.48107433319091797 Batch_id=97 Accuracy=83.74: 100%|██████████| 98/98 [00:08<00:00, 11.32it/s]



Test set: Average loss: 0.4367, Accuracy: 8576/10000 (85.76%)

EPOCH: 27


Loss=0.41568878293037415 Batch_id=97 Accuracy=84.09: 100%|██████████| 98/98 [00:08<00:00, 11.21it/s]



Test set: Average loss: 0.4349, Accuracy: 8535/10000 (85.35%)

EPOCH: 28


Loss=0.4198441207408905 Batch_id=97 Accuracy=84.01: 100%|██████████| 98/98 [00:08<00:00, 11.03it/s] 



Test set: Average loss: 0.4422, Accuracy: 8552/10000 (85.52%)

EPOCH: 29


Loss=0.44396382570266724 Batch_id=97 Accuracy=84.15: 100%|██████████| 98/98 [00:09<00:00, 10.77it/s]



Test set: Average loss: 0.4392, Accuracy: 8552/10000 (85.52%)

EPOCH: 30


Loss=0.4845697283744812 Batch_id=97 Accuracy=84.15: 100%|██████████| 98/98 [00:08<00:00, 11.08it/s] 



Test set: Average loss: 0.4436, Accuracy: 8509/10000 (85.09%)

EPOCH: 31


Loss=0.45054784417152405 Batch_id=97 Accuracy=84.25: 100%|██████████| 98/98 [00:08<00:00, 11.59it/s]



Test set: Average loss: 0.4346, Accuracy: 8564/10000 (85.64%)

EPOCH: 32


Loss=0.45164790749549866 Batch_id=97 Accuracy=84.49: 100%|██████████| 98/98 [00:09<00:00, 10.86it/s]



Test set: Average loss: 0.4364, Accuracy: 8532/10000 (85.32%)

EPOCH: 33


Loss=0.4942794144153595 Batch_id=97 Accuracy=84.22: 100%|██████████| 98/98 [00:09<00:00, 10.78it/s] 



Test set: Average loss: 0.4367, Accuracy: 8588/10000 (85.88%)

EPOCH: 34


Loss=0.40555673837661743 Batch_id=97 Accuracy=84.43: 100%|██████████| 98/98 [00:08<00:00, 10.96it/s]



Test set: Average loss: 0.4356, Accuracy: 8564/10000 (85.64%)

EPOCH: 35


Loss=0.5164335370063782 Batch_id=97 Accuracy=84.37: 100%|██████████| 98/98 [00:08<00:00, 11.16it/s] 



Test set: Average loss: 0.4301, Accuracy: 8566/10000 (85.66%)

EPOCH: 36


Loss=0.4619891345500946 Batch_id=97 Accuracy=84.98: 100%|██████████| 98/98 [00:09<00:00, 10.78it/s] 



Test set: Average loss: 0.4168, Accuracy: 8596/10000 (85.96%)

EPOCH: 37


Loss=0.3569701015949249 Batch_id=97 Accuracy=85.09: 100%|██████████| 98/98 [00:09<00:00, 10.60it/s] 



Test set: Average loss: 0.4240, Accuracy: 8615/10000 (86.15%)

EPOCH: 38


Loss=0.39051949977874756 Batch_id=97 Accuracy=85.00: 100%|██████████| 98/98 [00:08<00:00, 11.51it/s]



Test set: Average loss: 0.4218, Accuracy: 8613/10000 (86.13%)

EPOCH: 39


Loss=0.4365992844104767 Batch_id=97 Accuracy=85.42: 100%|██████████| 98/98 [00:08<00:00, 11.39it/s] 



Test set: Average loss: 0.4186, Accuracy: 8623/10000 (86.23%)



In [11]:
if use_wandb:
    wandb.finish()

VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_losses,0.53289
test_losses,0.485
train_acc,81.18632
test_acc,83.52
lr,0.0216
_step,39.0
_runtime,415.0
_timestamp,1636996799.0


0,1
train_losses,█▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_losses,██▇▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████
test_acc,▁▂▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████████████
lr,███████████▄▄▄▄▄▄▄▄▄▄▄▄▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
