In [1]:
print("begin")

begin


In [5]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers

ModuleNotFoundError: No module named 'ignite'

In [None]:
DATA_DIR = "./data_cn"

IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 100

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1

In [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)

NameError: name 'torch' is not defined

In [6]:
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(IMAGE_SIZE, padding=4),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
    ]
)

In [7]:
train_dset = datasets.CIFAR10(
    root=DATA_DIR, train=True, download=True, transform=train_transform
)
test_dset = datasets.CIFAR10(
    root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor()
)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_loader = torch.utils.data.DataLoader(
    train_dset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

In [9]:
class LayerNormChannels(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.LayerNorm(channels)

    def forward(self, x):
        x = x.transpose(1, -1)
        x = self.norm(x)
        x = x.transpose(-1, 1)
        return x

In [10]:
class Residual(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.residual = nn.Sequential(*layers)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return x + self.gamma * self.residual(x)

In [11]:
class ConvNeXtBlock(Residual):
    def __init__(self, channels, kernel_size, mult=4, p_drop=0.0):
        padding = (kernel_size - 1) // 2
        hidden_channels = channels * mult
        super().__init__(
            nn.Conv2d(
                channels, channels, kernel_size, padding=padding, groups=channels
            ),
            LayerNormChannels(channels),
            nn.Conv2d(channels, hidden_channels, 1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, channels, 1),
            nn.Dropout(p_drop),
        )

In [12]:
class DownsampleBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__(
            LayerNormChannels(in_channels),
            nn.Conv2d(in_channels, out_channels, stride, stride=stride),
        )

In [13]:
class Stage(nn.Sequential):
    def __init__(self, in_channels, out_channels, num_blocks, kernel_size, p_drop=0.0):
        layers = (
            []
            if in_channels == out_channels
            else [DownsampleBlock(in_channels, out_channels)]
        )
        layers += [
            ConvNeXtBlock(out_channels, kernel_size, p_drop=p_drop)
            for _ in range(num_blocks)
        ]
        super().__init__(*layers)

In [14]:
class ConvNeXtBody(nn.Sequential):
    def __init__(
        self, in_channels, channel_list, num_blocks_list, kernel_size, p_drop=0.0
    ):
        layers = []
        for out_channels, num_blocks in zip(channel_list, num_blocks_list):
            layers.append(
                Stage(in_channels, out_channels, num_blocks, kernel_size, p_drop)
            )
            in_channels = out_channels
        super().__init__(*layers)

In [15]:
class Stem(nn.Sequential):
    def __init__(self, in_channels, out_channels, patch_size):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, patch_size, stride=patch_size),
            LayerNormChannels(out_channels),
        )

In [16]:
class Head(nn.Sequential):
    def __init__(self, in_channels, classes):
        super().__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.LayerNorm(in_channels),
            nn.Linear(in_channels, classes),
        )

In [17]:
class ConvNeXt(nn.Sequential):
    def __init__(
        self,
        classes,
        channel_list,
        num_blocks_list,
        kernel_size,
        patch_size,
        in_channels=3,
        res_p_drop=0.0,
    ):
        super().__init__(
            Stem(in_channels, channel_list[0], patch_size),
            ConvNeXtBody(
                channel_list[0], channel_list, num_blocks_list, kernel_size, res_p_drop
            ),
            Head(channel_list[-1], classes),
        )
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d)):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.zeros_(m.bias)
            elif isinstance(m, Residual):
                nn.init.zeros_(m.gamma)

    def separate_parameters(self):
        parameters_decay = set()
        parameters_no_decay = set()
        modules_weight_decay = (nn.Linear, nn.Conv2d)
        modules_no_weight_decay = (nn.LayerNorm,)

        for m_name, m in self.named_modules():
            for param_name, param in m.named_parameters():
                full_param_name = f"{m_name}.{param_name}" if m_name else param_name

                if isinstance(m, modules_no_weight_decay):
                    parameters_no_decay.add(full_param_name)
                elif param_name.endswith("bias"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, Residual) and param_name.endswith("gamma"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, modules_weight_decay):
                    parameters_decay.add(full_param_name)

        # sanity check
        assert len(parameters_decay & parameters_no_decay) == 0
        assert len(parameters_decay) + len(parameters_no_decay) == len(
            list(model.parameters())
        )

        return parameters_decay, parameters_no_decay

In [18]:
model = ConvNeXt(
    NUM_CLASSES,
    channel_list=[64, 128, 256, 512],
    num_blocks_list=[2, 2, 2, 2],
    kernel_size=7,
    patch_size=1,
    res_p_drop=0.0,
)

In [19]:
model.to(DEVICE)

