<a href="https://colab.research.google.com/github/manishmcsa/Assigment-6/blob/main/model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm

In [None]:
def train(model, device, train_loader, optimizer, epoch, l1_decay, l2_decay):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    l1_loss = None
    l2_loss = None
    for batch_idx, (data, target) in enumerate(pbar):
        # get samples
        data, target = data.to(device), target.to(device)

        # Init
        optimizer.zero_grad()
        # In PyTorch, we need to set the gradients to zero before starting to
        # do backpropragation because PyTorch accumulates the gradients on
        # subsequent backward passes.
        # Because of this, when you start your training loop, ideally you
        # should zero out the gradients so that you do the parameter update
        # correctly.

        # Predict
        y_pred = model(data)

        # Calculate loss
        loss = F.nll_loss(y_pred, target)
        if l1_decay > 0:
            l1_loss = 0
            for param in model.parameters():
                l1_loss += torch.norm(param, 1)
            loss += l1_decay * l1_loss
        if l2_decay > 0:
            l2_loss = 0
            for param in model.parameters():
                l2_loss += torch.norm(param, 2)
            loss += l2_decay * l2_loss

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Update pbar-tqdm

        pred = y_pred.argmax(dim=1,
                             keepdim=True)  # get the index of the max
        # log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar_str = f'Loss={loss.item()} Batch_id={batch_idx} Accuracy=' \
                   f'{100 * correct / processed:0.2f}'
        if l1_decay > 0:
            pbar_str = f'L1_loss={l1_loss.item()} %s' % pbar_str
        if l2_decay > 0:
            pbar_str = f'L2_loss={l2_loss.item()} %s' % pbar_str