<a href="https://colab.research.google.com/github/mohsenahmadi2003/cnn_ai/blob/main/Session4_dataloder_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **`Import`**

In [1]:
import torch
import torchvision
import torch.nn as nn

# **`Initialization`**

In [2]:
batch_size = 256

num_class = 10

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# **`Dataset`**

In [None]:
train_dataset = torchvision.datasets.MNIST("./mnist_train", train=True, transform=torchvision.transforms.ToTensor(), download=True)

test_dataset = torchvision.datasets.MNIST("./mnist_test", train=False, transform=torchvision.transforms.ToTensor(), download=True)


# **`Data Loader`**

In [4]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
one_train_batch_imgs, one_train_batch_lbls = next(iter(train_loader))
print(one_train_batch_imgs.shape)
print(one_train_batch_lbls)

In [None]:
for index, (images, labels) in enumerate(train_loader):
    print(index, images.shape)

# **`Model`**

In [None]:
torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d(3, 2),
                    torch.nn.Conv2d(32, 64, 3),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d(3, 2),
                    torch.nn.Linear(64*7*7, 1024),
                    torch.nn.Linear(1024, 10))

In [8]:
class convnet(nn.Module):
    def __init__(self, num_class):
        super(convnet, self).__init__()
        # Layer 1
        self.conv2d_1 = nn.Conv2d(1, 32, (3,3))
        self.relu_1 = nn.ReLU()
        self.maxpool_1 = nn.MaxPool2d(3, 2)
        # Layer 2
        self.conv2d_2 = nn.Conv2d(32, 64, (3,3))
        self.relu_2 = nn.ReLU()
        self.maxpool_2 = nn.MaxPool2d(3, 2)
        # Layer 3
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, num_class)



    def forward(self, x):
        # Layer 1
        y = self.conv2d_1(x)
        y = self.relu_1(y)
        y = self.maxpool_1(y)

        # Layer 2
        y = self.conv2d_2(y)
        y = self.relu_2(y)
        y = self.maxpool_2(y)

        # Layer 3
        y = y.view(y.size(0), -1) # N_batchsize*3136
        y = self.fc1(y)
        y = self.fc2(y)

        return y

In [None]:
model = convnet(num_class).to(device)
print(model)

In [None]:
model.conv2d_1.weight

# **Config**

In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# **Trainer**

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [None]:
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

In [None]:
# Plot the training curves of validation accuracy vs. number
#  of training epochs for the transfer learning method and
#  the model trained from scratch
ohist = []
shist = []

ohist = [h.cpu().numpy() for h in hist]
shist = [h.cpu().numpy() for h in scratch_hist]

plt.title("Validation Accuracy vs. Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")
plt.plot(range(1,num_epochs+1),shist,label="Scratch")
plt.ylim((0,1.))
plt.xticks(np.arange(1, num_epochs+1, 1.0))
plt.legend()
plt.show()