In [4]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
import torch
import torch.nn as nn

In [6]:
def conv_block(n_channels):
  return nn.Sequential(
      nn.LazyBatchNorm2d(),
      nn.ReLU(),
      nn.LazyConv2d(n_channels, kernel_size = 3, padding = 1)
  )

In [7]:
class DenseBlock(nn.Module):
  def __init__(self, n_convs, n_channels):
    super(DenseBlock, self).__init__()
    layer = []
    for i in range(n_convs):
      layer.append(conv_block(n_channels))
    self.net = nn.Sequential(*layer)
  
  def forward(self, X):
    for blk in self.net:
      Y = blk(X)
      X = torch.cat((X, Y), dim = 1)
    return X

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
blk = DenseBlock(2, 10).to(device)
X = torch.randn(4, 3, 8, 8).to(device)
Y = blk(X)
Y.shape



torch.Size([4, 23, 8, 8])

## Transition layer

In [10]:
def transition_block(n_channels):
  return nn.Sequential(
      nn.LazyBatchNorm2d(),
      nn.ReLU(),
      nn.LazyConv2d(n_channels, kernel_size = 1),
      nn.AvgPool2d(kernel_size = 2, stride = 2)
  )

In [11]:
blk = transition_block(10).to(device)
blk(Y).shape

torch.Size([4, 10, 4, 4])

## DenseNet model

In [18]:
class DenseNet(nn.Module):
  def __init__(self, n_channels = 64, g_rate = 32, arch = (4, 4, 4, 4), lr = 0.1, n_classes = 10):
    super(DenseNet, self).__init__()
    self.net = nn.Sequential(self.b1())
    for i, n_convs in enumerate(arch):
      self.net.add_module(f'dense_blk{i + 1}', DenseBlock(n_convs, g_rate))
      n_channels += n_convs * g_rate
      if i != len(arch) - 1:
        n_channels //= 2
        self.net.add_module(f'tran_blk{i + 1}', transition_block(n_channels))
    
    self.net.add_module('Last', nn.Sequential(
        nn.LazyBatchNorm2d(),
        nn.ReLU(),
        nn.AdaptivePool2d((1, 1)),
        nn.Flatten(),
        nn.LazyLinear(n_classes)
    ))

  def b1(self):
    return nn.Sequential(
        nn.LazyConv2d(64, kernel_size = 7, stride = 2, padding = 3),
        nn.LazyBatchNorm2d(),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    )

In [13]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchmetrics import Accuracy

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
train_data = FashionMNIST(root = './data', train = True, 
                          transform = ToTensor(), target_transform = None,
                          download = True)

test_data = FashionMNIST(root = './data', train = True,
                         transform = ToTensor(), target_transform = None,
                         download = True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [16]:
torch.manual_seed(42)

train_loader = DataLoader(train_data, batch_size = 128, shuffle = True)
test_loader = DataLoader(test_data, batch_size = 128, shuffle = True)

In [19]:
model = DenseNet(lr = 0.01).to(device)

AttributeError: ignored