In [8]:
import torch
from torch import nn, optim
import torchvision
from torchvision import transforms

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
class CNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()        
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512, 120),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(120, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
      x = self.feature_extraction(x)
      x = x.view(-1, 512)
      x = self.classifier(x)

      return x
    
model = CNN().to(device)


In [7]:
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-2)

In [10]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(60),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(10, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor()
])

In [14]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1) -> None:
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim
    
    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.inference_mode:
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter(1, target.data.unsqueeze(1), self.confidence, reduce='add')
        
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [15]:
ls = LabelSmoothingLoss(10, smoothing=0.2)