ConvNeXt(
  (0): Stem(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): LayerNormChannels(
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (1): ConvNeXtBody(
    (0): Stage(
      (0): ConvNeXtBlock(
        (residual): Sequential(
          (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (1): LayerNormChannels(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          )
          (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          (3): GELU(approximate='none')
          (4): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (5): Dropout(p=0.0, inplace=False)
        )
      )
      (1): ConvNeXtBlock(
        (residual): Sequential(
          (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (1): LayerNormChannels(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          )
    

In [20]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 6,376,466


In [21]:
def get_optimizer(model, learning_rate, weight_decay):
    param_dict = {pn: p for pn, p in model.named_parameters()}
    parameters_decay, parameters_no_decay = model.separate_parameters()

    optim_groups = [
        {
            "params": [param_dict[pn] for pn in parameters_decay],
            "weight_decay": weight_decay,
        },
        {"params": [param_dict[pn] for pn in parameters_no_decay], "weight_decay": 0.0},
    ]
    optimizer = optim.AdamW(optim_groups, lr=learning_rate)
    return optimizer

In [22]:
loss = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
lr_scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LEARNING_RATE, steps_per_epoch=len(train_loader), epochs=EPOCHS
)
trainer.add_event_handler(
    Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()
)
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

In [23]:
val_metrics = {"accuracy": ignite.metrics.Accuracy(), "loss": ignite.metrics.Loss(loss)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=DEVICE)
history = defaultdict(list)

In [24]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    train_state = engine.state
    epoch = train_state.epoch
    max_epochs = train_state.max_epochs
    train_loss = train_state.metrics["loss"]
    history["train loss"].append(train_loss)

    evaluator.run(test_loader)
    val_metrics = evaluator.state.metrics
    val_loss = val_metrics["loss"]
    val_acc = val_metrics["accuracy"]
    history["val loss"].append(val_loss)
    history["val acc"].append(val_acc)

    print(
        "{}/{} - train: loss {:.3f}; val: loss {:.3f} accuracy {:.3f}".format(
            epoch, max_epochs, train_loss, val_loss, val_acc
        )
    )

In [25]:
trainer.run(train_loader, max_epochs=EPOCHS)

1/100 - train: loss 1.882; val: loss 1.817 accuracy 0.329
2/100 - train: loss 1.668; val: loss 1.692 accuracy 0.392
3/100 - train: loss 1.544; val: loss 1.546 accuracy 0.443
4/100 - train: loss 1.491; val: loss 1.463 accuracy 0.469
5/100 - train: loss 1.417; val: loss 1.394 accuracy 0.492
6/100 - train: loss 1.380; val: loss 1.325 accuracy 0.526
7/100 - train: loss 1.318; val: loss 1.289 accuracy 0.538
8/100 - train: loss 1.229; val: loss 1.186 accuracy 0.572
9/100 - train: loss 1.165; val: loss 1.097 accuracy 0.602
10/100 - train: loss 1.060; val: loss 1.028 accuracy 0.640
11/100 - train: loss 0.975; val: loss 0.966 accuracy 0.658
12/100 - train: loss 0.875; val: loss 0.862 accuracy 0.697
13/100 - train: loss 0.822; val: loss 0.797 accuracy 0.727
14/100 - train: loss 0.764; val: loss 0.694 accuracy 0.760
15/100 - train: loss 0.675; val: loss 0.634 accuracy 0.777
16/100 - train: loss 0.662; val: loss 0.621 accuracy 0.784
17/100 - train: loss 0.619; val: loss 0.635 accuracy 0.780
18/100

State:
	iteration: 156300
	epoch: 100
	epoch_length: 1563
	max_epochs: 100
	output: 0.004967710934579372
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [26]:
save_path = "./model_cn"

torch.save(model, save_path)

In [32]:
torch.save(model.state_dict(), save_path)

In [27]:
cnex1 = ConvNeXt(
    NUM_CLASSES,
    channel_list=[64, 128, 256, 512],
    num_blocks_list=[2, 2, 2, 2],
    kernel_size=7,
    patch_size=1,
    res_p_drop=0.0,
)

In [34]:
cnex1.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [35]:
print(cnex1)

ConvNeXt(
  (0): Stem(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): LayerNormChannels(
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (1): ConvNeXtBody(
    (0): Stage(
      (0): ConvNeXtBlock(
        (residual): Sequential(
          (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (1): LayerNormChannels(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          )
          (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          (3): GELU(approximate='none')
          (4): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (5): Dropout(p=0.0, inplace=False)
        )
      )
      (1): ConvNeXtBlock(
        (residual): Sequential(
          (0): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (1): LayerNormChannels(
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          )
    