<a href="https://colab.research.google.com/github/maheravi/Deep-Learning/blob/main/PyTorch%20Persian%20Mnist%20TL/Persian_Mnist_Sweep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb --upgrade

Collecting wandb
  Downloading wandb-0.12.9-py2.py3-none-any.whl (1.7 MB)
[?25l[K     |▏                               | 10 kB 24.5 MB/s eta 0:00:01[K     |▍                               | 20 kB 29.7 MB/s eta 0:00:01[K     |▋                               | 30 kB 33.5 MB/s eta 0:00:01[K     |▊                               | 40 kB 35.6 MB/s eta 0:00:01[K     |█                               | 51 kB 38.3 MB/s eta 0:00:01[K     |█▏                              | 61 kB 31.1 MB/s eta 0:00:01[K     |█▍                              | 71 kB 27.5 MB/s eta 0:00:01[K     |█▌                              | 81 kB 29.0 MB/s eta 0:00:01[K     |█▊                              | 92 kB 31.1 MB/s eta 0:00:01[K     |██                              | 102 kB 29.9 MB/s eta 0:00:01[K     |██                              | 112 kB 29.9 MB/s eta 0:00:01[K     |██▎                             | 122 kB 29.9 MB/s eta 0:00:01[K     |██▌                             | 133 kB 29.9 MB/s eta 

In [2]:
import wandb

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.models as models


In [4]:
sweep_config = {
    'method': 'random'
    }

In [5]:
metric = {
    'name': 'loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

In [6]:
parameters_dict = {
    'optimizer': {
        'values': ['adam', 'sgd']
        },
    'fc_layer_size': {
        'values': [128, 256, 512]
        },
    'dropout': {
          'values': [0.3, 0.4, 0.5]
        },
    }

sweep_config['parameters'] = parameters_dict

In [7]:
parameters_dict.update({
    'epochs': {
        'value': 1}
    })

In [8]:
import math

parameters_dict.update({
    'learning_rate': {
        # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        'min': 0,
        'max': 0.1
      },
    'batch_size': {
        # integers between 32 and 256
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform',
        'q': 1,
        'min': math.log(32),
        'max': math.log(256),
      }
    })

In [9]:
import pprint

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'loss'},
 'parameters': {'batch_size': {'distribution': 'q_log_uniform',
                               'max': 5.545177444479562,
                               'min': 3.4657359027997265,
                               'q': 1},
                'dropout': {'values': [0.3, 0.4, 0.5]},
                'epochs': {'value': 1},
                'fc_layer_size': {'values': [128, 256, 512]},
                'learning_rate': {'distribution': 'uniform',
                                  'max': 0.1,
                                  'min': 0},
                'optimizer': {'values': ['adam', 'sgd']}}}


In [24]:
sweep_id = wandb.sweep(sweep_config, project="pytorch-sweeps-PesianMnist")

Create sweep with ID: 1sj0kr8y
Sweep URL: https://wandb.ai/ma_heravi/pytorch-sweeps-PesianMnist/sweeps/1sj0kr8y


In [35]:
epochs = 10
loss_function = nn.CrossEntropyLoss()

def calc_acc(preds, labels):
    _, preds_max = torch.max(preds, 1)
    acc = torch.sum(preds_max == labels.data, dtype=torch.float64) / len(preds)
    return acc

In [36]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config

        loader = build_dataset(config.batch_size)
        network = build_network(config.fc_layer_size, config.dropout)
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate)

        for epoch in range(1, epochs+1):
            total_loss, total_acc = train_epoch(network, loader, optimizer)
            print(f"Epoch: {epoch+1}, Loss: {total_loss}")
            wandb.log({'epochs':  epoch + 1,
              'loss': total_loss,
              'acc': total_acc
                              })           

In [37]:
def build_dataset(batch_size):
   
    transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((28, 28)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])


    dataset = torchvision.datasets.ImageFolder("/content/drive/MyDrive/MNIST_persian", transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader


def build_network(fc_layer_size, dropout):
    network = models.resnet50(pretrained=True)
    in_fetures = network.fc.in_features
    network.fc = nn.Linear(in_fetures, 10)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    network = network.to(device)
    # This freezes layers 1-6 in the total 10 layers of Resnet50
    ct = 0
    for child in network.children():
        ct += 1
        if ct < 7:
            for param in child.parameters():
                param.requires_grad = False
                
    return network
        

def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)
    return optimizer


In [38]:
def train_epoch(network, loader, optimizer):

    train_loss = 0.0
    train_acc = 0.0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # 1- forwarding
        preds = network(images)
        # 2- backwarding 
        loss = loss_function(preds, labels)
        loss.backward()
        # 3- Update
        optimizer.step()

        train_loss += loss
        train_acc += calc_acc(preds, labels)
    
    total_loss = train_loss / len(loader)
    total_acc = train_acc / len(loader)
    print(f"loss_train:{total_loss},accuracy_train:{total_acc}")
  
    return total_loss, total_acc

In [39]:
wandb.agent(sweep_id, train, count=5)

[34m[1mwandb[0m: Agent Starting Run: j8jndt45 with config:
[34m[1mwandb[0m: 	batch_size: 102
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc_layer_size: 512
[34m[1mwandb[0m: 	learning_rate: 0.01847258268541586
[34m[1mwandb[0m: 	optimizer: adam


loss_train:3.591172218322754,accuracy_train:0.16327300150829563
Epoch: 2, Loss: 3.591172218322754
loss_train:1.7755763530731201,accuracy_train:0.47410759175465056
Epoch: 3, Loss: 1.7755763530731201
loss_train:0.6114603281021118,accuracy_train:0.7741955756661639
Epoch: 4, Loss: 0.6114603281021118
loss_train:0.2777659595012665,accuracy_train:0.896870286576169
Epoch: 5, Loss: 0.2777659595012665
loss_train:0.1447882354259491,accuracy_train:0.9504776269482151
Epoch: 6, Loss: 0.1447882354259491
loss_train:0.1050422415137291,accuracy_train:0.9678858722976369
Epoch: 7, Loss: 0.1050422415137291
loss_train:0.07511337101459503,accuracy_train:0.9744218200100554
Epoch: 8, Loss: 0.07511337101459503
loss_train:0.07653386145830154,accuracy_train:0.9775012569130217
Epoch: 9, Loss: 0.07653386145830154
loss_train:0.060092609375715256,accuracy_train:0.9780040221216693
Epoch: 10, Loss: 0.060092609375715256
loss_train:0.0756952092051506,accuracy_train:0.973667672197084
Epoch: 11, Loss: 0.0756952092051506


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▄▆▇██████
epochs,▁▂▃▃▄▅▆▆▇█
loss,█▄▂▁▁▁▁▁▁▁

0,1
acc,0.97367
epochs,11.0
loss,0.0757


[34m[1mwandb[0m: Agent Starting Run: 5b8n9r6d with config:
[34m[1mwandb[0m: 	batch_size: 73
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc_layer_size: 128
[34m[1mwandb[0m: 	learning_rate: 0.057846578822503825
[34m[1mwandb[0m: 	optimizer: adam


loss_train:4.497877597808838,accuracy_train:0.17770447219983881
Epoch: 2, Loss: 4.497877597808838
loss_train:1.685288667678833,accuracy_train:0.4410505640612409
Epoch: 3, Loss: 1.685288667678833
loss_train:0.9928544163703918,accuracy_train:0.6360545930701049
Epoch: 4, Loss: 0.9928544163703918
loss_train:0.4982936680316925,accuracy_train:0.8081184528605962
Epoch: 5, Loss: 0.4982936680316925
loss_train:0.27892425656318665,accuracy_train:0.8966307413376308
Epoch: 6, Loss: 0.27892425656318665
loss_train:0.21645592153072357,accuracy_train:0.9183873892022562
Epoch: 7, Loss: 0.21645592153072357
loss_train:0.18998916447162628,accuracy_train:0.938079170024174
Epoch: 8, Loss: 0.18998916447162628
loss_train:0.13645130395889282,accuracy_train:0.9582242143432715
Epoch: 9, Loss: 0.13645130395889282
loss_train:0.11285076290369034,accuracy_train:0.9661563255439161
Epoch: 10, Loss: 0.11285076290369034
loss_train:0.0896410271525383,accuracy_train:0.9731819097502014
Epoch: 11, Loss: 0.0896410271525383


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▃▅▇▇█████
epochs,▁▂▃▃▄▅▆▆▇█
loss,█▄▂▂▁▁▁▁▁▁

0,1
acc,0.97318
epochs,11.0
loss,0.08964


[34m[1mwandb[0m: Agent Starting Run: 242pt454 with config:
[34m[1mwandb[0m: 	batch_size: 162
[34m[1mwandb[0m: 	dropout: 0.4
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc_layer_size: 128
[34m[1mwandb[0m: 	learning_rate: 0.09864128685962284
[34m[1mwandb[0m: 	optimizer: sgd


loss_train:2.569227933883667,accuracy_train:0.3943602693602693
Epoch: 2, Loss: 2.569227933883667
loss_train:0.8794772028923035,accuracy_train:0.7189253647586981
Epoch: 3, Loss: 0.8794772028923035
loss_train:0.47122275829315186,accuracy_train:0.8394360269360269
Epoch: 4, Loss: 0.47122275829315186
loss_train:0.32574617862701416,accuracy_train:0.9036195286195285
Epoch: 5, Loss: 0.32574617862701416
loss_train:0.18385858833789825,accuracy_train:0.9421997755331089
Epoch: 6, Loss: 0.18385858833789825
loss_train:0.08291301131248474,accuracy_train:0.9769219977553311
Epoch: 7, Loss: 0.08291301131248474
loss_train:0.05020197480916977,accuracy_train:0.9830246913580245
Epoch: 8, Loss: 0.05020197480916977
loss_train:0.09632593393325806,accuracy_train:0.9762906846240179
Epoch: 9, Loss: 0.09632593393325806
loss_train:0.06912099570035934,accuracy_train:0.9814814814814815
Epoch: 10, Loss: 0.06912099570035934
loss_train:0.08081506937742233,accuracy_train:0.9762906846240179
Epoch: 11, Loss: 0.080815069377

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▅▆▇██████
epochs,▁▂▃▃▄▅▆▆▇█
loss,█▃▂▂▁▁▁▁▁▁

0,1
acc,0.97629
epochs,11.0
loss,0.08082


[34m[1mwandb[0m: Agent Starting Run: y6kvhirv with config:
[34m[1mwandb[0m: 	batch_size: 231
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc_layer_size: 256
[34m[1mwandb[0m: 	learning_rate: 0.035274123437077476
[34m[1mwandb[0m: 	optimizer: sgd


loss_train:1.78464937210083,accuracy_train:0.4157287157287157
Epoch: 2, Loss: 1.78464937210083
loss_train:0.3159542977809906,accuracy_train:0.896007696007696
Epoch: 3, Loss: 0.3159542977809906
loss_train:0.2256239652633667,accuracy_train:0.939009139009139
Epoch: 4, Loss: 0.2256239652633667
loss_train:0.38771745562553406,accuracy_train:0.9392015392015393
Epoch: 5, Loss: 0.38771745562553406
loss_train:0.41507193446159363,accuracy_train:0.9107744107744107
Epoch: 6, Loss: 0.41507193446159363
loss_train:0.20149096846580505,accuracy_train:0.9360269360269361
Epoch: 7, Loss: 0.20149096846580505
loss_train:0.15587183833122253,accuracy_train:0.9499278499278498
Epoch: 8, Loss: 0.15587183833122253
loss_train:0.10358169674873352,accuracy_train:0.9673400673400674
Epoch: 9, Loss: 0.10358169674873352
loss_train:0.07414284348487854,accuracy_train:0.9817700817700816
Epoch: 10, Loss: 0.07414284348487854
loss_train:0.05978706479072571,accuracy_train:0.975998075998076
Epoch: 11, Loss: 0.05978706479072571


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▇▇▇▇▇████
epochs,▁▂▃▃▄▅▆▆▇█
loss,█▂▂▂▂▂▁▁▁▁

0,1
acc,0.976
epochs,11.0
loss,0.05979


[34m[1mwandb[0m: Agent Starting Run: o2ecpjwj with config:
[34m[1mwandb[0m: 	batch_size: 45
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	fc_layer_size: 256
[34m[1mwandb[0m: 	learning_rate: 0.0651755112269236
[34m[1mwandb[0m: 	optimizer: sgd


loss_train:1.7072601318359375,accuracy_train:0.5563786008230451
Epoch: 2, Loss: 1.7072601318359375
loss_train:2.7587695121765137,accuracy_train:0.6831275720164608
Epoch: 3, Loss: 2.7587695121765137
loss_train:0.626624345779419,accuracy_train:0.7860082304526752
Epoch: 4, Loss: 0.626624345779419
loss_train:0.3697603642940521,accuracy_train:0.8748971193415639
Epoch: 5, Loss: 0.3697603642940521
loss_train:1.0771573781967163,accuracy_train:0.9069958847736627
Epoch: 6, Loss: 1.0771573781967163
loss_train:0.16874054074287415,accuracy_train:0.9432098765432099
Epoch: 7, Loss: 0.16874054074287415
loss_train:0.1087624579668045,accuracy_train:0.9580246913580248
Epoch: 8, Loss: 0.1087624579668045
loss_train:0.8997863531112671,accuracy_train:0.9016460905349797
Epoch: 9, Loss: 0.8997863531112671
loss_train:1.390612244606018,accuracy_train:0.8831275720164612
Epoch: 10, Loss: 1.390612244606018
loss_train:0.27761682868003845,accuracy_train:0.9193415637860083
Epoch: 11, Loss: 0.27761682868003845


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
acc,▁▃▅▇▇██▇▇▇
epochs,▁▂▃▃▄▅▆▆▇█
loss,▅█▂▂▄▁▁▃▄▁

0,1
acc,0.91934
epochs,11.0
loss,0.27762


In [None]:
torch.save(models.state_dict(), "PersianMnistSweepTL.pth")