In [1]:
import torch
from torch import nn
from typing import List
from torch import Tensor

In [2]:

from torchvision.ops import StochasticDepth

class LayerScaler(nn.Module):
    def __init__(self, init_value: float, dimensions: int):
        super().__init__()
        self.gamma = nn.Parameter(init_value * torch.ones((dimensions)), requires_grad=True)

    def forward(self, x):
        return self.gamma[None, ..., None, None] * x
    
class BottleNeckBlock(nn.Module): 
    '''Inverted BottleNeck'''
    def __init__(self, in_features: int, out_features: int, 
                expansion: int = 4, drop_p:float = .0, layer_scaler_init_value: float = 1e-6):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide (ResNeXt-ify: use depthwise convolution and increase the network width)
            nn.Conv2d(
                in_features, in_features, kernel_size=7, padding=3, bias=False, groups=in_features
            ),
            # GroupNorm with num_groups=1 is the same as LayerNorm but works for 2D data
            nn.GroupNorm(num_groups=1, num_channels=in_features),
            # wide -> wide
            nn.Conv2d(in_features, expanded_features, kernel_size=1),
            nn.GELU(), # Use GELU rather than ReLU 
            # wide -> narrow
            nn.Conv2d(expanded_features, out_features, kernel_size=1),
        )
        self.layer_scaler = LayerScaler(layer_scaler_init_value, out_features)
        self.drop_path = StochasticDepth(drop_p, mode="batch")

    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        x = self.layer_scaler(x)
        x = self.drop_path(x)
        x += res
        return x


In [3]:
class ConvNextStage(nn.Sequential):
    def __init__(self, in_features: int, out_features: int, depth: int, **kwargs):
        super().__init__(
            # Downsampling
            nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=in_features),
                nn.Conv2d(in_features, out_features, kernel_size=2, stride=2)
            ),
            *[
                BottleNeckBlock(out_features, out_features, **kwargs)
                for _ in range(depth)
            ],
        )

In [4]:
class ConvNextStem(nn.Sequential): 
    '''Patchifying stem: a 7x7 convolution with stride 2 followed by a 3x3 max pooling with stride 2 
       -> a 4x4 convolution with stride 4 (from overlapping to non-overlapping).'''
    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=4, stride=4),
            nn.BatchNorm2d(out_features)
        )

In [5]:
class ConvNextEncoder(nn.Module):
    def __init__(self, in_channels: int, stem_features: int, 
                depths: List[int], widths: List[int], drop_p: float = .0):
        super().__init__()
        self.stem = ConvNextStem(in_channels, stem_features)

        in_out_widths = list(zip(widths, widths[1:]))

        # drop probability
        drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))]
        
        self.stages = nn.ModuleList(
            [
                ConvNextStage(stem_features, widths[0], depths[0], drop_p=drop_probs[0]),
                *[
                    ConvNextStage(in_features, out_features, depth, drop_p=drop_p) 
                    for (in_features, out_features), depth, drop_p in zip(in_out_widths, depths[1:], drop_probs[1:])
                ]
            ]
        )
    
    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return x

In [6]:
class ClassificationHead(nn.Sequential):
    def __init__(self, num_channels: int, num_classes: int = 1000):
        super().__init__(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1),
            nn.LayerNorm(num_channels),
            nn.Linear(num_channels, num_classes)
        )

class ConvNext(nn.Sequential):
    '''ConvNext for Image Classification'''
    def __init__(self,
                in_channels: int = 3,
                stem_features: int = 64,
                depths: List[int] = [3, 3, 9, 3],
                widths: List[int] = [96, 192, 384, 768],
                drop_p: float = .0,
                num_classes: int = 1000):

        super().__init__()
        self.encoder = ConvNextEncoder(in_channels, stem_features, depths, widths, drop_p)
        self.head = ClassificationHead(widths[-1], num_classes)

In [7]:
convnext_tiny = ConvNext(depths=[3, 3, 9, 3], widths=[96, 192, 384, 768])
convnext_tiny

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 96, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96, bias=False)
            (1): GroupNorm(1, 96, eps=1e-05, affine=True)
            (2): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
          (block

In [8]:
convnext_small = ConvNext(depths=[3, 3, 27, 3], widths=[96, 192, 384, 768])
convnext_small

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 96, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96, bias=False)
            (1): GroupNorm(1, 96, eps=1e-05, affine=True)
            (2): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
          (block

In [9]:
convnext_base = ConvNext(depths=[3, 3, 27, 3], widths=[128, 256, 512, 1024])
convnext_base

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128, bias=False)
            (1): GroupNorm(1, 128, eps=1e-05, affine=True)
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
         

In [10]:
convnext_large = ConvNext(depths=[3, 3, 27, 3], widths=[192, 384, 768, 1536])
convnext_large

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 192, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192, bias=False)
            (1): GroupNorm(1, 192, eps=1e-05, affine=True)
            (2): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
         

In [11]:
convnext_xlarge = ConvNext(depths=[3, 3, 27, 3], widths=[256, 512, 1024, 2048])
convnext_xlarge

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 256, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(256, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=256, bias=False)
            (1): GroupNorm(1, 256, eps=1e-05, affine=True)
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
       

In [13]:
import sys
import logging
import matplotlib.pyplot as plt
from logging.handlers import RotatingFileHandler

def plot_losses(train_losses, valid_losses, path):
    plt.plot(train_losses, label='train')
    plt.plot(valid_losses, label='valid')
    plt.savefig(f"{path}/losses.png")
    plt.clf()


class CustomLogger:

    def __init__(self,
                 name,
                 file_path=None,
                 log_size=10 * 1024 * 1024,
                 backup_count=5):
        self.log_size = log_size
        self.backup_count = backup_count
        self._init_logger(name, file_path)

    def log_info(self, message):
        self.logger.info(message)
    
    def _init_logger(self, name, file_path):
        logging.addLevelName(logging.INFO, "[INF]")

        self.logger = logging.getLogger(name)
        self.logger.setLevel(logging.INFO)
        self.formatter = logging.Formatter(
            "%(levelname)s - %(asctime)s - %(message)s"
        )

        if file_path:
            file_handler = RotatingFileHandler(file_path,
                                            maxBytes=self.log_size,
                                            backupCount=self.backup_count)
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(self.formatter)
            self.logger.addHandler(file_handler)
        else:
            stream_handler = logging.StreamHandler(sys.stdout)
            stream_handler.setFormatter(self.formatter)
            self.logger.addHandler(stream_handler)



In [24]:
import torch
from torch import nn
from torchinfo import summary
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from datasets import load_dataset
from torchvision.transforms import (CenterCrop, Compose, ToTensor, RandAugment)

from timm.data.mixup import Mixup
from timm.data.random_erasing import RandomErasing

import os
import argparse
from tqdm.auto import tqdm
from datetime import datetime
from sklearn.metrics import classification_report


def main(args):
    mixup_args = {
        'mixup_alpha': 0.8,
        'cutmix_alpha': 1.0,
        'cutmix_minmax': None,
        'prob': 0.4,
        'switch_prob': 0.5,
        'mode': 'elem',
        'label_smoothing': 0.1,
        'num_classes': args.num_classes
    }

    mixup = Mixup(**mixup_args)
    rand_erasing = RandomErasing(probability=0.25, max_area=1/4, mode="pixel")

    model = convnext_tiny
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    transforms_train = Compose([
        CenterCrop(args.resolution),
        RandAugment(num_ops=2),
        ToTensor()
    ])

    transforms_test = Compose([
        CenterCrop(args.resolution),
        ToTensor()
    ])

    if not args.dataset_name:
        raise ValueError(
            "You must specify a dataset name."
        )
    
    train_data, test_data = load_dataset(args.dataset_name,
                                         split=["train", "test"])
    
    def transforms_train_(examples):
        images = [
            transforms_train(image.convert("RGB"))
            for image in examples["img"]
        ]
        labels = [l for l in examples["label"]]
        return {"images": images, "labels": labels}
    
    def transforms_test_(examples):
        images = [
            transforms_test(image.convert("RGB"))
            for image in examples["img"]
        ]
        labels = [l for l in examples["label"]]
        return {"images": images, "labels": labels}
    
    train_data.set_transform(transforms_train_)
    test_data.set_transform(transforms_test_)

    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.train_batch_size,
        shuffle=True
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_data, batch_size=args.eval_batch_size, shuffle=False
    )

    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.learning_rate,
        steps_per_epoch=len(train_dataloader),
        epochs=args.num_epochs
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    loss_fn = nn.CrossEntropyLoss()

    current_date = datetime.today().strftime('%Y%m%d_%H%M%S')
    logs_path = f"./training_logs/{current_date}/"
    os.makedirs(logs_path, exist_ok=True)
    logger = CustomLogger("convnext_tiny", file_path=f"{logs_path}/training_log.txt")
    model_summary = str(summary(model, (1, 3, args.resolution, args.resolution), verbose=0))
    logger.log_info(model_summary)

    global_step = 0
    losses = []
    valid_losses = []
    for epoch in range(args.num_epochs):
        model.train()
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")
        losses_log = 0
        for step, batch in enumerate(train_dataloader):
            images = batch["images"].to(device)
            labels = batch["labels"].to(device)

            images, labels = mixup(images, labels)
            images = rand_erasing(images)

            preds = model(images)

            loss = loss_fn(preds, labels)
            loss.backward()

            if args.use_clip_grad:
                clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            losses_log += loss.detach().item()
            logs = {
                "loss_avg": losses_log / (step + 1),
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step
            }

            progress_bar.set_postfix(**logs)
            global_step += 1
        progress_bar.close()

        losses.append(losses_log / (step + 1))
        if epoch % args.save_model_epochs == 0:
            model.eval()
            valid_loss = 0
            with torch.no_grad():
                valid_labels = []
                valid_preds = []
                for step, batch in enumerate(tqdm(test_dataloader, total=len(test_dataloader))):
                    images = batch["images"].to(device)
                    labels = batch["labels"].to(device)

                    preds = model(images)

                    loss = loss_fn(preds, labels)
                    valid_loss += loss.item()

                    preds = preds.argmax(dim=-1)
                    valid_labels.extend(labels.detach().cpu().tolist())
                    valid_preds.extend(preds.detach().cpu().tolist())

                # print for debug
                print(f"Valid loss: {valid_loss / len(test_dataloader)}")
                print(classification_report(valid_labels, valid_preds))

                logger.log_info(f"Epoch {epoch}")
                logger.log_info(logs)
                logger.log_info(
                    f"Valid loss: {valid_loss / len(test_dataloader)}"
                )
                logger.log_info(classification_report(valid_labels, valid_preds))

                torch.save(
                    {
                        'model_state': model.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                    }, args.output_dir
                )
            
            epoch_path = f"{logs_path}/{epoch}"
            os.makedirs(epoch_path)

            valid_losses.append(valid_loss / len(test_dataloader))
            plot_losses(
                train_losses=losses,
                valid_losses=valid_losses,
                path=epoch_path
            )

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train ConvNeXt")
    parser.add_argument("--dataset_name", type=str, default=None)
    